Implement value type abstraction for types.

This is done by changing Type to be a POD interface around an underlying pointer storage and adding in-class support for isa/dyn_cast/cast.

PiperOrigin-RevId: 219372163
diff --git a/lib/Analysis/LoopAnalysis.cpp b/lib/Analysis/LoopAnalysis.cpp
index 1b3c24f..1904a63 100644
--- a/lib/Analysis/LoopAnalysis.cpp
+++ b/lib/Analysis/LoopAnalysis.cpp
@@ -118,15 +118,15 @@
   return tripCountExpr.getLargestKnownDivisor();
 }
 
-bool mlir::isAccessInvariant(const MLValue &input, MemRefType *memRefType,
+bool mlir::isAccessInvariant(const MLValue &input, MemRefType memRefType,
                              ArrayRef<MLValue *> indices, unsigned dim) {
-  assert(indices.size() == memRefType->getRank());
+  assert(indices.size() == memRefType.getRank());
   assert(dim < indices.size());
-  auto layoutMap = memRefType->getAffineMaps();
-  assert(memRefType->getAffineMaps().size() <= 1);
+  auto layoutMap = memRefType.getAffineMaps();
+  assert(memRefType.getAffineMaps().size() <= 1);
   // TODO(ntv): remove dependency on Builder once we support non-identity
   // layout map.
-  Builder b(memRefType->getContext());
+  Builder b(memRefType.getContext());
   assert(layoutMap.empty() ||
          layoutMap[0] == b.getMultiDimIdentityMap(indices.size()));
   (void)layoutMap;
@@ -170,7 +170,7 @@
   using namespace functional;
   auto indices = map([](SSAValue *val) { return dyn_cast<MLValue>(val); },
                      memoryOp->getIndices());
-  auto *memRefType = memoryOp->getMemRefType();
+  auto memRefType = memoryOp->getMemRefType();
   for (unsigned d = 0, numIndices = indices.size(); d < numIndices; ++d) {
     if (fastestVaryingDim == (numIndices - 1) - d) {
       continue;
@@ -184,8 +184,8 @@
 
 template <typename LoadOrStoreOpPointer>
 static bool isVectorElement(LoadOrStoreOpPointer memoryOp) {
-  auto *memRefType = memoryOp->getMemRefType();
-  return isa<VectorType>(memRefType->getElementType());
+  auto memRefType = memoryOp->getMemRefType();
+  return memRefType.getElementType().template isa<VectorType>();
 }
 
 bool mlir::isVectorizableLoop(const ForStmt &loop, unsigned fastestVaryingDim) {
diff --git a/lib/Analysis/Verifier.cpp b/lib/Analysis/Verifier.cpp
index bfbcb16..0dd030d 100644
--- a/lib/Analysis/Verifier.cpp
+++ b/lib/Analysis/Verifier.cpp
@@ -195,7 +195,7 @@
 
   // Verify that the argument list of the function and the arg list of the first
   // block line up.
-  auto fnInputTypes = fn.getType()->getInputs();
+  auto fnInputTypes = fn.getType().getInputs();
   if (fnInputTypes.size() != firstBB->getNumArguments())
     return failure("first block of cfgfunc must have " +
                        Twine(fnInputTypes.size()) +
@@ -306,7 +306,7 @@
 
 bool CFGFuncVerifier::verifyReturn(const ReturnInst &inst) {
   // Verify that the return operands match the results of the function.
-  auto results = fn.getType()->getResults();
+  auto results = fn.getType().getResults();
   if (inst.getNumOperands() != results.size())
     return failure("return has " + Twine(inst.getNumOperands()) +
                        " operands, but enclosing function returns " +
diff --git a/lib/IR/AsmPrinter.cpp b/lib/IR/AsmPrinter.cpp
index 454a28a..cb5e96f 100644
--- a/lib/IR/AsmPrinter.cpp
+++ b/lib/IR/AsmPrinter.cpp
@@ -122,7 +122,7 @@
   void visitForStmt(const ForStmt *forStmt);
   void visitIfStmt(const IfStmt *ifStmt);
   void visitOperationStmt(const OperationStmt *opStmt);
-  void visitType(const Type *type);
+  void visitType(Type type);
   void visitAttribute(Attribute attr);
   void visitOperation(const Operation *op);
 
@@ -135,16 +135,16 @@
 } // end anonymous namespace
 
 // TODO Support visiting other types/instructions when implemented.
-void ModuleState::visitType(const Type *type) {
-  if (auto *funcType = dyn_cast<FunctionType>(type)) {
+void ModuleState::visitType(Type type) {
+  if (auto funcType = type.dyn_cast<FunctionType>()) {
     // Visit input and result types for functions.
-    for (auto *input : funcType->getInputs())
+    for (auto input : funcType.getInputs())
       visitType(input);
-    for (auto *result : funcType->getResults())
+    for (auto result : funcType.getResults())
       visitType(result);
-  } else if (auto *memref = dyn_cast<MemRefType>(type)) {
+  } else if (auto memref = type.dyn_cast<MemRefType>()) {
     // Visit affine maps in memref type.
-    for (auto map : memref->getAffineMaps()) {
+    for (auto map : memref.getAffineMaps()) {
       recordAffineMapReference(map);
     }
   }
@@ -271,7 +271,7 @@
   void print(const Module *module);
   void printFunctionReference(const Function *func);
   void printAttribute(Attribute attr);
-  void printType(const Type *type);
+  void printType(Type type);
   void print(const Function *fn);
   void print(const ExtFunction *fn);
   void print(const CFGFunction *fn);
@@ -290,7 +290,7 @@
   void printFunctionAttributes(const Function *fn);
   void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
                              ArrayRef<const char *> elidedAttrs = {});
-  void printFunctionResultType(const FunctionType *type);
+  void printFunctionResultType(FunctionType type);
   void printAffineMapId(int affineMapId) const;
   void printAffineMapReference(AffineMap affineMap);
   void printIntegerSetId(int integerSetId) const;
@@ -489,9 +489,9 @@
 }
 
 void ModulePrinter::printDenseElementsAttr(DenseElementsAttr attr) {
-  auto *type = attr.getType();
-  auto shape = type->getShape();
-  auto rank = type->getRank();
+  auto type = attr.getType();
+  auto shape = type.getShape();
+  auto rank = type.getRank();
 
   SmallVector<Attribute, 16> elements;
   attr.getValues(elements);
@@ -541,8 +541,8 @@
     os << ']';
 }
 
-void ModulePrinter::printType(const Type *type) {
-  switch (type->getKind()) {
+void ModulePrinter::printType(Type type) {
+  switch (type.getKind()) {
   case Type::Kind::Index:
     os << "index";
     return;
@@ -581,71 +581,71 @@
     return;
 
   case Type::Kind::Integer: {
-    auto *integer = cast<IntegerType>(type);
-    os << 'i' << integer->getWidth();
+    auto integer = type.cast<IntegerType>();
+    os << 'i' << integer.getWidth();
     return;
   }
   case Type::Kind::Function: {
-    auto *func = cast<FunctionType>(type);
+    auto func = type.cast<FunctionType>();
     os << '(';
-    interleaveComma(func->getInputs(), [&](Type *type) { printType(type); });
+    interleaveComma(func.getInputs(), [&](Type type) { printType(type); });
     os << ") -> ";
-    auto results = func->getResults();
+    auto results = func.getResults();
     if (results.size() == 1)
-      os << *results[0];
+      os << results[0];
     else {
       os << '(';
-      interleaveComma(results, [&](Type *type) { printType(type); });
+      interleaveComma(results, [&](Type type) { printType(type); });
       os << ')';
     }
     return;
   }
   case Type::Kind::Vector: {
-    auto *v = cast<VectorType>(type);
+    auto v = type.cast<VectorType>();
     os << "vector<";
-    for (auto dim : v->getShape())
+    for (auto dim : v.getShape())
       os << dim << 'x';
-    os << *v->getElementType() << '>';
+    os << v.getElementType() << '>';
     return;
   }
   case Type::Kind::RankedTensor: {
-    auto *v = cast<RankedTensorType>(type);
+    auto v = type.cast<RankedTensorType>();
     os << "tensor<";
-    for (auto dim : v->getShape()) {
+    for (auto dim : v.getShape()) {
       if (dim < 0)
         os << '?';
       else
         os << dim;
       os << 'x';
     }
-    os << *v->getElementType() << '>';
+    os << v.getElementType() << '>';
     return;
   }
   case Type::Kind::UnrankedTensor: {
-    auto *v = cast<UnrankedTensorType>(type);
+    auto v = type.cast<UnrankedTensorType>();
     os << "tensor<*x";
-    printType(v->getElementType());
+    printType(v.getElementType());
     os << '>';
     return;
   }
   case Type::Kind::MemRef: {
-    auto *v = cast<MemRefType>(type);
+    auto v = type.cast<MemRefType>();
     os << "memref<";
-    for (auto dim : v->getShape()) {
+    for (auto dim : v.getShape()) {
       if (dim < 0)
         os << '?';
       else
         os << dim;
       os << 'x';
     }
-    printType(v->getElementType());
-    for (auto map : v->getAffineMaps()) {
+    printType(v.getElementType());
+    for (auto map : v.getAffineMaps()) {
       os << ", ";
       printAffineMapReference(map);
     }
     // Only print the memory space if it is the non-default one.
-    if (v->getMemorySpace())
-      os << ", " << v->getMemorySpace();
+    if (v.getMemorySpace())
+      os << ", " << v.getMemorySpace();
     os << '>';
     return;
   }
@@ -842,18 +842,18 @@
 // Function printing
 //===----------------------------------------------------------------------===//
 
-void ModulePrinter::printFunctionResultType(const FunctionType *type) {
-  switch (type->getResults().size()) {
+void ModulePrinter::printFunctionResultType(FunctionType type) {
+  switch (type.getResults().size()) {
   case 0:
     break;
   case 1:
     os << " -> ";
-    printType(type->getResults()[0]);
+    printType(type.getResults()[0]);
     break;
   default:
     os << " -> (";
-    interleaveComma(type->getResults(),
-                    [&](Type *eltType) { printType(eltType); });
+    interleaveComma(type.getResults(),
+                    [&](Type eltType) { printType(eltType); });
     os << ')';
     break;
   }
@@ -871,8 +871,7 @@
   auto type = fn->getType();
 
   os << "@" << fn->getName() << '(';
-  interleaveComma(type->getInputs(),
-                  [&](Type *eltType) { printType(eltType); });
+  interleaveComma(type.getInputs(), [&](Type eltType) { printType(eltType); });
   os << ')';
 
   printFunctionResultType(type);
@@ -937,7 +936,7 @@
 
   // Implement OpAsmPrinter.
   raw_ostream &getStream() const { return os; }
-  void printType(const Type *type) { ModulePrinter::printType(type); }
+  void printType(Type type) { ModulePrinter::printType(type); }
   void printAttribute(Attribute attr) { ModulePrinter::printAttribute(attr); }
   void printAffineMap(AffineMap map) {
     return ModulePrinter::printAffineMapReference(map);
@@ -974,10 +973,10 @@
     if (auto *op = value->getDefiningOperation()) {
       if (auto intOp = op->dyn_cast<ConstantIntOp>()) {
         // i1 constants get special names.
-        if (intOp->getType()->isInteger(1)) {
+        if (intOp->getType().isInteger(1)) {
           specialName << (intOp->getValue() ? "true" : "false");
         } else {
-          specialName << 'c' << intOp->getValue() << '_' << *intOp->getType();
+          specialName << 'c' << intOp->getValue() << '_' << intOp->getType();
         }
       } else if (auto intOp = op->dyn_cast<ConstantIndexOp>()) {
         specialName << 'c' << intOp->getValue();
@@ -1579,7 +1578,7 @@
 
 void Type::print(raw_ostream &os) const {
   ModuleState state(getContext());
-  ModulePrinter(os, state).printType(this);
+  ModulePrinter(os, state).printType(*this);
 }
 
 void Type::dump() const { print(llvm::errs()); }
diff --git a/lib/IR/AttributeDetail.h b/lib/IR/AttributeDetail.h
index a0e9afb..63ad544 100644
--- a/lib/IR/AttributeDetail.h
+++ b/lib/IR/AttributeDetail.h
@@ -26,6 +26,7 @@
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/IntegerSet.h"
 #include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Types.h"
 #include "llvm/Support/TrailingObjects.h"
 
 namespace mlir {
@@ -86,7 +87,7 @@
 
 /// An attribute representing a reference to a type.
 struct TypeAttributeStorage : public AttributeStorage {
-  Type *value;
+  Type value;
 };
 
 /// An attribute representing a reference to a function.
@@ -96,7 +97,7 @@
 
 /// A base attribute representing a reference to a vector or tensor constant.
 struct ElementsAttributeStorage : public AttributeStorage {
-  VectorOrTensorType *type;
+  VectorOrTensorType type;
 };
 
 /// An attribute representing a reference to a vector or tensor constant,
diff --git a/lib/IR/Attributes.cpp b/lib/IR/Attributes.cpp
index 34312b8..58b5b90 100644
--- a/lib/IR/Attributes.cpp
+++ b/lib/IR/Attributes.cpp
@@ -75,9 +75,7 @@
 
 TypeAttr::TypeAttr(Attribute::ImplType *ptr) : Attribute(ptr) {}
 
-Type *TypeAttr::getValue() const {
-  return static_cast<ImplType *>(attr)->value;
-}
+Type TypeAttr::getValue() const { return static_cast<ImplType *>(attr)->value; }
 
 FunctionAttr::FunctionAttr(Attribute::ImplType *ptr) : Attribute(ptr) {}
 
@@ -85,11 +83,11 @@
   return static_cast<ImplType *>(attr)->value;
 }
 
-FunctionType *FunctionAttr::getType() const { return getValue()->getType(); }
+FunctionType FunctionAttr::getType() const { return getValue()->getType(); }
 
 ElementsAttr::ElementsAttr(Attribute::ImplType *ptr) : Attribute(ptr) {}
 
-VectorOrTensorType *ElementsAttr::getType() const {
+VectorOrTensorType ElementsAttr::getType() const {
   return static_cast<ImplType *>(attr)->type;
 }
 
@@ -166,8 +164,8 @@
 
 void DenseIntElementsAttr::getValues(SmallVectorImpl<Attribute> &values) const {
   auto bitsWidth = static_cast<ImplType *>(attr)->bitsWidth;
-  auto elementNum = getType()->getNumElements();
-  auto context = getType()->getContext();
+  auto elementNum = getType().getNumElements();
+  auto context = getType().getContext();
   values.reserve(elementNum);
   if (bitsWidth == 64) {
     ArrayRef<int64_t> vs(
@@ -192,8 +190,8 @@
     : DenseElementsAttr(ptr) {}
 
 void DenseFPElementsAttr::getValues(SmallVectorImpl<Attribute> &values) const {
-  auto elementNum = getType()->getNumElements();
-  auto context = getType()->getContext();
+  auto elementNum = getType().getNumElements();
+  auto context = getType().getContext();
   ArrayRef<double> vs({reinterpret_cast<const double *>(getRawData().data()),
                        getRawData().size() / 8});
   values.reserve(elementNum);
diff --git a/lib/IR/BasicBlock.cpp b/lib/IR/BasicBlock.cpp
index bb8ac75..29a5ce1 100644
--- a/lib/IR/BasicBlock.cpp
+++ b/lib/IR/BasicBlock.cpp
@@ -33,18 +33,18 @@
 // Argument list management.
 //===----------------------------------------------------------------------===//
 
-BBArgument *BasicBlock::addArgument(Type *type) {
+BBArgument *BasicBlock::addArgument(Type type) {
   auto *arg = new BBArgument(type, this);
   arguments.push_back(arg);
   return arg;
 }
 
 /// Add one argument to the argument list for each type specified in the list.
-auto BasicBlock::addArguments(ArrayRef<Type *> types)
+auto BasicBlock::addArguments(ArrayRef<Type> types)
     -> llvm::iterator_range<args_iterator> {
   arguments.reserve(arguments.size() + types.size());
   auto initialSize = arguments.size();
-  for (auto *type : types) {
+  for (auto type : types) {
     addArgument(type);
   }
   return {arguments.data() + initialSize, arguments.data() + arguments.size()};
diff --git a/lib/IR/Builders.cpp b/lib/IR/Builders.cpp
index 22d749a..906b580 100644
--- a/lib/IR/Builders.cpp
+++ b/lib/IR/Builders.cpp
@@ -52,59 +52,58 @@
 // Types.
 //===----------------------------------------------------------------------===//
 
-FloatType *Builder::getBF16Type() { return Type::getBF16(context); }
+FloatType Builder::getBF16Type() { return Type::getBF16(context); }
 
-FloatType *Builder::getF16Type() { return Type::getF16(context); }
+FloatType Builder::getF16Type() { return Type::getF16(context); }
 
-FloatType *Builder::getF32Type() { return Type::getF32(context); }
+FloatType Builder::getF32Type() { return Type::getF32(context); }
 
-FloatType *Builder::getF64Type() { return Type::getF64(context); }
+FloatType Builder::getF64Type() { return Type::getF64(context); }
 
-OtherType *Builder::getIndexType() { return Type::getIndex(context); }
+OtherType Builder::getIndexType() { return Type::getIndex(context); }
 
-OtherType *Builder::getTFControlType() { return Type::getTFControl(context); }
+OtherType Builder::getTFControlType() { return Type::getTFControl(context); }
 
-OtherType *Builder::getTFResourceType() { return Type::getTFResource(context); }
+OtherType Builder::getTFResourceType() { return Type::getTFResource(context); }
 
-OtherType *Builder::getTFVariantType() { return Type::getTFVariant(context); }
+OtherType Builder::getTFVariantType() { return Type::getTFVariant(context); }
 
-OtherType *Builder::getTFComplex64Type() {
+OtherType Builder::getTFComplex64Type() {
   return Type::getTFComplex64(context);
 }
 
-OtherType *Builder::getTFComplex128Type() {
+OtherType Builder::getTFComplex128Type() {
   return Type::getTFComplex128(context);
 }
 
-OtherType *Builder::getTFF32REFType() { return Type::getTFF32REF(context); }
+OtherType Builder::getTFF32REFType() { return Type::getTFF32REF(context); }
 
-OtherType *Builder::getTFStringType() { return Type::getTFString(context); }
+OtherType Builder::getTFStringType() { return Type::getTFString(context); }
 
-IntegerType *Builder::getIntegerType(unsigned width) {
+IntegerType Builder::getIntegerType(unsigned width) {
   return Type::getInteger(width, context);
 }
 
-FunctionType *Builder::getFunctionType(ArrayRef<Type *> inputs,
-                                       ArrayRef<Type *> results) {
+FunctionType Builder::getFunctionType(ArrayRef<Type> inputs,
+                                      ArrayRef<Type> results) {
   return FunctionType::get(inputs, results, context);
 }
 
-MemRefType *Builder::getMemRefType(ArrayRef<int> shape, Type *elementType,
-                                   ArrayRef<AffineMap> affineMapComposition,
-                                   unsigned memorySpace) {
+MemRefType Builder::getMemRefType(ArrayRef<int> shape, Type elementType,
+                                  ArrayRef<AffineMap> affineMapComposition,
+                                  unsigned memorySpace) {
   return MemRefType::get(shape, elementType, affineMapComposition, memorySpace);
 }
 
-VectorType *Builder::getVectorType(ArrayRef<int> shape, Type *elementType) {
+VectorType Builder::getVectorType(ArrayRef<int> shape, Type elementType) {
   return VectorType::get(shape, elementType);
 }
 
-RankedTensorType *Builder::getTensorType(ArrayRef<int> shape,
-                                         Type *elementType) {
+RankedTensorType Builder::getTensorType(ArrayRef<int> shape, Type elementType) {
   return RankedTensorType::get(shape, elementType);
 }
 
-UnrankedTensorType *Builder::getTensorType(Type *elementType) {
+UnrankedTensorType Builder::getTensorType(Type elementType) {
   return UnrankedTensorType::get(elementType);
 }
 
@@ -144,7 +143,7 @@
   return IntegerSetAttr::get(set);
 }
 
-TypeAttr Builder::getTypeAttr(Type *type) {
+TypeAttr Builder::getTypeAttr(Type type) {
   return TypeAttr::get(type, context);
 }
 
@@ -152,23 +151,23 @@
   return FunctionAttr::get(value, context);
 }
 
-ElementsAttr Builder::getSplatElementsAttr(VectorOrTensorType *type,
+ElementsAttr Builder::getSplatElementsAttr(VectorOrTensorType type,
                                            Attribute elt) {
   return SplatElementsAttr::get(type, elt);
 }
 
-ElementsAttr Builder::getDenseElementsAttr(VectorOrTensorType *type,
+ElementsAttr Builder::getDenseElementsAttr(VectorOrTensorType type,
                                            ArrayRef<char> data) {
   return DenseElementsAttr::get(type, data);
 }
 
-ElementsAttr Builder::getSparseElementsAttr(VectorOrTensorType *type,
+ElementsAttr Builder::getSparseElementsAttr(VectorOrTensorType type,
                                             DenseIntElementsAttr indices,
                                             DenseElementsAttr values) {
   return SparseElementsAttr::get(type, indices, values);
 }
 
-ElementsAttr Builder::getOpaqueElementsAttr(VectorOrTensorType *type,
+ElementsAttr Builder::getOpaqueElementsAttr(VectorOrTensorType type,
                                             StringRef bytes) {
   return OpaqueElementsAttr::get(type, bytes);
 }
@@ -296,7 +295,7 @@
 OperationStmt *MLFuncBuilder::createOperation(Location *location,
                                               OperationName name,
                                               ArrayRef<MLValue *> operands,
-                                              ArrayRef<Type *> types,
+                                              ArrayRef<Type> types,
                                               ArrayRef<NamedAttribute> attrs) {
   auto *op = OperationStmt::create(location, name, operands, types, attrs,
                                    getContext());
diff --git a/lib/IR/BuiltinOps.cpp b/lib/IR/BuiltinOps.cpp
index 542e67e..e4bca03 100644
--- a/lib/IR/BuiltinOps.cpp
+++ b/lib/IR/BuiltinOps.cpp
@@ -63,7 +63,7 @@
   numDims = opInfos.size();
 
   // Parse the optional symbol operands.
-  auto *affineIntTy = parser->getBuilder().getIndexType();
+  auto affineIntTy = parser->getBuilder().getIndexType();
   if (parser->parseOperandList(opInfos, -1,
                                OpAsmParser::Delimiter::OptionalSquare) ||
       parser->resolveOperands(opInfos, affineIntTy, operands))
@@ -84,7 +84,7 @@
 
 bool AffineApplyOp::parse(OpAsmParser *parser, OperationState *result) {
   auto &builder = parser->getBuilder();
-  auto *affineIntTy = builder.getIndexType();
+  auto affineIntTy = builder.getIndexType();
 
   AffineMapAttr mapAttr;
   unsigned numDims;
@@ -171,7 +171,7 @@
 
 /// Builds a constant op with the specified attribute value and result type.
 void ConstantOp::build(Builder *builder, OperationState *result,
-                       Attribute value, Type *type) {
+                       Attribute value, Type type) {
   result->addAttribute("value", value);
   result->types.push_back(type);
 }
@@ -181,12 +181,12 @@
   p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"value");
 
   if (!getValue().isa<FunctionAttr>())
-    *p << " : " << *getType();
+    *p << " : " << getType();
 }
 
 bool ConstantOp::parse(OpAsmParser *parser, OperationState *result) {
   Attribute valueAttr;
-  Type *type;
+  Type type;
 
   if (parser->parseAttribute(valueAttr, "value", result->attributes) ||
       parser->parseOptionalAttributeDict(result->attributes))
@@ -208,33 +208,33 @@
   if (!value)
     return emitOpError("requires a 'value' attribute");
 
-  auto *type = this->getType();
-  if (isa<IntegerType>(type) || type->isIndex()) {
+  auto type = this->getType();
+  if (type.isa<IntegerType>() || type.isIndex()) {
     if (!value.isa<IntegerAttr>())
       return emitOpError(
           "requires 'value' to be an integer for an integer result type");
     return false;
   }
 
-  if (isa<FloatType>(type)) {
+  if (type.isa<FloatType>()) {
     if (!value.isa<FloatAttr>())
       return emitOpError("requires 'value' to be a floating point constant");
     return false;
   }
 
-  if (isa<VectorOrTensorType>(type)) {
+  if (type.isa<VectorOrTensorType>()) {
     if (!value.isa<ElementsAttr>())
       return emitOpError("requires 'value' to be a vector/tensor constant");
     return false;
   }
 
-  if (type->isTFString()) {
+  if (type.isTFString()) {
     if (!value.isa<StringAttr>())
       return emitOpError("requires 'value' to be a string constant");
     return false;
   }
 
-  if (isa<FunctionType>(type)) {
+  if (type.isa<FunctionType>()) {
     if (!value.isa<FunctionAttr>())
       return emitOpError("requires 'value' to be a function reference");
     return false;
@@ -251,19 +251,19 @@
 }
 
 void ConstantFloatOp::build(Builder *builder, OperationState *result,
-                            const APFloat &value, FloatType *type) {
+                            const APFloat &value, FloatType type) {
   ConstantOp::build(builder, result, builder->getFloatAttr(value), type);
 }
 
 bool ConstantFloatOp::isClassFor(const Operation *op) {
   return ConstantOp::isClassFor(op) &&
-         isa<FloatType>(op->getResult(0)->getType());
+         op->getResult(0)->getType().isa<FloatType>();
 }
 
 /// ConstantIntOp only matches values whose result type is an IntegerType.
 bool ConstantIntOp::isClassFor(const Operation *op) {
   return ConstantOp::isClassFor(op) &&
-         isa<IntegerType>(op->getResult(0)->getType());
+         op->getResult(0)->getType().isa<IntegerType>();
 }
 
 void ConstantIntOp::build(Builder *builder, OperationState *result,
@@ -275,14 +275,14 @@
 /// Build a constant int op producing an integer with the specified type,
 /// which must be an integer type.
 void ConstantIntOp::build(Builder *builder, OperationState *result,
-                          int64_t value, Type *type) {
-  assert(isa<IntegerType>(type) && "ConstantIntOp can only have integer type");
+                          int64_t value, Type type) {
+  assert(type.isa<IntegerType>() && "ConstantIntOp can only have integer type");
   ConstantOp::build(builder, result, builder->getIntegerAttr(value), type);
 }
 
 /// ConstantIndexOp only matches values whose result type is Index.
 bool ConstantIndexOp::isClassFor(const Operation *op) {
-  return ConstantOp::isClassFor(op) && op->getResult(0)->getType()->isIndex();
+  return ConstantOp::isClassFor(op) && op->getResult(0)->getType().isIndex();
 }
 
 void ConstantIndexOp::build(Builder *builder, OperationState *result,
@@ -302,7 +302,7 @@
 
 bool ReturnOp::parse(OpAsmParser *parser, OperationState *result) {
   SmallVector<OpAsmParser::OperandType, 2> opInfo;
-  SmallVector<Type *, 2> types;
+  SmallVector<Type, 2> types;
   llvm::SMLoc loc;
   return parser->getCurrentLocation(&loc) || parser->parseOperandList(opInfo) ||
          (!opInfo.empty() && parser->parseColonTypeList(types)) ||
@@ -330,7 +330,7 @@
 
     // The operand number and types must match the function signature.
     MLFunction *function = cast<MLFunction>(block);
-    const auto &results = function->getType()->getResults();
+    const auto &results = function->getType().getResults();
     if (stmt->getNumOperands() != results.size())
       return emitOpError("has " + Twine(stmt->getNumOperands()) +
                          " operands, but enclosing function returns " +
diff --git a/lib/IR/Function.cpp b/lib/IR/Function.cpp
index efeb16b..70c0e12 100644
--- a/lib/IR/Function.cpp
+++ b/lib/IR/Function.cpp
@@ -28,8 +28,8 @@
 using namespace mlir;
 
 Function::Function(Kind kind, Location *location, StringRef name,
-                   FunctionType *type, ArrayRef<NamedAttribute> attrs)
-    : nameAndKind(Identifier::get(name, type->getContext()), kind),
+                   FunctionType type, ArrayRef<NamedAttribute> attrs)
+    : nameAndKind(Identifier::get(name, type.getContext()), kind),
       location(location), type(type) {
   this->attrs = AttributeListStorage::get(attrs, getContext());
 }
@@ -46,7 +46,7 @@
     return {};
 }
 
-MLIRContext *Function::getContext() const { return getType()->getContext(); }
+MLIRContext *Function::getContext() const { return getType().getContext(); }
 
 /// Delete this object.
 void Function::destroy() {
@@ -159,7 +159,7 @@
 // ExtFunction implementation.
 //===----------------------------------------------------------------------===//
 
-ExtFunction::ExtFunction(Location *location, StringRef name, FunctionType *type,
+ExtFunction::ExtFunction(Location *location, StringRef name, FunctionType type,
                          ArrayRef<NamedAttribute> attrs)
     : Function(Kind::ExtFunc, location, name, type, attrs) {}
 
@@ -167,7 +167,7 @@
 // CFGFunction implementation.
 //===----------------------------------------------------------------------===//
 
-CFGFunction::CFGFunction(Location *location, StringRef name, FunctionType *type,
+CFGFunction::CFGFunction(Location *location, StringRef name, FunctionType type,
                          ArrayRef<NamedAttribute> attrs)
     : Function(Kind::CFGFunc, location, name, type, attrs) {}
 
@@ -188,9 +188,9 @@
 
 /// Create a new MLFunction with the specific fields.
 MLFunction *MLFunction::create(Location *location, StringRef name,
-                               FunctionType *type,
+                               FunctionType type,
                                ArrayRef<NamedAttribute> attrs) {
-  const auto &argTypes = type->getInputs();
+  const auto &argTypes = type.getInputs();
   auto byteSize = totalSizeToAlloc<MLFuncArgument>(argTypes.size());
   void *rawMem = malloc(byteSize);
 
@@ -204,7 +204,7 @@
   return function;
 }
 
-MLFunction::MLFunction(Location *location, StringRef name, FunctionType *type,
+MLFunction::MLFunction(Location *location, StringRef name, FunctionType type,
                        ArrayRef<NamedAttribute> attrs)
     : Function(Kind::MLFunc, location, name, type, attrs),
       StmtBlock(StmtBlockKind::MLFunc) {}
diff --git a/lib/IR/Instructions.cpp b/lib/IR/Instructions.cpp
index 422636b..d2f49dd 100644
--- a/lib/IR/Instructions.cpp
+++ b/lib/IR/Instructions.cpp
@@ -143,7 +143,7 @@
 /// Create a new OperationInst with the specified fields.
 OperationInst *OperationInst::create(Location *location, OperationName name,
                                      ArrayRef<CFGValue *> operands,
-                                     ArrayRef<Type *> resultTypes,
+                                     ArrayRef<Type> resultTypes,
                                      ArrayRef<NamedAttribute> attributes,
                                      MLIRContext *context) {
   auto byteSize = totalSizeToAlloc<InstOperand, InstResult>(operands.size(),
@@ -167,7 +167,7 @@
 
 OperationInst *OperationInst::clone() const {
   SmallVector<CFGValue *, 8> operands;
-  SmallVector<Type *, 8> resultTypes;
+  SmallVector<Type, 8> resultTypes;
 
   // Put together the operands and results.
   for (auto *operand : getOperands())
diff --git a/lib/IR/MLIRContext.cpp b/lib/IR/MLIRContext.cpp
index 0a2e941..8811f7b 100644
--- a/lib/IR/MLIRContext.cpp
+++ b/lib/IR/MLIRContext.cpp
@@ -21,6 +21,7 @@
 #include "AttributeDetail.h"
 #include "AttributeListStorage.h"
 #include "IntegerSetDetail.h"
+#include "TypeDetail.h"
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Attributes.h"
@@ -44,11 +45,11 @@
 using namespace llvm;
 
 namespace {
-struct FunctionTypeKeyInfo : DenseMapInfo<FunctionType *> {
+struct FunctionTypeKeyInfo : DenseMapInfo<FunctionTypeStorage *> {
   // Functions are uniqued based on their inputs and results.
-  using KeyTy = std::pair<ArrayRef<Type *>, ArrayRef<Type *>>;
-  using DenseMapInfo<FunctionType *>::getHashValue;
-  using DenseMapInfo<FunctionType *>::isEqual;
+  using KeyTy = std::pair<ArrayRef<Type>, ArrayRef<Type>>;
+  using DenseMapInfo<FunctionTypeStorage *>::getHashValue;
+  using DenseMapInfo<FunctionTypeStorage *>::isEqual;
 
   static unsigned getHashValue(KeyTy key) {
     return hash_combine(
@@ -56,7 +57,7 @@
         hash_combine_range(key.second.begin(), key.second.end()));
   }
 
-  static bool isEqual(const KeyTy &lhs, const FunctionType *rhs) {
+  static bool isEqual(const KeyTy &lhs, const FunctionTypeStorage *rhs) {
     if (rhs == getEmptyKey() || rhs == getTombstoneKey())
       return false;
     return lhs == KeyTy(rhs->getInputs(), rhs->getResults());
@@ -109,65 +110,64 @@
   }
 };
 
-struct VectorTypeKeyInfo : DenseMapInfo<VectorType *> {
+struct VectorTypeKeyInfo : DenseMapInfo<VectorTypeStorage *> {
   // Vectors are uniqued based on their element type and shape.
-  using KeyTy = std::pair<Type *, ArrayRef<int>>;
-  using DenseMapInfo<VectorType *>::getHashValue;
-  using DenseMapInfo<VectorType *>::isEqual;
+  using KeyTy = std::pair<Type, ArrayRef<int>>;
+  using DenseMapInfo<VectorTypeStorage *>::getHashValue;
+  using DenseMapInfo<VectorTypeStorage *>::isEqual;
 
   static unsigned getHashValue(KeyTy key) {
     return hash_combine(
-        DenseMapInfo<Type *>::getHashValue(key.first),
+        DenseMapInfo<Type>::getHashValue(key.first),
         hash_combine_range(key.second.begin(), key.second.end()));
   }
 
-  static bool isEqual(const KeyTy &lhs, const VectorType *rhs) {
+  static bool isEqual(const KeyTy &lhs, const VectorTypeStorage *rhs) {
     if (rhs == getEmptyKey() || rhs == getTombstoneKey())
       return false;
-    return lhs == KeyTy(rhs->getElementType(), rhs->getShape());
+    return lhs == KeyTy(rhs->elementType, rhs->getShape());
   }
 };
 
-struct RankedTensorTypeKeyInfo : DenseMapInfo<RankedTensorType *> {
+struct RankedTensorTypeKeyInfo : DenseMapInfo<RankedTensorTypeStorage *> {
   // Ranked tensors are uniqued based on their element type and shape.
-  using KeyTy = std::pair<Type *, ArrayRef<int>>;
-  using DenseMapInfo<RankedTensorType *>::getHashValue;
-  using DenseMapInfo<RankedTensorType *>::isEqual;
+  using KeyTy = std::pair<Type, ArrayRef<int>>;
+  using DenseMapInfo<RankedTensorTypeStorage *>::getHashValue;
+  using DenseMapInfo<RankedTensorTypeStorage *>::isEqual;
 
   static unsigned getHashValue(KeyTy key) {
     return hash_combine(
-        DenseMapInfo<Type *>::getHashValue(key.first),
+        DenseMapInfo<Type>::getHashValue(key.first),
         hash_combine_range(key.second.begin(), key.second.end()));
   }
 
-  static bool isEqual(const KeyTy &lhs, const RankedTensorType *rhs) {
+  static bool isEqual(const KeyTy &lhs, const RankedTensorTypeStorage *rhs) {
     if (rhs == getEmptyKey() || rhs == getTombstoneKey())
       return false;
-    return lhs == KeyTy(rhs->getElementType(), rhs->getShape());
+    return lhs == KeyTy(rhs->elementType, rhs->getShape());
   }
 };
 
-struct MemRefTypeKeyInfo : DenseMapInfo<MemRefType *> {
+struct MemRefTypeKeyInfo : DenseMapInfo<MemRefTypeStorage *> {
   // MemRefs are uniqued based on their element type, shape, affine map
   // composition, and memory space.
-  using KeyTy =
-      std::tuple<Type *, ArrayRef<int>, ArrayRef<AffineMap>, unsigned>;
-  using DenseMapInfo<MemRefType *>::getHashValue;
-  using DenseMapInfo<MemRefType *>::isEqual;
+  using KeyTy = std::tuple<Type, ArrayRef<int>, ArrayRef<AffineMap>, unsigned>;
+  using DenseMapInfo<MemRefTypeStorage *>::getHashValue;
+  using DenseMapInfo<MemRefTypeStorage *>::isEqual;
 
   static unsigned getHashValue(KeyTy key) {
     return hash_combine(
-        DenseMapInfo<Type *>::getHashValue(std::get<0>(key)),
+        DenseMapInfo<Type>::getHashValue(std::get<0>(key)),
         hash_combine_range(std::get<1>(key).begin(), std::get<1>(key).end()),
         hash_combine_range(std::get<2>(key).begin(), std::get<2>(key).end()),
         std::get<3>(key));
   }
 
-  static bool isEqual(const KeyTy &lhs, const MemRefType *rhs) {
+  static bool isEqual(const KeyTy &lhs, const MemRefTypeStorage *rhs) {
     if (rhs == getEmptyKey() || rhs == getTombstoneKey())
       return false;
-    return lhs == std::make_tuple(rhs->getElementType(), rhs->getShape(),
-                                  rhs->getAffineMaps(), rhs->getMemorySpace());
+    return lhs == std::make_tuple(rhs->elementType, rhs->getShape(),
+                                  rhs->getAffineMaps(), rhs->memorySpace);
   }
 };
 
@@ -221,7 +221,7 @@
 };
 
 struct DenseElementsAttrInfo : DenseMapInfo<DenseElementsAttributeStorage *> {
-  using KeyTy = std::pair<VectorOrTensorType *, ArrayRef<char>>;
+  using KeyTy = std::pair<VectorOrTensorType, ArrayRef<char>>;
   using DenseMapInfo<DenseElementsAttributeStorage *>::getHashValue;
   using DenseMapInfo<DenseElementsAttributeStorage *>::isEqual;
 
@@ -239,7 +239,7 @@
 };
 
 struct OpaqueElementsAttrInfo : DenseMapInfo<OpaqueElementsAttributeStorage *> {
-  using KeyTy = std::pair<VectorOrTensorType *, StringRef>;
+  using KeyTy = std::pair<VectorOrTensorType, StringRef>;
   using DenseMapInfo<OpaqueElementsAttributeStorage *>::getHashValue;
   using DenseMapInfo<OpaqueElementsAttributeStorage *>::isEqual;
 
@@ -295,13 +295,14 @@
   llvm::StringMap<char, llvm::BumpPtrAllocator &> identifiers;
 
   // Uniquing table for 'other' types.
-  OtherType *otherTypes[int(Type::Kind::LAST_OTHER_TYPE) -
-                        int(Type::Kind::FIRST_OTHER_TYPE) + 1] = {nullptr};
+  OtherTypeStorage *otherTypes[int(Type::Kind::LAST_OTHER_TYPE) -
+                               int(Type::Kind::FIRST_OTHER_TYPE) + 1] = {
+      nullptr};
 
   // Uniquing table for 'float' types.
-  FloatType *floatTypes[int(Type::Kind::LAST_FLOATING_POINT_TYPE) -
-                        int(Type::Kind::FIRST_FLOATING_POINT_TYPE) + 1] = {
-      nullptr};
+  FloatTypeStorage *floatTypes[int(Type::Kind::LAST_FLOATING_POINT_TYPE) -
+                               int(Type::Kind::FIRST_FLOATING_POINT_TYPE) + 1] =
+      {nullptr};
 
   // Affine map uniquing.
   using AffineMapSet = DenseSet<AffineMap, AffineMapKeyInfo>;
@@ -324,26 +325,26 @@
   DenseMap<int64_t, AffineConstantExprStorage *> constExprs;
 
   /// Integer type uniquing.
-  DenseMap<unsigned, IntegerType *> integers;
+  DenseMap<unsigned, IntegerTypeStorage *> integers;
 
   /// Function type uniquing.
-  using FunctionTypeSet = DenseSet<FunctionType *, FunctionTypeKeyInfo>;
+  using FunctionTypeSet = DenseSet<FunctionTypeStorage *, FunctionTypeKeyInfo>;
   FunctionTypeSet functions;
 
   /// Vector type uniquing.
-  using VectorTypeSet = DenseSet<VectorType *, VectorTypeKeyInfo>;
+  using VectorTypeSet = DenseSet<VectorTypeStorage *, VectorTypeKeyInfo>;
   VectorTypeSet vectors;
 
   /// Ranked tensor type uniquing.
   using RankedTensorTypeSet =
-      DenseSet<RankedTensorType *, RankedTensorTypeKeyInfo>;
+      DenseSet<RankedTensorTypeStorage *, RankedTensorTypeKeyInfo>;
   RankedTensorTypeSet rankedTensors;
 
   /// Unranked tensor type uniquing.
-  DenseMap<Type *, UnrankedTensorType *> unrankedTensors;
+  DenseMap<Type, UnrankedTensorTypeStorage *> unrankedTensors;
 
   /// MemRef type uniquing.
-  using MemRefTypeSet = DenseSet<MemRefType *, MemRefTypeKeyInfo>;
+  using MemRefTypeSet = DenseSet<MemRefTypeStorage *, MemRefTypeKeyInfo>;
   MemRefTypeSet memrefs;
 
   // Attribute uniquing.
@@ -355,13 +356,12 @@
   ArrayAttrSet arrayAttrs;
   DenseMap<AffineMap, AffineMapAttributeStorage *> affineMapAttrs;
   DenseMap<IntegerSet, IntegerSetAttributeStorage *> integerSetAttrs;
-  DenseMap<Type *, TypeAttributeStorage *> typeAttrs;
+  DenseMap<Type, TypeAttributeStorage *> typeAttrs;
   using AttributeListSet =
       DenseSet<AttributeListStorage *, AttributeListKeyInfo>;
   AttributeListSet attributeLists;
   DenseMap<const Function *, FunctionAttributeStorage *> functionAttrs;
-  DenseMap<std::pair<VectorOrTensorType *, Attribute>,
-           SplatElementsAttributeStorage *>
+  DenseMap<std::pair<Type, Attribute>, SplatElementsAttributeStorage *>
       splatElementsAttrs;
   using DenseElementsAttrSet =
       DenseSet<DenseElementsAttributeStorage *, DenseElementsAttrInfo>;
@@ -369,7 +369,7 @@
   using OpaqueElementsAttrSet =
       DenseSet<OpaqueElementsAttributeStorage *, OpaqueElementsAttrInfo>;
   OpaqueElementsAttrSet opaqueElementsAttrs;
-  DenseMap<std::tuple<Type *, Attribute, Attribute>,
+  DenseMap<std::tuple<Type, Attribute, Attribute>,
            SparseElementsAttributeStorage *>
       sparseElementsAttrs;
 
@@ -556,19 +556,20 @@
 // Type uniquing
 //===----------------------------------------------------------------------===//
 
-IntegerType *IntegerType::get(unsigned width, MLIRContext *context) {
+IntegerType IntegerType::get(unsigned width, MLIRContext *context) {
+  assert(width <= kMaxWidth && "admissible integer bitwidth exceeded");
   auto &impl = context->getImpl();
 
   auto *&result = impl.integers[width];
   if (!result) {
-    result = impl.allocator.Allocate<IntegerType>();
-    new (result) IntegerType(width, context);
+    result = impl.allocator.Allocate<IntegerTypeStorage>();
+    new (result) IntegerTypeStorage{{Kind::Integer, context}, width};
   }
 
   return result;
 }
 
-FloatType *FloatType::get(Kind kind, MLIRContext *context) {
+FloatType FloatType::get(Kind kind, MLIRContext *context) {
   assert(kind >= Kind::FIRST_FLOATING_POINT_TYPE &&
          kind <= Kind::LAST_FLOATING_POINT_TYPE && "Not an FP type kind");
   auto &impl = context->getImpl();
@@ -580,16 +581,16 @@
     return entry;
 
   // On the first use, we allocate them into the bump pointer.
-  auto *ptr = impl.allocator.Allocate<FloatType>();
+  auto *ptr = impl.allocator.Allocate<FloatTypeStorage>();
 
   // Initialize the memory using placement new.
-  new (ptr) FloatType(kind, context);
+  new (ptr) FloatTypeStorage{{kind, context}};
 
   // Cache and return it.
   return entry = ptr;
 }
 
-OtherType *OtherType::get(Kind kind, MLIRContext *context) {
+OtherType OtherType::get(Kind kind, MLIRContext *context) {
   assert(kind >= Kind::FIRST_OTHER_TYPE && kind <= Kind::LAST_OTHER_TYPE &&
          "Not an 'other' type kind");
   auto &impl = context->getImpl();
@@ -600,18 +601,17 @@
     return entry;
 
   // On the first use, we allocate them into the bump pointer.
-  auto *ptr = impl.allocator.Allocate<OtherType>();
+  auto *ptr = impl.allocator.Allocate<OtherTypeStorage>();
 
   // Initialize the memory using placement new.
-  new (ptr) OtherType(kind, context);
+  new (ptr) OtherTypeStorage{{kind, context}};
 
   // Cache and return it.
   return entry = ptr;
 }
 
-FunctionType *FunctionType::get(ArrayRef<Type *> inputs,
-                                ArrayRef<Type *> results,
-                                MLIRContext *context) {
+FunctionType FunctionType::get(ArrayRef<Type> inputs, ArrayRef<Type> results,
+                               MLIRContext *context) {
   auto &impl = context->getImpl();
 
   // Look to see if we already have this function type.
@@ -623,32 +623,34 @@
     return *existing.first;
 
   // On the first use, we allocate them into the bump pointer.
-  auto *result = impl.allocator.Allocate<FunctionType>();
+  auto *result = impl.allocator.Allocate<FunctionTypeStorage>();
 
   // Copy the inputs and results into the bump pointer.
-  SmallVector<Type *, 16> types;
+  SmallVector<Type, 16> types;
   types.reserve(inputs.size() + results.size());
   types.append(inputs.begin(), inputs.end());
   types.append(results.begin(), results.end());
-  auto typesList = impl.copyInto(ArrayRef<Type *>(types));
+  auto typesList = impl.copyInto(ArrayRef<Type>(types));
 
   // Initialize the memory using placement new.
-  new (result)
-      FunctionType(typesList.data(), inputs.size(), results.size(), context);
+  new (result) FunctionTypeStorage{
+      {Kind::Function, context, static_cast<unsigned int>(inputs.size())},
+      static_cast<unsigned int>(results.size()),
+      typesList.data()};
 
   // Cache and return it.
   return *existing.first = result;
 }
 
-VectorType *VectorType::get(ArrayRef<int> shape, Type *elementType) {
+VectorType VectorType::get(ArrayRef<int> shape, Type elementType) {
   assert(!shape.empty() && "vector types must have at least one dimension");
-  assert((isa<FloatType>(elementType) || isa<IntegerType>(elementType)) &&
+  assert((elementType.isa<FloatType>() || elementType.isa<IntegerType>()) &&
          "vectors elements must be primitives");
   assert(!std::any_of(shape.begin(), shape.end(), [](int i) {
     return i < 0;
   }) && "vector types must have static shape");
 
-  auto *context = elementType->getContext();
+  auto *context = elementType.getContext();
   auto &impl = context->getImpl();
 
   // Look to see if we already have this vector type.
@@ -660,21 +662,23 @@
     return *existing.first;
 
   // On the first use, we allocate them into the bump pointer.
-  auto *result = impl.allocator.Allocate<VectorType>();
+  auto *result = impl.allocator.Allocate<VectorTypeStorage>();
 
   // Copy the shape into the bump pointer.
   shape = impl.copyInto(shape);
 
   // Initialize the memory using placement new.
-  new (result) VectorType(shape, elementType, context);
+  new (result) VectorTypeStorage{
+      {{Kind::Vector, context, static_cast<unsigned int>(shape.size())},
+       elementType},
+      shape.data()};
 
   // Cache and return it.
   return *existing.first = result;
 }
 
-RankedTensorType *RankedTensorType::get(ArrayRef<int> shape,
-                                        Type *elementType) {
-  auto *context = elementType->getContext();
+RankedTensorType RankedTensorType::get(ArrayRef<int> shape, Type elementType) {
+  auto *context = elementType.getContext();
   auto &impl = context->getImpl();
 
   // Look to see if we already have this ranked tensor type.
@@ -686,20 +690,23 @@
     return *existing.first;
 
   // On the first use, we allocate them into the bump pointer.
-  auto *result = impl.allocator.Allocate<RankedTensorType>();
+  auto *result = impl.allocator.Allocate<RankedTensorTypeStorage>();
 
   // Copy the shape into the bump pointer.
   shape = impl.copyInto(shape);
 
   // Initialize the memory using placement new.
-  new (result) RankedTensorType(shape, elementType, context);
+  new (result) RankedTensorTypeStorage{
+      {{{Kind::RankedTensor, context, static_cast<unsigned int>(shape.size())},
+        elementType}},
+      shape.data()};
 
   // Cache and return it.
   return *existing.first = result;
 }
 
-UnrankedTensorType *UnrankedTensorType::get(Type *elementType) {
-  auto *context = elementType->getContext();
+UnrankedTensorType UnrankedTensorType::get(Type elementType) {
+  auto *context = elementType.getContext();
   auto &impl = context->getImpl();
 
   // Look to see if we already have this unranked tensor type.
@@ -710,17 +717,18 @@
     return result;
 
   // On the first use, we allocate them into the bump pointer.
-  result = impl.allocator.Allocate<UnrankedTensorType>();
+  result = impl.allocator.Allocate<UnrankedTensorTypeStorage>();
 
   // Initialize the memory using placement new.
-  new (result) UnrankedTensorType(elementType, context);
+  new (result) UnrankedTensorTypeStorage{
+      {{{Kind::UnrankedTensor, context}, elementType}}};
   return result;
 }
 
-MemRefType *MemRefType::get(ArrayRef<int> shape, Type *elementType,
-                            ArrayRef<AffineMap> affineMapComposition,
-                            unsigned memorySpace) {
-  auto *context = elementType->getContext();
+MemRefType MemRefType::get(ArrayRef<int> shape, Type elementType,
+                           ArrayRef<AffineMap> affineMapComposition,
+                           unsigned memorySpace) {
+  auto *context = elementType.getContext();
   auto &impl = context->getImpl();
 
   // Drop the unbounded identity maps from the composition.
@@ -744,7 +752,7 @@
     return *existing.first;
 
   // On the first use, we allocate them into the bump pointer.
-  auto *result = impl.allocator.Allocate<MemRefType>();
+  auto *result = impl.allocator.Allocate<MemRefTypeStorage>();
 
   // Copy the shape into the bump pointer.
   shape = impl.copyInto(shape);
@@ -755,8 +763,13 @@
       impl.copyInto(ArrayRef<AffineMap>(affineMapComposition));
 
   // Initialize the memory using placement new.
-  new (result) MemRefType(shape, elementType, affineMapComposition, memorySpace,
-                          context);
+  new (result) MemRefTypeStorage{
+      {Kind::MemRef, context, static_cast<unsigned int>(shape.size())},
+      elementType,
+      shape.data(),
+      static_cast<unsigned int>(affineMapComposition.size()),
+      affineMapComposition.data(),
+      memorySpace};
   // Cache and return it.
   return *existing.first = result;
 }
@@ -895,7 +908,7 @@
   return result;
 }
 
-TypeAttr TypeAttr::get(Type *type, MLIRContext *context) {
+TypeAttr TypeAttr::get(Type type, MLIRContext *context) {
   auto *&result = context->getImpl().typeAttrs[type];
   if (result)
     return result;
@@ -1009,9 +1022,9 @@
   return *existing.first = result;
 }
 
-SplatElementsAttr SplatElementsAttr::get(VectorOrTensorType *type,
+SplatElementsAttr SplatElementsAttr::get(VectorOrTensorType type,
                                          Attribute elt) {
-  auto &impl = type->getContext()->getImpl();
+  auto &impl = type.getContext()->getImpl();
 
   // Look to see if we already have this.
   auto *&result = impl.splatElementsAttrs[{type, elt}];
@@ -1030,14 +1043,14 @@
   return result;
 }
 
-DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType *type,
+DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType type,
                                          ArrayRef<char> data) {
-  auto bitsRequired = (long)type->getBitWidth() * type->getNumElements();
+  auto bitsRequired = (long)type.getBitWidth() * type.getNumElements();
   (void)bitsRequired;
   assert((bitsRequired <= data.size() * 8L) &&
          "Input data bit size should be larger than that type requires");
 
-  auto &impl = type->getContext()->getImpl();
+  auto &impl = type.getContext()->getImpl();
 
   // Look to see if this constant is already defined.
   DenseElementsAttrInfo::KeyTy key({type, data});
@@ -1048,8 +1061,8 @@
     return *existing.first;
 
   // Otherwise, allocate a new one, unique it and return it.
-  auto *eltType = type->getElementType();
-  switch (eltType->getKind()) {
+  auto eltType = type.getElementType();
+  switch (eltType.getKind()) {
   case Type::Kind::BF16:
   case Type::Kind::F16:
   case Type::Kind::F32:
@@ -1064,7 +1077,7 @@
     return *existing.first = result;
   }
   case Type::Kind::Integer: {
-    auto width = ::cast<IntegerType>(eltType)->getWidth();
+    auto width = eltType.cast<IntegerType>().getWidth();
     auto *result = impl.allocator.Allocate<DenseIntElementsAttributeStorage>();
     auto *copy = (char *)impl.allocator.Allocate(data.size(), 64);
     std::uninitialized_copy(data.begin(), data.end(), copy);
@@ -1080,12 +1093,12 @@
   }
 }
 
-OpaqueElementsAttr OpaqueElementsAttr::get(VectorOrTensorType *type,
+OpaqueElementsAttr OpaqueElementsAttr::get(VectorOrTensorType type,
                                            StringRef bytes) {
-  assert(isValidTensorElementType(type->getElementType()) &&
+  assert(isValidTensorElementType(type.getElementType()) &&
          "Input element type should be a valid tensor element type");
 
-  auto &impl = type->getContext()->getImpl();
+  auto &impl = type.getContext()->getImpl();
 
   // Look to see if this constant is already defined.
   OpaqueElementsAttrInfo::KeyTy key({type, bytes});
@@ -1104,10 +1117,10 @@
   return *existing.first = result;
 }
 
-SparseElementsAttr SparseElementsAttr::get(VectorOrTensorType *type,
+SparseElementsAttr SparseElementsAttr::get(VectorOrTensorType type,
                                            DenseIntElementsAttr indices,
                                            DenseElementsAttr values) {
-  auto &impl = type->getContext()->getImpl();
+  auto &impl = type.getContext()->getImpl();
 
   // Look to see if we already have this.
   auto key = std::make_tuple(type, indices, values);
diff --git a/lib/IR/Operation.cpp b/lib/IR/Operation.cpp
index 2ed09b8..0722421 100644
--- a/lib/IR/Operation.cpp
+++ b/lib/IR/Operation.cpp
@@ -377,7 +377,7 @@
 }
 
 bool OpTrait::impl::verifySameOperandsAndResult(const Operation *op) {
-  auto *type = op->getResult(0)->getType();
+  auto type = op->getResult(0)->getType();
   for (unsigned i = 1, e = op->getNumResults(); i < e; ++i) {
     if (op->getResult(i)->getType() != type)
       return op->emitOpError(
@@ -393,19 +393,19 @@
 
 /// If this is a vector type, or a tensor type, return the scalar element type
 /// that it is built around, otherwise return the type unmodified.
-static Type *getTensorOrVectorElementType(Type *type) {
-  if (auto *vec = dyn_cast<VectorType>(type))
-    return vec->getElementType();
+static Type getTensorOrVectorElementType(Type type) {
+  if (auto vec = type.dyn_cast<VectorType>())
+    return vec.getElementType();
 
   // Look through tensor<vector<...>> to find the underlying element type.
-  if (auto *tensor = dyn_cast<TensorType>(type))
-    return getTensorOrVectorElementType(tensor->getElementType());
+  if (auto tensor = type.dyn_cast<TensorType>())
+    return getTensorOrVectorElementType(tensor.getElementType());
   return type;
 }
 
 bool OpTrait::impl::verifyResultsAreFloatLike(const Operation *op) {
   for (auto *result : op->getResults()) {
-    if (!isa<FloatType>(getTensorOrVectorElementType(result->getType())))
+    if (!getTensorOrVectorElementType(result->getType()).isa<FloatType>())
       return op->emitOpError("requires a floating point type");
   }
 
@@ -414,7 +414,7 @@
 
 bool OpTrait::impl::verifyResultsAreIntegerLike(const Operation *op) {
   for (auto *result : op->getResults()) {
-    if (!isa<IntegerType>(getTensorOrVectorElementType(result->getType())))
+    if (!getTensorOrVectorElementType(result->getType()).isa<IntegerType>())
       return op->emitOpError("requires an integer type");
   }
   return false;
@@ -436,7 +436,7 @@
 
 bool impl::parseBinaryOp(OpAsmParser *parser, OperationState *result) {
   SmallVector<OpAsmParser::OperandType, 2> ops;
-  Type *type;
+  Type type;
   return parser->parseOperandList(ops, 2) ||
          parser->parseOptionalAttributeDict(result->attributes) ||
          parser->parseColonType(type) ||
@@ -448,7 +448,7 @@
   *p << op->getName() << ' ' << *op->getOperand(0) << ", "
      << *op->getOperand(1);
   p->printOptionalAttrDict(op->getAttrs());
-  *p << " : " << *op->getResult(0)->getType();
+  *p << " : " << op->getResult(0)->getType();
 }
 
 //===----------------------------------------------------------------------===//
@@ -456,14 +456,14 @@
 //===----------------------------------------------------------------------===//
 
 void impl::buildCastOp(Builder *builder, OperationState *result,
-                       SSAValue *source, Type *destType) {
+                       SSAValue *source, Type destType) {
   result->addOperands(source);
   result->addTypes(destType);
 }
 
 bool impl::parseCastOp(OpAsmParser *parser, OperationState *result) {
   OpAsmParser::OperandType srcInfo;
-  Type *srcType, *dstType;
+  Type srcType, dstType;
   return parser->parseOperand(srcInfo) || parser->parseColonType(srcType) ||
          parser->resolveOperand(srcInfo, srcType, result->operands) ||
          parser->parseKeywordType("to", dstType) ||
@@ -472,5 +472,5 @@
 
 void impl::printCastOp(const Operation *op, OpAsmPrinter *p) {
   *p << op->getName() << ' ' << *op->getOperand(0) << " : "
-     << *op->getOperand(0)->getType() << " to " << *op->getResult(0)->getType();
+     << op->getOperand(0)->getType() << " to " << op->getResult(0)->getType();
 }
diff --git a/lib/IR/Statement.cpp b/lib/IR/Statement.cpp
index e9c46d6..698089a1 100644
--- a/lib/IR/Statement.cpp
+++ b/lib/IR/Statement.cpp
@@ -239,7 +239,7 @@
 /// Create a new OperationStmt with the specific fields.
 OperationStmt *OperationStmt::create(Location *location, OperationName name,
                                      ArrayRef<MLValue *> operands,
-                                     ArrayRef<Type *> resultTypes,
+                                     ArrayRef<Type> resultTypes,
                                      ArrayRef<NamedAttribute> attributes,
                                      MLIRContext *context) {
   auto byteSize = totalSizeToAlloc<StmtOperand, StmtResult>(operands.size(),
@@ -288,9 +288,9 @@
   // If we have a result or operand type, that is a constant time way to get
   // to the context.
   if (getNumResults())
-    return getResult(0)->getType()->getContext();
+    return getResult(0)->getType().getContext();
   if (getNumOperands())
-    return getOperand(0)->getType()->getContext();
+    return getOperand(0)->getType().getContext();
 
   // In the very odd case where we have no operands or results, fall back to
   // doing a find.
@@ -474,7 +474,7 @@
   if (operands.empty())
     return findFunction()->getContext();
 
-  return getOperand(0)->getType()->getContext();
+  return getOperand(0)->getType().getContext();
 }
 
 //===----------------------------------------------------------------------===//
@@ -501,7 +501,7 @@
     operands.push_back(remapOperand(opValue));
 
   if (auto *opStmt = dyn_cast<OperationStmt>(this)) {
-    SmallVector<Type *, 8> resultTypes;
+    SmallVector<Type, 8> resultTypes;
     resultTypes.reserve(opStmt->getNumResults());
     for (auto *result : opStmt->getResults())
       resultTypes.push_back(result->getType());
diff --git a/lib/IR/TypeDetail.h b/lib/IR/TypeDetail.h
new file mode 100644
index 0000000..c22e87a
--- /dev/null
+++ b/lib/IR/TypeDetail.h
@@ -0,0 +1,126 @@
+//===- TypeDetail.h - MLIR Affine Expr storage details ----------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This holds implementation details of Type.
+//
+//===----------------------------------------------------------------------===//
+#ifndef TYPEDETAIL_H_
+#define TYPEDETAIL_H_
+
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Types.h"
+
+namespace mlir {
+
+class AffineMap;
+class MLIRContext;
+
+namespace detail {
+
+/// Base storage class appearing in a Type.
+struct alignas(8) TypeStorage {
+  TypeStorage(Type::Kind kind, MLIRContext *context)
+      : context(context), kind(kind), subclassData(0) {}
+  TypeStorage(Type::Kind kind, MLIRContext *context, unsigned subclassData)
+      : context(context), kind(kind), subclassData(subclassData) {}
+
+  unsigned getSubclassData() const { return subclassData; }
+
+  void setSubclassData(unsigned val) {
+    subclassData = val;
+    // Ensure we don't have any accidental truncation.
+    assert(getSubclassData() == val && "Subclass data too large for field");
+  }
+
+  /// This refers to the MLIRContext in which this type was uniqued.
+  MLIRContext *const context;
+
+  /// Classification of the subclass, used for type checking.
+  Type::Kind kind : 8;
+
+  /// Space for subclasses to store data.
+  unsigned subclassData : 24;
+};
+
+struct IntegerTypeStorage : public TypeStorage {
+  unsigned width;
+};
+
+struct FloatTypeStorage : public TypeStorage {};
+
+struct OtherTypeStorage : public TypeStorage {};
+
+struct FunctionTypeStorage : public TypeStorage {
+  ArrayRef<Type> getInputs() const {
+    return ArrayRef<Type>(inputsAndResults, subclassData);
+  }
+  ArrayRef<Type> getResults() const {
+    return ArrayRef<Type>(inputsAndResults + subclassData, numResults);
+  }
+
+  unsigned numResults;
+  Type const *inputsAndResults;
+};
+
+struct VectorOrTensorTypeStorage : public TypeStorage {
+  Type elementType;
+};
+
+struct VectorTypeStorage : public VectorOrTensorTypeStorage {
+  ArrayRef<int> getShape() const {
+    return ArrayRef<int>(shapeElements, getSubclassData());
+  }
+
+  const int *shapeElements;
+};
+
+struct TensorTypeStorage : public VectorOrTensorTypeStorage {};
+
+struct RankedTensorTypeStorage : public TensorTypeStorage {
+  ArrayRef<int> getShape() const {
+    return ArrayRef<int>(shapeElements, getSubclassData());
+  }
+
+  const int *shapeElements;
+};
+
+struct UnrankedTensorTypeStorage : public TensorTypeStorage {};
+
+struct MemRefTypeStorage : public TypeStorage {
+  ArrayRef<int> getShape() const {
+    return ArrayRef<int>(shapeElements, getSubclassData());
+  }
+
+  ArrayRef<AffineMap> getAffineMaps() const {
+    return ArrayRef<AffineMap>(affineMapList, numAffineMaps);
+  }
+
+  /// The type of each scalar element of the memref.
+  Type elementType;
+  /// An array of integers which stores the shape dimension sizes.
+  const int *shapeElements;
+  /// The number of affine maps in the 'affineMapList' array.
+  const unsigned numAffineMaps;
+  /// List of affine maps in the memref's layout/index map composition.
+  AffineMap const *affineMapList;
+  /// Memory space in which data referenced by memref resides.
+  const unsigned memorySpace;
+};
+
+} // namespace detail
+} // namespace mlir
+#endif // TYPEDETAIL_H_
diff --git a/lib/IR/Types.cpp b/lib/IR/Types.cpp
index 0ad3f47..1a71695 100644
--- a/lib/IR/Types.cpp
+++ b/lib/IR/Types.cpp
@@ -16,10 +16,17 @@
 // =============================================================================
 
 #include "mlir/IR/Types.h"
+#include "TypeDetail.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/Support/STLExtras.h"
 #include "llvm/Support/raw_ostream.h"
+
 using namespace mlir;
+using namespace mlir::detail;
+
+Type::Kind Type::getKind() const { return type->kind; }
+
+MLIRContext *Type::getContext() const { return type->context; }
 
 unsigned Type::getBitWidth() const {
   switch (getKind()) {
@@ -32,34 +39,49 @@
   case Type::Kind::F64:
     return 64;
   case Type::Kind::Integer:
-    return cast<IntegerType>(this)->getWidth();
+    return cast<IntegerType>().getWidth();
   case Type::Kind::Vector:
   case Type::Kind::RankedTensor:
   case Type::Kind::UnrankedTensor:
-    return cast<VectorOrTensorType>(this)->getElementType()->getBitWidth();
+    return cast<VectorOrTensorType>().getElementType().getBitWidth();
     // TODO: Handle more types.
   default:
     llvm_unreachable("unexpected type");
   }
 }
 
-IntegerType::IntegerType(unsigned width, MLIRContext *context)
-    : Type(Kind::Integer, context), width(width) {
-  assert(width <= kMaxWidth && "admissible integer bitwidth exceeded");
+unsigned Type::getSubclassData() const { return type->getSubclassData(); }
+void Type::setSubclassData(unsigned val) { type->setSubclassData(val); }
+
+IntegerType::IntegerType(Type::ImplType *ptr) : Type(ptr) {}
+
+unsigned IntegerType::getWidth() const {
+  return static_cast<ImplType *>(type)->width;
 }
 
-FloatType::FloatType(Kind kind, MLIRContext *context) : Type(kind, context) {}
+FloatType::FloatType(Type::ImplType *ptr) : Type(ptr) {}
 
-OtherType::OtherType(Kind kind, MLIRContext *context) : Type(kind, context) {}
+OtherType::OtherType(Type::ImplType *ptr) : Type(ptr) {}
 
-FunctionType::FunctionType(Type *const *inputsAndResults, unsigned numInputs,
-                           unsigned numResults, MLIRContext *context)
-    : Type(Kind::Function, context, numInputs), numResults(numResults),
-      inputsAndResults(inputsAndResults) {}
+FunctionType::FunctionType(Type::ImplType *ptr) : Type(ptr) {}
 
-VectorOrTensorType::VectorOrTensorType(Kind kind, MLIRContext *context,
-                                       Type *elementType, unsigned subClassData)
-    : Type(kind, context, subClassData), elementType(elementType) {}
+ArrayRef<Type> FunctionType::getInputs() const {
+  return static_cast<ImplType *>(type)->getInputs();
+}
+
+unsigned FunctionType::getNumResults() const {
+  return static_cast<ImplType *>(type)->numResults;
+}
+
+ArrayRef<Type> FunctionType::getResults() const {
+  return static_cast<ImplType *>(type)->getResults();
+}
+
+VectorOrTensorType::VectorOrTensorType(Type::ImplType *ptr) : Type(ptr) {}
+
+Type VectorOrTensorType::getElementType() const {
+  return static_cast<ImplType *>(type)->elementType;
+}
 
 unsigned VectorOrTensorType::getNumElements() const {
   switch (getKind()) {
@@ -103,11 +125,11 @@
 ArrayRef<int> VectorOrTensorType::getShape() const {
   switch (getKind()) {
   case Kind::Vector:
-    return cast<VectorType>(this)->getShape();
+    return cast<VectorType>().getShape();
   case Kind::RankedTensor:
-    return cast<RankedTensorType>(this)->getShape();
+    return cast<RankedTensorType>().getShape();
   case Kind::UnrankedTensor:
-    return cast<RankedTensorType>(this)->getShape();
+    return cast<RankedTensorType>().getShape();
   default:
     llvm_unreachable("not a VectorOrTensorType");
   }
@@ -118,35 +140,38 @@
   return !std::any_of(dims.begin(), dims.end(), [](int i) { return i < 0; });
 }
 
-VectorType::VectorType(ArrayRef<int> shape, Type *elementType,
-                       MLIRContext *context)
-    : VectorOrTensorType(Kind::Vector, context, elementType, shape.size()),
-      shapeElements(shape.data()) {}
+VectorType::VectorType(Type::ImplType *ptr) : VectorOrTensorType(ptr) {}
 
-TensorType::TensorType(Kind kind, Type *elementType, MLIRContext *context)
-    : VectorOrTensorType(kind, context, elementType) {
-  assert(isValidTensorElementType(elementType));
+ArrayRef<int> VectorType::getShape() const {
+  return static_cast<ImplType *>(type)->getShape();
 }
 
-RankedTensorType::RankedTensorType(ArrayRef<int> shape, Type *elementType,
-                                   MLIRContext *context)
-    : TensorType(Kind::RankedTensor, elementType, context),
-      shapeElements(shape.data()) {
-  setSubclassData(shape.size());
+TensorType::TensorType(Type::ImplType *ptr) : VectorOrTensorType(ptr) {}
+
+RankedTensorType::RankedTensorType(Type::ImplType *ptr) : TensorType(ptr) {}
+
+ArrayRef<int> RankedTensorType::getShape() const {
+  return static_cast<ImplType *>(type)->getShape();
 }
 
-UnrankedTensorType::UnrankedTensorType(Type *elementType, MLIRContext *context)
-    : TensorType(Kind::UnrankedTensor, elementType, context) {}
+UnrankedTensorType::UnrankedTensorType(Type::ImplType *ptr) : TensorType(ptr) {}
 
-MemRefType::MemRefType(ArrayRef<int> shape, Type *elementType,
-                       ArrayRef<AffineMap> affineMapList, unsigned memorySpace,
-                       MLIRContext *context)
-    : Type(Kind::MemRef, context, shape.size()), elementType(elementType),
-      shapeElements(shape.data()), numAffineMaps(affineMapList.size()),
-      affineMapList(affineMapList.data()), memorySpace(memorySpace) {}
+MemRefType::MemRefType(Type::ImplType *ptr) : Type(ptr) {}
+
+ArrayRef<int> MemRefType::getShape() const {
+  return static_cast<ImplType *>(type)->getShape();
+}
+
+Type MemRefType::getElementType() const {
+  return static_cast<ImplType *>(type)->elementType;
+}
 
 ArrayRef<AffineMap> MemRefType::getAffineMaps() const {
-  return ArrayRef<AffineMap>(affineMapList, numAffineMaps);
+  return static_cast<ImplType *>(type)->getAffineMaps();
+}
+
+unsigned MemRefType::getMemorySpace() const {
+  return static_cast<ImplType *>(type)->memorySpace;
 }
 
 unsigned MemRefType::getNumDynamicDims() const {
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index 7974c7c..ceb8931 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -182,19 +182,19 @@
   // as the results of their action.
 
   // Type parsing.
-  VectorType *parseVectorType();
+  VectorType parseVectorType();
   ParseResult parseXInDimensionList();
   ParseResult parseDimensionListRanked(SmallVectorImpl<int> &dimensions);
-  Type *parseTensorType();
-  Type *parseMemRefType();
-  Type *parseFunctionType();
-  Type *parseType();
-  ParseResult parseTypeListNoParens(SmallVectorImpl<Type *> &elements);
-  ParseResult parseTypeList(SmallVectorImpl<Type *> &elements);
+  Type parseTensorType();
+  Type parseMemRefType();
+  Type parseFunctionType();
+  Type parseType();
+  ParseResult parseTypeListNoParens(SmallVectorImpl<Type> &elements);
+  ParseResult parseTypeList(SmallVectorImpl<Type> &elements);
 
   // Attribute parsing.
   Function *resolveFunctionReference(StringRef nameStr, SMLoc nameLoc,
-                                     FunctionType *type);
+                                     FunctionType type);
   Attribute parseAttribute();
 
   ParseResult parseAttributeDict(SmallVectorImpl<NamedAttribute> &attributes);
@@ -206,9 +206,9 @@
   AffineMap parseAffineMapReference();
   IntegerSet parseIntegerSetInline();
   IntegerSet parseIntegerSetReference();
-  DenseElementsAttr parseDenseElementsAttr(VectorOrTensorType *type);
-  DenseElementsAttr parseDenseElementsAttr(Type *eltType, bool isVector);
-  VectorOrTensorType *parseVectorOrTensorType();
+  DenseElementsAttr parseDenseElementsAttr(VectorOrTensorType type);
+  DenseElementsAttr parseDenseElementsAttr(Type eltType, bool isVector);
+  VectorOrTensorType parseVectorOrTensorType();
 
 private:
   // The Parser is subclassed and reinstantiated.  Do not add additional
@@ -299,7 +299,7 @@
 ///   float-type ::= `f16` | `bf16` | `f32` | `f64`
 ///   other-type ::= `index` | `tf_control`
 ///
-Type *Parser::parseType() {
+Type Parser::parseType() {
   switch (getToken().getKind()) {
   default:
     return (emitError("expected type"), nullptr);
@@ -368,7 +368,7 @@
 ///   vector-type ::= `vector` `<` const-dimension-list primitive-type `>`
 ///   const-dimension-list ::= (integer-literal `x`)+
 ///
-VectorType *Parser::parseVectorType() {
+VectorType Parser::parseVectorType() {
   consumeToken(Token::kw_vector);
 
   if (parseToken(Token::less, "expected '<' in vector type"))
@@ -402,11 +402,11 @@
 
   // Parse the element type.
   auto typeLoc = getToken().getLoc();
-  auto *elementType = parseType();
+  auto elementType = parseType();
   if (!elementType || parseToken(Token::greater, "expected '>' in vector type"))
     return nullptr;
 
-  if (!isa<FloatType>(elementType) && !isa<IntegerType>(elementType))
+  if (!elementType.isa<FloatType>() && !elementType.isa<IntegerType>())
     return (emitError(typeLoc, "invalid vector element type"), nullptr);
 
   return VectorType::get(dimensions, elementType);
@@ -461,7 +461,7 @@
 ///   tensor-type ::= `tensor` `<` dimension-list element-type `>`
 ///   dimension-list ::= dimension-list-ranked | `*x`
 ///
-Type *Parser::parseTensorType() {
+Type Parser::parseTensorType() {
   consumeToken(Token::kw_tensor);
 
   if (parseToken(Token::less, "expected '<' in tensor type"))
@@ -485,7 +485,7 @@
 
   // Parse the element type.
   auto typeLoc = getToken().getLoc();
-  auto *elementType = parseType();
+  auto elementType = parseType();
   if (!elementType || parseToken(Token::greater, "expected '>' in tensor type"))
     return nullptr;
 
@@ -505,7 +505,7 @@
 ///   semi-affine-map-composition ::= (semi-affine-map `,` )* semi-affine-map
 ///   memory-space ::= integer-literal /* | TODO: address-space-id */
 ///
-Type *Parser::parseMemRefType() {
+Type Parser::parseMemRefType() {
   consumeToken(Token::kw_memref);
 
   if (parseToken(Token::less, "expected '<' in memref type"))
@@ -517,12 +517,12 @@
 
   // Parse the element type.
   auto typeLoc = getToken().getLoc();
-  auto *elementType = parseType();
+  auto elementType = parseType();
   if (!elementType)
     return nullptr;
 
-  if (!isa<IntegerType>(elementType) && !isa<FloatType>(elementType) &&
-      !isa<VectorType>(elementType))
+  if (!elementType.isa<IntegerType>() && !elementType.isa<FloatType>() &&
+      !elementType.isa<VectorType>())
     return (emitError(typeLoc, "invalid memref element type"), nullptr);
 
   // Parse semi-affine-map-composition.
@@ -581,10 +581,10 @@
 ///
 ///   function-type ::= type-list-parens `->` type-list
 ///
-Type *Parser::parseFunctionType() {
+Type Parser::parseFunctionType() {
   assert(getToken().is(Token::l_paren));
 
-  SmallVector<Type *, 4> arguments, results;
+  SmallVector<Type, 4> arguments, results;
   if (parseTypeList(arguments) ||
       parseToken(Token::arrow, "expected '->' in function type") ||
       parseTypeList(results))
@@ -598,7 +598,7 @@
 ///
 ///   type-list-no-parens ::=  type (`,` type)*
 ///
-ParseResult Parser::parseTypeListNoParens(SmallVectorImpl<Type *> &elements) {
+ParseResult Parser::parseTypeListNoParens(SmallVectorImpl<Type> &elements) {
   auto parseElt = [&]() -> ParseResult {
     auto elt = parseType();
     elements.push_back(elt);
@@ -615,7 +615,7 @@
 ///   type-list-parens ::= `(` `)`
 ///                      | `(` type-list-no-parens `)`
 ///
-ParseResult Parser::parseTypeList(SmallVectorImpl<Type *> &elements) {
+ParseResult Parser::parseTypeList(SmallVectorImpl<Type> &elements) {
   auto parseElt = [&]() -> ParseResult {
     auto elt = parseType();
     elements.push_back(elt);
@@ -639,8 +639,8 @@
 namespace {
 class TensorLiteralParser {
 public:
-  TensorLiteralParser(Parser &p, Type *eltTy)
-      : p(p), eltTy(eltTy), currBitPos(0), bitsWidth(eltTy->getBitWidth()) {}
+  TensorLiteralParser(Parser &p, Type eltTy)
+      : p(p), eltTy(eltTy), currBitPos(0), bitsWidth(eltTy.getBitWidth()) {}
 
   ParseResult parse() { return parseList(shape); }
 
@@ -676,7 +676,7 @@
   }
 
   Parser &p;
-  Type *eltTy;
+  Type eltTy;
   size_t currBitPos;
   size_t bitsWidth;
   SmallVector<int, 4> shape;
@@ -698,7 +698,7 @@
     if (!result)
       return p.emitError("expected tensor element");
     // check result matches the element type.
-    switch (eltTy->getKind()) {
+    switch (eltTy.getKind()) {
     case Type::Kind::BF16:
     case Type::Kind::F16:
     case Type::Kind::F32:
@@ -779,7 +779,7 @@
 /// synthesizing a forward reference) or emit an error and return null on
 /// failure.
 Function *Parser::resolveFunctionReference(StringRef nameStr, SMLoc nameLoc,
-                                           FunctionType *type) {
+                                           FunctionType type) {
   Identifier name = builder.getIdentifier(nameStr.drop_front());
 
   // See if the function has already been defined in the module.
@@ -902,10 +902,10 @@
     if (parseToken(Token::colon, "expected ':' and function type"))
       return nullptr;
     auto typeLoc = getToken().getLoc();
-    Type *type = parseType();
+    Type type = parseType();
     if (!type)
       return nullptr;
-    auto *fnType = dyn_cast<FunctionType>(type);
+    auto fnType = type.dyn_cast<FunctionType>();
     if (!fnType)
       return (emitError(typeLoc, "expected function type"), nullptr);
 
@@ -916,7 +916,7 @@
     consumeToken(Token::kw_opaque);
     if (parseToken(Token::less, "expected '<' after 'opaque'"))
       return nullptr;
-    auto *type = parseVectorOrTensorType();
+    auto type = parseVectorOrTensorType();
     if (!type)
       return nullptr;
     auto val = getToken().getStringValue();
@@ -937,7 +937,7 @@
     if (parseToken(Token::less, "expected '<' after 'splat'"))
       return nullptr;
 
-    auto *type = parseVectorOrTensorType();
+    auto type = parseVectorOrTensorType();
     if (!type)
       return nullptr;
     switch (getToken().getKind()) {
@@ -959,7 +959,7 @@
     if (parseToken(Token::less, "expected '<' after 'dense'"))
       return nullptr;
 
-    auto *type = parseVectorOrTensorType();
+    auto type = parseVectorOrTensorType();
     if (!type)
       return nullptr;
 
@@ -981,41 +981,41 @@
     if (parseToken(Token::less, "Expected '<' after 'sparse'"))
       return nullptr;
 
-    auto *type = parseVectorOrTensorType();
+    auto type = parseVectorOrTensorType();
     if (!type)
       return nullptr;
 
     switch (getToken().getKind()) {
     case Token::l_square: {
       /// Parse indices
-      auto *indicesEltType = builder.getIntegerType(32);
+      auto indicesEltType = builder.getIntegerType(32);
       auto indices =
-          parseDenseElementsAttr(indicesEltType, isa<VectorType>(type));
+          parseDenseElementsAttr(indicesEltType, type.isa<VectorType>());
 
       if (parseToken(Token::comma, "expected ','"))
         return nullptr;
 
       /// Parse values.
-      auto *valuesEltType = type->getElementType();
+      auto valuesEltType = type.getElementType();
       auto values =
-          parseDenseElementsAttr(valuesEltType, isa<VectorType>(type));
+          parseDenseElementsAttr(valuesEltType, type.isa<VectorType>());
 
       /// Sanity check.
-      auto *indicesType = indices.getType();
-      auto *valuesType = values.getType();
-      auto sameShape = (indicesType->getRank() == 1) ||
-                       (type->getRank() == indicesType->getDimSize(1));
+      auto indicesType = indices.getType();
+      auto valuesType = values.getType();
+      auto sameShape = (indicesType.getRank() == 1) ||
+                       (type.getRank() == indicesType.getDimSize(1));
       auto sameElementNum =
-          indicesType->getDimSize(0) == valuesType->getDimSize(0);
+          indicesType.getDimSize(0) == valuesType.getDimSize(0);
       if (!sameShape || !sameElementNum) {
         std::string str;
         llvm::raw_string_ostream s(str);
         s << "expected shape ([";
-        interleaveComma(type->getShape(), s);
+        interleaveComma(type.getShape(), s);
         s << "]); inferred shape of indices literal ([";
-        interleaveComma(indicesType->getShape(), s);
+        interleaveComma(indicesType.getShape(), s);
         s << "]); inferred shape of values literal ([";
-        interleaveComma(valuesType->getShape(), s);
+        interleaveComma(valuesType.getShape(), s);
         s << "])";
         return (emitError(s.str()), nullptr);
       }
@@ -1035,7 +1035,7 @@
             nullptr);
   }
   default: {
-    if (Type *type = parseType())
+    if (Type type = parseType())
       return builder.getTypeAttr(type);
     return nullptr;
   }
@@ -1051,12 +1051,12 @@
 ///
 /// This method returns a constructed dense elements attribute with the shape
 /// from the parsing result.
-DenseElementsAttr Parser::parseDenseElementsAttr(Type *eltType, bool isVector) {
+DenseElementsAttr Parser::parseDenseElementsAttr(Type eltType, bool isVector) {
   TensorLiteralParser literalParser(*this, eltType);
   if (literalParser.parse())
     return nullptr;
 
-  VectorOrTensorType *type;
+  VectorOrTensorType type;
   if (isVector) {
     type = builder.getVectorType(literalParser.getShape(), eltType);
   } else {
@@ -1076,18 +1076,18 @@
 /// This method compares the shapes from the parsing result and that from the
 /// input argument. It returns a constructed dense elements attribute if both
 /// match.
-DenseElementsAttr Parser::parseDenseElementsAttr(VectorOrTensorType *type) {
-  auto *eltTy = type->getElementType();
+DenseElementsAttr Parser::parseDenseElementsAttr(VectorOrTensorType type) {
+  auto eltTy = type.getElementType();
   TensorLiteralParser literalParser(*this, eltTy);
   if (literalParser.parse())
     return nullptr;
-  if (literalParser.getShape() != type->getShape()) {
+  if (literalParser.getShape() != type.getShape()) {
     std::string str;
     llvm::raw_string_ostream s(str);
     s << "inferred shape of elements literal ([";
     interleaveComma(literalParser.getShape(), s);
     s << "]) does not match type ([";
-    interleaveComma(type->getShape(), s);
+    interleaveComma(type.getShape(), s);
     s << "])";
     return (emitError(s.str()), nullptr);
   }
@@ -1100,8 +1100,8 @@
 ///   vector-or-tensor-type ::= vector-type | tensor-type
 ///
 /// This method also checks the type has static shape and ranked.
-VectorOrTensorType *Parser::parseVectorOrTensorType() {
-  auto *type = dyn_cast<VectorOrTensorType>(parseType());
+VectorOrTensorType Parser::parseVectorOrTensorType() {
+  auto type = parseType().dyn_cast<VectorOrTensorType>();
   if (!type) {
     return (emitError("expected elements literal has a tensor or vector type"),
             nullptr);
@@ -1110,7 +1110,7 @@
   if (parseToken(Token::comma, "expected ','"))
     return nullptr;
 
-  if (!type->hasStaticShape() || type->getRank() == -1) {
+  if (!type.hasStaticShape() || type.getRank() == -1) {
     return (emitError("tensor literals must be ranked and have static shape"),
             nullptr);
   }
@@ -1834,7 +1834,7 @@
 
   /// Given a reference to an SSA value and its type, return a reference. This
   /// returns null on failure.
-  SSAValue *resolveSSAUse(SSAUseInfo useInfo, Type *type);
+  SSAValue *resolveSSAUse(SSAUseInfo useInfo, Type type);
 
   /// Register a definition of a value with the symbol table.
   ParseResult addDefinition(SSAUseInfo useInfo, SSAValue *value);
@@ -1845,11 +1845,11 @@
 
   template <typename ResultType>
   ResultType parseSSADefOrUseAndType(
-      const std::function<ResultType(SSAUseInfo, Type *)> &action);
+      const std::function<ResultType(SSAUseInfo, Type)> &action);
 
   SSAValue *parseSSAUseAndType() {
     return parseSSADefOrUseAndType<SSAValue *>(
-        [&](SSAUseInfo useInfo, Type *type) -> SSAValue * {
+        [&](SSAUseInfo useInfo, Type type) -> SSAValue * {
           return resolveSSAUse(useInfo, type);
         });
   }
@@ -1880,7 +1880,7 @@
   /// their first reference, to allow checking for use of undefined values.
   DenseMap<SSAValue *, SMLoc> forwardReferencePlaceholders;
 
-  SSAValue *createForwardReferencePlaceholder(SMLoc loc, Type *type);
+  SSAValue *createForwardReferencePlaceholder(SMLoc loc, Type type);
 
   /// Return true if this is a forward reference.
   bool isForwardReferencePlaceholder(SSAValue *value) {
@@ -1891,7 +1891,7 @@
 
 /// Create and remember a new placeholder for a forward reference.
 SSAValue *FunctionParser::createForwardReferencePlaceholder(SMLoc loc,
-                                                            Type *type) {
+                                                            Type type) {
   // Forward references are always created as instructions, even in ML
   // functions, because we just need something with a def/use chain.
   //
@@ -1908,7 +1908,7 @@
 
 /// Given an unbound reference to an SSA value and its type, return the value
 /// it specifies.  This returns null on failure.
-SSAValue *FunctionParser::resolveSSAUse(SSAUseInfo useInfo, Type *type) {
+SSAValue *FunctionParser::resolveSSAUse(SSAUseInfo useInfo, Type type) {
   auto &entries = values[useInfo.name];
 
   // If we have already seen a value of this name, return it.
@@ -2057,14 +2057,14 @@
 ///   ssa-use-and-type ::= ssa-use `:` type
 template <typename ResultType>
 ResultType FunctionParser::parseSSADefOrUseAndType(
-    const std::function<ResultType(SSAUseInfo, Type *)> &action) {
+    const std::function<ResultType(SSAUseInfo, Type)> &action) {
 
   SSAUseInfo useInfo;
   if (parseSSAUse(useInfo) ||
       parseToken(Token::colon, "expected ':' and type for SSA operand"))
     return nullptr;
 
-  auto *type = parseType();
+  auto type = parseType();
   if (!type)
     return nullptr;
 
@@ -2101,7 +2101,7 @@
   if (valueIDs.empty())
     return ParseSuccess;
 
-  SmallVector<Type *, 4> types;
+  SmallVector<Type, 4> types;
   if (parseToken(Token::colon, "expected ':' in operand list") ||
       parseTypeListNoParens(types))
     return ParseFailure;
@@ -2209,14 +2209,14 @@
   auto type = parseType();
   if (!type)
     return nullptr;
-  auto fnType = dyn_cast<FunctionType>(type);
+  auto fnType = type.dyn_cast<FunctionType>();
   if (!fnType)
     return (emitError(typeLoc, "expected function type"), nullptr);
 
-  result.addTypes(fnType->getResults());
+  result.addTypes(fnType.getResults());
 
   // Check that we have the right number of types for the operands.
-  auto operandTypes = fnType->getInputs();
+  auto operandTypes = fnType.getInputs();
   if (operandTypes.size() != operandInfos.size()) {
     auto plural = "s"[operandInfos.size() == 1];
     return (emitError(typeLoc, "expected " + llvm::utostr(operandInfos.size()) +
@@ -2253,17 +2253,17 @@
     return parser.parseToken(Token::comma, "expected ','");
   }
 
-  bool parseColonType(Type *&result) override {
+  bool parseColonType(Type &result) override {
     return parser.parseToken(Token::colon, "expected ':'") ||
            !(result = parser.parseType());
   }
 
-  bool parseColonTypeList(SmallVectorImpl<Type *> &result) override {
+  bool parseColonTypeList(SmallVectorImpl<Type> &result) override {
     if (parser.parseToken(Token::colon, "expected ':'"))
       return true;
 
     do {
-      if (auto *type = parser.parseType())
+      if (auto type = parser.parseType())
         result.push_back(type);
       else
         return true;
@@ -2273,7 +2273,7 @@
   }
 
   /// Parse a keyword followed by a type.
-  bool parseKeywordType(const char *keyword, Type *&result) override {
+  bool parseKeywordType(const char *keyword, Type &result) override {
     if (parser.getTokenSpelling() != keyword)
       return parser.emitError("expected '" + Twine(keyword) + "'");
     parser.consumeToken();
@@ -2396,7 +2396,7 @@
   }
 
   /// Resolve a parse function name and a type into a function reference.
-  virtual bool resolveFunctionName(StringRef name, FunctionType *type,
+  virtual bool resolveFunctionName(StringRef name, FunctionType type,
                                    llvm::SMLoc loc, Function *&result) {
     result = parser.resolveFunctionReference(name, loc, type);
     return result == nullptr;
@@ -2410,7 +2410,7 @@
 
   llvm::SMLoc getNameLoc() const override { return nameLoc; }
 
-  bool resolveOperand(const OperandType &operand, Type *type,
+  bool resolveOperand(const OperandType &operand, Type type,
                       SmallVectorImpl<SSAValue *> &result) override {
     FunctionParser::SSAUseInfo operandInfo = {operand.name, operand.number,
                                               operand.location};
@@ -2559,11 +2559,11 @@
     return ParseSuccess;
 
   return parseCommaSeparatedList([&]() -> ParseResult {
-    auto type = parseSSADefOrUseAndType<Type *>(
-        [&](SSAUseInfo useInfo, Type *type) -> Type * {
+    auto type = parseSSADefOrUseAndType<Type>(
+        [&](SSAUseInfo useInfo, Type type) -> Type {
           BBArgument *arg = owner->addArgument(type);
           if (addDefinition(useInfo, arg))
-            return nullptr;
+            return {};
           return type;
         });
     return type ? ParseSuccess : ParseFailure;
@@ -2908,7 +2908,7 @@
                      " symbol count must match");
 
   // Resolve SSA uses.
-  Type *indexType = builder.getIndexType();
+  Type indexType = builder.getIndexType();
   for (unsigned i = 0, e = opInfo.size(); i != e; ++i) {
     SSAValue *sval = resolveSSAUse(opInfo[i], indexType);
     if (!sval)
@@ -3187,9 +3187,9 @@
   ParseResult parseAffineStructureDef();
 
   // Functions.
-  ParseResult parseMLArgumentList(SmallVectorImpl<Type *> &argTypes,
+  ParseResult parseMLArgumentList(SmallVectorImpl<Type> &argTypes,
                                   SmallVectorImpl<StringRef> &argNames);
-  ParseResult parseFunctionSignature(StringRef &name, FunctionType *&type,
+  ParseResult parseFunctionSignature(StringRef &name, FunctionType &type,
                                      SmallVectorImpl<StringRef> *argNames);
   ParseResult parseFunctionAttribute(SmallVectorImpl<NamedAttribute> &attrs);
   ParseResult parseExtFunc();
@@ -3248,7 +3248,7 @@
 /// ml-argument-list ::= ml-argument (`,` ml-argument)* | /*empty*/
 ///
 ParseResult
-ModuleParser::parseMLArgumentList(SmallVectorImpl<Type *> &argTypes,
+ModuleParser::parseMLArgumentList(SmallVectorImpl<Type> &argTypes,
                                   SmallVectorImpl<StringRef> &argNames) {
   consumeToken(Token::l_paren);
 
@@ -3284,7 +3284,7 @@
 ///   type-list)?
 ///
 ParseResult
-ModuleParser::parseFunctionSignature(StringRef &name, FunctionType *&type,
+ModuleParser::parseFunctionSignature(StringRef &name, FunctionType &type,
                                      SmallVectorImpl<StringRef> *argNames) {
   if (getToken().isNot(Token::at_identifier))
     return emitError("expected a function identifier like '@foo'");
@@ -3295,7 +3295,7 @@
   if (getToken().isNot(Token::l_paren))
     return emitError("expected '(' in function signature");
 
-  SmallVector<Type *, 4> argTypes;
+  SmallVector<Type, 4> argTypes;
   ParseResult parseResult;
 
   if (argNames)
@@ -3307,7 +3307,7 @@
     return ParseFailure;
 
   // Parse the return type if present.
-  SmallVector<Type *, 4> results;
+  SmallVector<Type, 4> results;
   if (consumeIf(Token::arrow)) {
     if (parseTypeList(results))
       return ParseFailure;
@@ -3340,7 +3340,7 @@
   auto loc = getToken().getLoc();
 
   StringRef name;
-  FunctionType *type = nullptr;
+  FunctionType type;
   if (parseFunctionSignature(name, type, /*arguments*/ nullptr))
     return ParseFailure;
 
@@ -3372,7 +3372,7 @@
   auto loc = getToken().getLoc();
 
   StringRef name;
-  FunctionType *type = nullptr;
+  FunctionType type;
   if (parseFunctionSignature(name, type, /*arguments*/ nullptr))
     return ParseFailure;
 
@@ -3405,7 +3405,7 @@
   consumeToken(Token::kw_mlfunc);
 
   StringRef name;
-  FunctionType *type = nullptr;
+  FunctionType type;
   SmallVector<StringRef, 4> argNames;
 
   auto loc = getToken().getLoc();
diff --git a/lib/StandardOps/StandardOps.cpp b/lib/StandardOps/StandardOps.cpp
index b60d209..e2bdfd7 100644
--- a/lib/StandardOps/StandardOps.cpp
+++ b/lib/StandardOps/StandardOps.cpp
@@ -138,23 +138,23 @@
 //===----------------------------------------------------------------------===//
 
 void AllocOp::build(Builder *builder, OperationState *result,
-                    MemRefType *memrefType, ArrayRef<SSAValue *> operands) {
+                    MemRefType memrefType, ArrayRef<SSAValue *> operands) {
   result->addOperands(operands);
   result->types.push_back(memrefType);
 }
 
 void AllocOp::print(OpAsmPrinter *p) const {
-  MemRefType *type = getType();
+  MemRefType type = getType();
   *p << "alloc";
   // Print dynamic dimension operands.
   printDimAndSymbolList(operand_begin(), operand_end(),
-                        type->getNumDynamicDims(), p);
+                        type.getNumDynamicDims(), p);
   p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"map");
-  *p << " : " << *type;
+  *p << " : " << type;
 }
 
 bool AllocOp::parse(OpAsmParser *parser, OperationState *result) {
-  MemRefType *type;
+  MemRefType type;
 
   // Parse the dimension operands and optional symbol operands, followed by a
   // memref type.
@@ -170,7 +170,7 @@
   // Verification still checks that the total number of operands matches
   // the number of symbols in the affine map, plus the number of dynamic
   // dimensions in the memref.
-  if (numDimOperands != type->getNumDynamicDims()) {
+  if (numDimOperands != type.getNumDynamicDims()) {
     return parser->emitError(parser->getNameLoc(),
                              "dimension operand count does not equal memref "
                              "dynamic dimension count");
@@ -180,13 +180,13 @@
 }
 
 bool AllocOp::verify() const {
-  auto *memRefType = dyn_cast<MemRefType>(getResult()->getType());
+  auto memRefType = getResult()->getType().dyn_cast<MemRefType>();
   if (!memRefType)
     return emitOpError("result must be a memref");
 
   unsigned numSymbols = 0;
-  if (!memRefType->getAffineMaps().empty()) {
-    AffineMap affineMap = memRefType->getAffineMaps()[0];
+  if (!memRefType.getAffineMaps().empty()) {
+    AffineMap affineMap = memRefType.getAffineMaps()[0];
     // Store number of symbols used in affine map (used in subsequent check).
     numSymbols = affineMap.getNumSymbols();
     // TODO(zinenko): this check does not belong to AllocOp, or any other op but
@@ -195,10 +195,10 @@
     // Remove when we can emit errors directly from *Type::get(...) functions.
     //
     // Verify that the layout affine map matches the rank of the memref.
-    if (affineMap.getNumDims() != memRefType->getRank())
+    if (affineMap.getNumDims() != memRefType.getRank())
       return emitOpError("affine map dimension count must equal memref rank");
   }
-  unsigned numDynamicDims = memRefType->getNumDynamicDims();
+  unsigned numDynamicDims = memRefType.getNumDynamicDims();
   // Check that the total number of operands matches the number of symbols in
   // the affine map, plus the number of dynamic dimensions specified in the
   // memref type.
@@ -208,7 +208,7 @@
   }
   // Verify that all operands are of type Index.
   for (auto *operand : getOperands()) {
-    if (!operand->getType()->isIndex())
+    if (!operand->getType().isIndex())
       return emitOpError("requires operands to be of type Index");
   }
   return false;
@@ -239,13 +239,13 @@
     // Ok, we have one or more constant operands.  Collect the non-constant ones
     // and keep track of the resultant memref type to build.
     SmallVector<int, 4> newShapeConstants;
-    newShapeConstants.reserve(memrefType->getRank());
+    newShapeConstants.reserve(memrefType.getRank());
     SmallVector<SSAValue *, 4> newOperands;
     SmallVector<SSAValue *, 4> droppedOperands;
 
     unsigned dynamicDimPos = 0;
-    for (unsigned dim = 0, e = memrefType->getRank(); dim < e; ++dim) {
-      int dimSize = memrefType->getDimSize(dim);
+    for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) {
+      int dimSize = memrefType.getDimSize(dim);
       // If this is already static dimension, keep it.
       if (dimSize != -1) {
         newShapeConstants.push_back(dimSize);
@@ -267,10 +267,10 @@
     }
 
     // Create new memref type (which will have fewer dynamic dimensions).
-    auto *newMemRefType = MemRefType::get(
-        newShapeConstants, memrefType->getElementType(),
-        memrefType->getAffineMaps(), memrefType->getMemorySpace());
-    assert(newOperands.size() == newMemRefType->getNumDynamicDims());
+    auto newMemRefType = MemRefType::get(
+        newShapeConstants, memrefType.getElementType(),
+        memrefType.getAffineMaps(), memrefType.getMemorySpace());
+    assert(newOperands.size() == newMemRefType.getNumDynamicDims());
 
     // Create and insert the alloc op for the new memref.
     auto newAlloc =
@@ -297,13 +297,13 @@
                    ArrayRef<SSAValue *> operands) {
   result->addOperands(operands);
   result->addAttribute("callee", builder->getFunctionAttr(callee));
-  result->addTypes(callee->getType()->getResults());
+  result->addTypes(callee->getType().getResults());
 }
 
 bool CallOp::parse(OpAsmParser *parser, OperationState *result) {
   StringRef calleeName;
   llvm::SMLoc calleeLoc;
-  FunctionType *calleeType = nullptr;
+  FunctionType calleeType;
   SmallVector<OpAsmParser::OperandType, 4> operands;
   Function *callee = nullptr;
   if (parser->parseFunctionName(calleeName, calleeLoc) ||
@@ -312,8 +312,8 @@
       parser->parseOptionalAttributeDict(result->attributes) ||
       parser->parseColonType(calleeType) ||
       parser->resolveFunctionName(calleeName, calleeType, calleeLoc, callee) ||
-      parser->addTypesToList(calleeType->getResults(), result->types) ||
-      parser->resolveOperands(operands, calleeType->getInputs(), calleeLoc,
+      parser->addTypesToList(calleeType.getResults(), result->types) ||
+      parser->resolveOperands(operands, calleeType.getInputs(), calleeLoc,
                               result->operands))
     return true;
 
@@ -328,7 +328,7 @@
   p->printOperands(getOperands());
   *p << ')';
   p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"callee");
-  *p << " : " << *getCallee()->getType();
+  *p << " : " << getCallee()->getType();
 }
 
 bool CallOp::verify() const {
@@ -338,20 +338,20 @@
     return emitOpError("requires a 'callee' function attribute");
 
   // Verify that the operand and result types match the callee.
-  auto *fnType = fnAttr.getValue()->getType();
-  if (fnType->getNumInputs() != getNumOperands())
+  auto fnType = fnAttr.getValue()->getType();
+  if (fnType.getNumInputs() != getNumOperands())
     return emitOpError("incorrect number of operands for callee");
 
-  for (unsigned i = 0, e = fnType->getNumInputs(); i != e; ++i) {
-    if (getOperand(i)->getType() != fnType->getInput(i))
+  for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) {
+    if (getOperand(i)->getType() != fnType.getInput(i))
       return emitOpError("operand type mismatch");
   }
 
-  if (fnType->getNumResults() != getNumResults())
+  if (fnType.getNumResults() != getNumResults())
     return emitOpError("incorrect number of results for callee");
 
-  for (unsigned i = 0, e = fnType->getNumResults(); i != e; ++i) {
-    if (getResult(i)->getType() != fnType->getResult(i))
+  for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) {
+    if (getResult(i)->getType() != fnType.getResult(i))
       return emitOpError("result type mismatch");
   }
 
@@ -364,14 +364,14 @@
 
 void CallIndirectOp::build(Builder *builder, OperationState *result,
                            SSAValue *callee, ArrayRef<SSAValue *> operands) {
-  auto *fnType = cast<FunctionType>(callee->getType());
+  auto fnType = callee->getType().cast<FunctionType>();
   result->operands.push_back(callee);
   result->addOperands(operands);
-  result->addTypes(fnType->getResults());
+  result->addTypes(fnType.getResults());
 }
 
 bool CallIndirectOp::parse(OpAsmParser *parser, OperationState *result) {
-  FunctionType *calleeType = nullptr;
+  FunctionType calleeType;
   OpAsmParser::OperandType callee;
   llvm::SMLoc operandsLoc;
   SmallVector<OpAsmParser::OperandType, 4> operands;
@@ -382,9 +382,9 @@
          parser->parseOptionalAttributeDict(result->attributes) ||
          parser->parseColonType(calleeType) ||
          parser->resolveOperand(callee, calleeType, result->operands) ||
-         parser->resolveOperands(operands, calleeType->getInputs(), operandsLoc,
+         parser->resolveOperands(operands, calleeType.getInputs(), operandsLoc,
                                  result->operands) ||
-         parser->addTypesToList(calleeType->getResults(), result->types);
+         parser->addTypesToList(calleeType.getResults(), result->types);
 }
 
 void CallIndirectOp::print(OpAsmPrinter *p) const {
@@ -395,29 +395,29 @@
   p->printOperands(++operandRange.begin(), operandRange.end());
   *p << ')';
   p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"callee");
-  *p << " : " << *getCallee()->getType();
+  *p << " : " << getCallee()->getType();
 }
 
 bool CallIndirectOp::verify() const {
   // The callee must be a function.
-  auto *fnType = dyn_cast<FunctionType>(getCallee()->getType());
+  auto fnType = getCallee()->getType().dyn_cast<FunctionType>();
   if (!fnType)
     return emitOpError("callee must have function type");
 
   // Verify that the operand and result types match the callee.
-  if (fnType->getNumInputs() != getNumOperands() - 1)
+  if (fnType.getNumInputs() != getNumOperands() - 1)
     return emitOpError("incorrect number of operands for callee");
 
-  for (unsigned i = 0, e = fnType->getNumInputs(); i != e; ++i) {
-    if (getOperand(i + 1)->getType() != fnType->getInput(i))
+  for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) {
+    if (getOperand(i + 1)->getType() != fnType.getInput(i))
       return emitOpError("operand type mismatch");
   }
 
-  if (fnType->getNumResults() != getNumResults())
+  if (fnType.getNumResults() != getNumResults())
     return emitOpError("incorrect number of results for callee");
 
-  for (unsigned i = 0, e = fnType->getNumResults(); i != e; ++i) {
-    if (getResult(i)->getType() != fnType->getResult(i))
+  for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) {
+    if (getResult(i)->getType() != fnType.getResult(i))
       return emitOpError("result type mismatch");
   }
 
@@ -434,19 +434,19 @@
 }
 
 void DeallocOp::print(OpAsmPrinter *p) const {
-  *p << "dealloc " << *getMemRef() << " : " << *getMemRef()->getType();
+  *p << "dealloc " << *getMemRef() << " : " << getMemRef()->getType();
 }
 
 bool DeallocOp::parse(OpAsmParser *parser, OperationState *result) {
   OpAsmParser::OperandType memrefInfo;
-  MemRefType *type;
+  MemRefType type;
 
   return parser->parseOperand(memrefInfo) || parser->parseColonType(type) ||
          parser->resolveOperand(memrefInfo, type, result->operands);
 }
 
 bool DeallocOp::verify() const {
-  if (!isa<MemRefType>(getMemRef()->getType()))
+  if (!getMemRef()->getType().isa<MemRefType>())
     return emitOpError("operand must be a memref");
   return false;
 }
@@ -472,13 +472,13 @@
 void DimOp::print(OpAsmPrinter *p) const {
   *p << "dim " << *getOperand() << ", " << getIndex();
   p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"index");
-  *p << " : " << *getOperand()->getType();
+  *p << " : " << getOperand()->getType();
 }
 
 bool DimOp::parse(OpAsmParser *parser, OperationState *result) {
   OpAsmParser::OperandType operandInfo;
   IntegerAttr indexAttr;
-  Type *type;
+  Type type;
 
   return parser->parseOperand(operandInfo) || parser->parseComma() ||
          parser->parseAttribute(indexAttr, "index", result->attributes) ||
@@ -496,15 +496,15 @@
     return emitOpError("requires an integer attribute named 'index'");
   uint64_t index = (uint64_t)indexAttr.getValue();
 
-  auto *type = getOperand()->getType();
-  if (auto *tensorType = dyn_cast<RankedTensorType>(type)) {
-    if (index >= tensorType->getRank())
+  auto type = getOperand()->getType();
+  if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
+    if (index >= tensorType.getRank())
       return emitOpError("index is out of range");
-  } else if (auto *memrefType = dyn_cast<MemRefType>(type)) {
-    if (index >= memrefType->getRank())
+  } else if (auto memrefType = type.dyn_cast<MemRefType>()) {
+    if (index >= memrefType.getRank())
       return emitOpError("index is out of range");
 
-  } else if (isa<UnrankedTensorType>(type)) {
+  } else if (type.isa<UnrankedTensorType>()) {
     // ok, assumed to be in-range.
   } else {
     return emitOpError("requires an operand with tensor or memref type");
@@ -516,12 +516,12 @@
 Attribute DimOp::constantFold(ArrayRef<Attribute> operands,
                               MLIRContext *context) const {
   // Constant fold dim when the size along the index referred to is a constant.
-  auto *opType = getOperand()->getType();
+  auto opType = getOperand()->getType();
   int indexSize = -1;
-  if (auto *tensorType = dyn_cast<RankedTensorType>(opType)) {
-    indexSize = tensorType->getShape()[getIndex()];
-  } else if (auto *memrefType = dyn_cast<MemRefType>(opType)) {
-    indexSize = memrefType->getShape()[getIndex()];
+  if (auto tensorType = opType.dyn_cast<RankedTensorType>()) {
+    indexSize = tensorType.getShape()[getIndex()];
+  } else if (auto memrefType = opType.dyn_cast<MemRefType>()) {
+    indexSize = memrefType.getShape()[getIndex()];
   }
 
   if (indexSize >= 0)
@@ -544,9 +544,9 @@
   p->printOperands(getTagIndices());
   *p << ']';
   p->printOptionalAttrDict(getAttrs());
-  *p << " : " << *getSrcMemRef()->getType();
-  *p << ", " << *getDstMemRef()->getType();
-  *p << ", " << *getTagMemRef()->getType();
+  *p << " : " << getSrcMemRef()->getType();
+  *p << ", " << getDstMemRef()->getType();
+  *p << ", " << getTagMemRef()->getType();
 }
 
 // Parse DmaStartOp.
@@ -566,8 +566,8 @@
   OpAsmParser::OperandType tagMemrefInfo;
   SmallVector<OpAsmParser::OperandType, 4> tagIndexInfos;
 
-  SmallVector<Type *, 3> types;
-  auto *indexType = parser->getBuilder().getIndexType();
+  SmallVector<Type, 3> types;
+  auto indexType = parser->getBuilder().getIndexType();
 
   // Parse and resolve the following list of operands:
   // *) source memref followed by its indices (in square brackets).
@@ -601,12 +601,12 @@
     return true;
 
   // Check that source/destination index list size matches associated rank.
-  if (srcIndexInfos.size() != cast<MemRefType>(types[0])->getRank() ||
-      dstIndexInfos.size() != cast<MemRefType>(types[1])->getRank())
+  if (srcIndexInfos.size() != types[0].cast<MemRefType>().getRank() ||
+      dstIndexInfos.size() != types[1].cast<MemRefType>().getRank())
     return parser->emitError(parser->getNameLoc(),
                              "memref rank not equal to indices count");
 
-  if (tagIndexInfos.size() != cast<MemRefType>(types[2])->getRank())
+  if (tagIndexInfos.size() != types[2].cast<MemRefType>().getRank())
     return parser->emitError(parser->getNameLoc(),
                              "tag memref rank not equal to indices count");
 
@@ -632,7 +632,7 @@
   p->printOperands(getTagIndices());
   *p << "], ";
   p->printOperand(getNumElements());
-  *p << " : " << *getTagMemRef()->getType();
+  *p << " : " << getTagMemRef()->getType();
 }
 
 // Parse DmaWaitOp.
@@ -642,8 +642,8 @@
 bool DmaWaitOp::parse(OpAsmParser *parser, OperationState *result) {
   OpAsmParser::OperandType tagMemrefInfo;
   SmallVector<OpAsmParser::OperandType, 2> tagIndexInfos;
-  Type *type;
-  auto *indexType = parser->getBuilder().getIndexType();
+  Type type;
+  auto indexType = parser->getBuilder().getIndexType();
   OpAsmParser::OperandType numElementsInfo;
 
   // Parse tag memref, its indices, and dma size.
@@ -657,7 +657,7 @@
       parser->resolveOperand(numElementsInfo, indexType, result->operands))
     return true;
 
-  if (tagIndexInfos.size() != cast<MemRefType>(type)->getRank())
+  if (tagIndexInfos.size() != type.cast<MemRefType>().getRank())
     return parser->emitError(parser->getNameLoc(),
                              "tag memref rank not equal to indices count");
 
@@ -678,10 +678,10 @@
 void ExtractElementOp::build(Builder *builder, OperationState *result,
                              SSAValue *aggregate,
                              ArrayRef<SSAValue *> indices) {
-  auto *aggregateType = cast<VectorOrTensorType>(aggregate->getType());
+  auto aggregateType = aggregate->getType().cast<VectorOrTensorType>();
   result->addOperands(aggregate);
   result->addOperands(indices);
-  result->types.push_back(aggregateType->getElementType());
+  result->types.push_back(aggregateType.getElementType());
 }
 
 void ExtractElementOp::print(OpAsmPrinter *p) const {
@@ -689,13 +689,13 @@
   p->printOperands(getIndices());
   *p << ']';
   p->printOptionalAttrDict(getAttrs());
-  *p << " : " << *getAggregate()->getType();
+  *p << " : " << getAggregate()->getType();
 }
 
 bool ExtractElementOp::parse(OpAsmParser *parser, OperationState *result) {
   OpAsmParser::OperandType aggregateInfo;
   SmallVector<OpAsmParser::OperandType, 4> indexInfo;
-  VectorOrTensorType *type;
+  VectorOrTensorType type;
 
   auto affineIntTy = parser->getBuilder().getIndexType();
   return parser->parseOperand(aggregateInfo) ||
@@ -705,26 +705,26 @@
          parser->parseColonType(type) ||
          parser->resolveOperand(aggregateInfo, type, result->operands) ||
          parser->resolveOperands(indexInfo, affineIntTy, result->operands) ||
-         parser->addTypeToList(type->getElementType(), result->types);
+         parser->addTypeToList(type.getElementType(), result->types);
 }
 
 bool ExtractElementOp::verify() const {
   if (getNumOperands() == 0)
     return emitOpError("expected an aggregate to index into");
 
-  auto *aggregateType = dyn_cast<VectorOrTensorType>(getAggregate()->getType());
+  auto aggregateType = getAggregate()->getType().dyn_cast<VectorOrTensorType>();
   if (!aggregateType)
     return emitOpError("first operand must be a vector or tensor");
 
-  if (getType() != aggregateType->getElementType())
+  if (getType() != aggregateType.getElementType())
     return emitOpError("result type must match element type of aggregate");
 
   for (auto *idx : getIndices())
-    if (!idx->getType()->isIndex())
+    if (!idx->getType().isIndex())
       return emitOpError("index to extract_element must have 'index' type");
 
   // Verify the # indices match if we have a ranked type.
-  auto aggregateRank = aggregateType->getRank();
+  auto aggregateRank = aggregateType.getRank();
   if (aggregateRank != -1 && aggregateRank != getNumOperands() - 1)
     return emitOpError("incorrect number of indices for extract_element");
 
@@ -737,10 +737,10 @@
 
 void LoadOp::build(Builder *builder, OperationState *result, SSAValue *memref,
                    ArrayRef<SSAValue *> indices) {
-  auto *memrefType = cast<MemRefType>(memref->getType());
+  auto memrefType = memref->getType().cast<MemRefType>();
   result->addOperands(memref);
   result->addOperands(indices);
-  result->types.push_back(memrefType->getElementType());
+  result->types.push_back(memrefType.getElementType());
 }
 
 void LoadOp::print(OpAsmPrinter *p) const {
@@ -748,13 +748,13 @@
   p->printOperands(getIndices());
   *p << ']';
   p->printOptionalAttrDict(getAttrs());
-  *p << " : " << *getMemRefType();
+  *p << " : " << getMemRefType();
 }
 
 bool LoadOp::parse(OpAsmParser *parser, OperationState *result) {
   OpAsmParser::OperandType memrefInfo;
   SmallVector<OpAsmParser::OperandType, 4> indexInfo;
-  MemRefType *type;
+  MemRefType type;
 
   auto affineIntTy = parser->getBuilder().getIndexType();
   return parser->parseOperand(memrefInfo) ||
@@ -764,25 +764,25 @@
          parser->parseColonType(type) ||
          parser->resolveOperand(memrefInfo, type, result->operands) ||
          parser->resolveOperands(indexInfo, affineIntTy, result->operands) ||
-         parser->addTypeToList(type->getElementType(), result->types);
+         parser->addTypeToList(type.getElementType(), result->types);
 }
 
 bool LoadOp::verify() const {
   if (getNumOperands() == 0)
     return emitOpError("expected a memref to load from");
 
-  auto *memRefType = dyn_cast<MemRefType>(getMemRef()->getType());
+  auto memRefType = getMemRef()->getType().dyn_cast<MemRefType>();
   if (!memRefType)
     return emitOpError("first operand must be a memref");
 
-  if (getType() != memRefType->getElementType())
+  if (getType() != memRefType.getElementType())
     return emitOpError("result type must match element type of memref");
 
-  if (memRefType->getRank() != getNumOperands() - 1)
+  if (memRefType.getRank() != getNumOperands() - 1)
     return emitOpError("incorrect number of indices for load");
 
   for (auto *idx : getIndices())
-    if (!idx->getType()->isIndex())
+    if (!idx->getType().isIndex())
       return emitOpError("index to load must have 'index' type");
 
   // TODO: Verify we have the right number of indices.
@@ -804,31 +804,31 @@
 //===----------------------------------------------------------------------===//
 
 bool MemRefCastOp::verify() const {
-  auto *opType = dyn_cast<MemRefType>(getOperand()->getType());
-  auto *resType = dyn_cast<MemRefType>(getType());
+  auto opType = getOperand()->getType().dyn_cast<MemRefType>();
+  auto resType = getType().dyn_cast<MemRefType>();
   if (!opType || !resType)
     return emitOpError("requires input and result types to be memrefs");
 
   if (opType == resType)
     return emitOpError("requires the input and result type to be different");
 
-  if (opType->getElementType() != resType->getElementType())
+  if (opType.getElementType() != resType.getElementType())
     return emitOpError(
         "requires input and result element types to be the same");
 
-  if (opType->getAffineMaps() != resType->getAffineMaps())
+  if (opType.getAffineMaps() != resType.getAffineMaps())
     return emitOpError("requires input and result mappings to be the same");
 
-  if (opType->getMemorySpace() != resType->getMemorySpace())
+  if (opType.getMemorySpace() != resType.getMemorySpace())
     return emitOpError(
         "requires input and result memory spaces to be the same");
 
   // They must have the same rank, and any specified dimensions must match.
-  if (opType->getRank() != resType->getRank())
+  if (opType.getRank() != resType.getRank())
     return emitOpError("requires input and result ranks to match");
 
-  for (unsigned i = 0, e = opType->getRank(); i != e; ++i) {
-    int opDim = opType->getDimSize(i), resultDim = resType->getDimSize(i);
+  for (unsigned i = 0, e = opType.getRank(); i != e; ++i) {
+    int opDim = opType.getDimSize(i), resultDim = resType.getDimSize(i);
     if (opDim != -1 && resultDim != -1 && opDim != resultDim)
       return emitOpError("requires static dimensions to match");
   }
@@ -923,14 +923,14 @@
   p->printOperands(getIndices());
   *p << ']';
   p->printOptionalAttrDict(getAttrs());
-  *p << " : " << *getMemRefType();
+  *p << " : " << getMemRefType();
 }
 
 bool StoreOp::parse(OpAsmParser *parser, OperationState *result) {
   OpAsmParser::OperandType storeValueInfo;
   OpAsmParser::OperandType memrefInfo;
   SmallVector<OpAsmParser::OperandType, 4> indexInfo;
-  MemRefType *memrefType;
+  MemRefType memrefType;
 
   auto affineIntTy = parser->getBuilder().getIndexType();
   return parser->parseOperand(storeValueInfo) || parser->parseComma() ||
@@ -939,7 +939,7 @@
                                   OpAsmParser::Delimiter::Square) ||
          parser->parseOptionalAttributeDict(result->attributes) ||
          parser->parseColonType(memrefType) ||
-         parser->resolveOperand(storeValueInfo, memrefType->getElementType(),
+         parser->resolveOperand(storeValueInfo, memrefType.getElementType(),
                                 result->operands) ||
          parser->resolveOperand(memrefInfo, memrefType, result->operands) ||
          parser->resolveOperands(indexInfo, affineIntTy, result->operands);
@@ -950,19 +950,19 @@
     return emitOpError("expected a value to store and a memref");
 
   // Second operand is a memref type.
-  auto *memRefType = dyn_cast<MemRefType>(getMemRef()->getType());
+  auto memRefType = getMemRef()->getType().dyn_cast<MemRefType>();
   if (!memRefType)
     return emitOpError("second operand must be a memref");
 
   // First operand must have same type as memref element type.
-  if (getValueToStore()->getType() != memRefType->getElementType())
+  if (getValueToStore()->getType() != memRefType.getElementType())
     return emitOpError("first operand must have same type memref element type");
 
-  if (getNumOperands() != 2 + memRefType->getRank())
+  if (getNumOperands() != 2 + memRefType.getRank())
     return emitOpError("store index operand count not equal to memref rank");
 
   for (auto *idx : getIndices())
-    if (!idx->getType()->isIndex())
+    if (!idx->getType().isIndex())
       return emitOpError("index to load must have 'index' type");
 
   // TODO: Verify we have the right number of indices.
@@ -1046,31 +1046,31 @@
 //===----------------------------------------------------------------------===//
 
 bool TensorCastOp::verify() const {
-  auto *opType = dyn_cast<TensorType>(getOperand()->getType());
-  auto *resType = dyn_cast<TensorType>(getType());
+  auto opType = getOperand()->getType().dyn_cast<TensorType>();
+  auto resType = getType().dyn_cast<TensorType>();
   if (!opType || !resType)
     return emitOpError("requires input and result types to be tensors");
 
   if (opType == resType)
     return emitOpError("requires the input and result type to be different");
 
-  if (opType->getElementType() != resType->getElementType())
+  if (opType.getElementType() != resType.getElementType())
     return emitOpError(
         "requires input and result element types to be the same");
 
   // If the source or destination are unranked, then the cast is valid.
-  auto *opRType = dyn_cast<RankedTensorType>(opType);
-  auto *resRType = dyn_cast<RankedTensorType>(resType);
+  auto opRType = opType.dyn_cast<RankedTensorType>();
+  auto resRType = resType.dyn_cast<RankedTensorType>();
   if (!opRType || !resRType)
     return false;
 
   // If they are both ranked, they have to have the same rank, and any specified
   // dimensions must match.
-  if (opRType->getRank() != resRType->getRank())
+  if (opRType.getRank() != resRType.getRank())
     return emitOpError("requires input and result ranks to match");
 
-  for (unsigned i = 0, e = opRType->getRank(); i != e; ++i) {
-    int opDim = opRType->getDimSize(i), resultDim = resRType->getDimSize(i);
+  for (unsigned i = 0, e = opRType.getRank(); i != e; ++i) {
+    int opDim = opRType.getDimSize(i), resultDim = resRType.getDimSize(i);
     if (opDim != -1 && resultDim != -1 && opDim != resultDim)
       return emitOpError("requires static dimensions to match");
   }
diff --git a/lib/Transforms/ConstantFold.cpp b/lib/Transforms/ConstantFold.cpp
index 81994dd..15dd89b 100644
--- a/lib/Transforms/ConstantFold.cpp
+++ b/lib/Transforms/ConstantFold.cpp
@@ -31,7 +31,7 @@
   SmallVector<SSAValue *, 8> existingConstants;
   // Operation statements that were folded and that need to be erased.
   std::vector<OperationStmt *> opStmtsToErase;
-  using ConstantFactoryType = std::function<SSAValue *(Attribute, Type *)>;
+  using ConstantFactoryType = std::function<SSAValue *(Attribute, Type)>;
 
   bool foldOperation(Operation *op,
                      SmallVectorImpl<SSAValue *> &existingConstants,
@@ -106,7 +106,7 @@
     for (auto instIt = bb.begin(), e = bb.end(); instIt != e;) {
       auto &inst = *instIt++;
 
-      auto constantFactory = [&](Attribute value, Type *type) -> SSAValue * {
+      auto constantFactory = [&](Attribute value, Type type) -> SSAValue * {
         builder.setInsertionPoint(&inst);
         return builder.create<ConstantOp>(inst.getLoc(), value, type);
       };
@@ -134,7 +134,7 @@
 
 // Override the walker's operation statement visit for constant folding.
 void ConstantFold::visitOperationStmt(OperationStmt *stmt) {
-  auto constantFactory = [&](Attribute value, Type *type) -> SSAValue * {
+  auto constantFactory = [&](Attribute value, Type type) -> SSAValue * {
     MLFuncBuilder builder(stmt);
     return builder.create<ConstantOp>(stmt->getLoc(), value, type);
   };
diff --git a/lib/Transforms/PipelineDataTransfer.cpp b/lib/Transforms/PipelineDataTransfer.cpp
index d96d65b..9042181 100644
--- a/lib/Transforms/PipelineDataTransfer.cpp
+++ b/lib/Transforms/PipelineDataTransfer.cpp
@@ -77,23 +77,23 @@
   bInner.setInsertionPoint(forStmt, forStmt->begin());
 
   // Doubles the shape with a leading dimension extent of 2.
-  auto doubleShape = [&](MemRefType *oldMemRefType) -> MemRefType * {
+  auto doubleShape = [&](MemRefType oldMemRefType) -> MemRefType {
     // Add the leading dimension in the shape for the double buffer.
-    ArrayRef<int> shape = oldMemRefType->getShape();
+    ArrayRef<int> shape = oldMemRefType.getShape();
     SmallVector<int, 4> shapeSizes(shape.begin(), shape.end());
     shapeSizes.insert(shapeSizes.begin(), 2);
 
-    auto *newMemRefType =
-        bInner.getMemRefType(shapeSizes, oldMemRefType->getElementType(), {},
-                             oldMemRefType->getMemorySpace());
+    auto newMemRefType =
+        bInner.getMemRefType(shapeSizes, oldMemRefType.getElementType(), {},
+                             oldMemRefType.getMemorySpace());
     return newMemRefType;
   };
 
-  auto *newMemRefType = doubleShape(cast<MemRefType>(oldMemRef->getType()));
+  auto newMemRefType = doubleShape(oldMemRef->getType().cast<MemRefType>());
 
   // Create and place the alloc at the top level.
   MLFuncBuilder topBuilder(forStmt->getFunction());
-  auto *newMemRef = cast<MLValue>(
+  auto newMemRef = cast<MLValue>(
       topBuilder.create<AllocOp>(forStmt->getLoc(), newMemRefType)
           ->getResult());
 
diff --git a/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index cdf5b71..4ec8942 100644
--- a/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -78,7 +78,7 @@
   /// As part of canonicalization, we move constants to the top of the entry
   /// block of the current function and de-duplicate them.  This keeps track of
   /// constants we have done this for.
-  DenseMap<std::pair<Attribute, Type *>, Operation *> uniquedConstants;
+  DenseMap<std::pair<Attribute, Type>, Operation *> uniquedConstants;
 };
 }; // end anonymous namespace
 
diff --git a/lib/Transforms/Utils/Utils.cpp b/lib/Transforms/Utils/Utils.cpp
index edd8ce8..ad9d6dc 100644
--- a/lib/Transforms/Utils/Utils.cpp
+++ b/lib/Transforms/Utils/Utils.cpp
@@ -52,9 +52,9 @@
                                     MLValue *newMemRef,
                                     ArrayRef<MLValue *> extraIndices,
                                     AffineMap indexRemap) {
-  unsigned newMemRefRank = cast<MemRefType>(newMemRef->getType())->getRank();
+  unsigned newMemRefRank = newMemRef->getType().cast<MemRefType>().getRank();
   (void)newMemRefRank; // unused in opt mode
-  unsigned oldMemRefRank = cast<MemRefType>(oldMemRef->getType())->getRank();
+  unsigned oldMemRefRank = oldMemRef->getType().cast<MemRefType>().getRank();
   (void)newMemRefRank;
   if (indexRemap) {
     assert(indexRemap.getNumInputs() == oldMemRefRank);
@@ -64,8 +64,8 @@
   }
 
   // Assert same elemental type.
-  assert(cast<MemRefType>(oldMemRef->getType())->getElementType() ==
-         cast<MemRefType>(newMemRef->getType())->getElementType());
+  assert(oldMemRef->getType().cast<MemRefType>().getElementType() ==
+         newMemRef->getType().cast<MemRefType>().getElementType());
 
   // Check if memref was used in a non-deferencing context.
   for (const StmtOperand &use : oldMemRef->getUses()) {
@@ -139,7 +139,7 @@
                     opStmt->operand_end());
 
     // Result types don't change. Both memref's are of the same elemental type.
-    SmallVector<Type *, 8> resultTypes;
+    SmallVector<Type, 8> resultTypes;
     resultTypes.reserve(opStmt->getNumResults());
     for (const auto *result : opStmt->getResults())
       resultTypes.push_back(result->getType());
diff --git a/lib/Transforms/Vectorize.cpp b/lib/Transforms/Vectorize.cpp
index d7a1f53..511afa9 100644
--- a/lib/Transforms/Vectorize.cpp
+++ b/lib/Transforms/Vectorize.cpp
@@ -202,15 +202,15 @@
 /// sizes specified by vectorSize. The MemRef lives in the same memory space as
 /// tmpl. The MemRef should be promoted to a closer memory address space in a
 /// later pass.
-static MemRefType *getVectorizedMemRefType(MemRefType *tmpl,
-                                           ArrayRef<int> vectorSizes) {
-  auto *elementType = tmpl->getElementType();
-  assert(!dyn_cast<VectorType>(elementType) &&
+static MemRefType getVectorizedMemRefType(MemRefType tmpl,
+                                          ArrayRef<int> vectorSizes) {
+  auto elementType = tmpl.getElementType();
+  assert(!elementType.dyn_cast<VectorType>() &&
          "Can't vectorize an already vector type");
-  assert(tmpl->getAffineMaps().empty() &&
+  assert(tmpl.getAffineMaps().empty() &&
          "Unsupported non-implicit identity map");
   return MemRefType::get({1}, VectorType::get(vectorSizes, elementType), {},
-                         tmpl->getMemorySpace());
+                         tmpl.getMemorySpace());
 }
 
 /// Creates an unaligned load with the following semantics:
@@ -258,7 +258,7 @@
   operands.insert(operands.end(), dstMemRef);
   operands.insert(operands.end(), dstIndices.begin(), dstIndices.end());
   using functional::map;
-  std::function<Type *(SSAValue *)> getType = [](SSAValue *v) -> Type * {
+  std::function<Type(SSAValue *)> getType = [](SSAValue *v) -> Type {
     return v->getType();
   };
   auto types = map(getType, operands);
@@ -310,7 +310,7 @@
   operands.insert(operands.end(), dstMemRef);
   operands.insert(operands.end(), dstIndices.begin(), dstIndices.end());
   using functional::map;
-  std::function<Type *(SSAValue *)> getType = [](SSAValue *v) -> Type * {
+  std::function<Type(SSAValue *)> getType = [](SSAValue *v) -> Type {
     return v->getType();
   };
   auto types = map(getType, operands);
@@ -348,8 +348,9 @@
 template <typename LoadOrStoreOpPointer>
 static MLValue *materializeVector(MLValue *iv, LoadOrStoreOpPointer memoryOp,
                                   ArrayRef<int> vectorSize) {
-  auto *memRefType = cast<MemRefType>(memoryOp->getMemRef()->getType());
-  auto *vectorMemRefType = getVectorizedMemRefType(memRefType, vectorSize);
+  auto memRefType =
+      memoryOp->getMemRef()->getType().template cast<MemRefType>();
+  auto vectorMemRefType = getVectorizedMemRefType(memRefType, vectorSize);
 
   // Materialize a MemRef with 1 vector.
   auto *opStmt = cast<OperationStmt>(memoryOp->getOperation());