Implement value type abstraction for types.

This is done by changing Type to be a POD interface around an underlying pointer storage and adding in-class support for isa/dyn_cast/cast.

PiperOrigin-RevId: 219372163
diff --git a/include/mlir/Analysis/LoopAnalysis.h b/include/mlir/Analysis/LoopAnalysis.h
index 7f6d799..6820ee8 100644
--- a/include/mlir/Analysis/LoopAnalysis.h
+++ b/include/mlir/Analysis/LoopAnalysis.h
@@ -51,7 +51,7 @@
 /// whether indices[dim] is independent of the value `input`.
 // For now we assume no layout map or identity layout map in the MemRef.
 // TODO(ntv): support more than identity layout map.
-bool isAccessInvariant(const MLValue &input, MemRefType *memRefType,
+bool isAccessInvariant(const MLValue &input, MemRefType memRefType,
                        llvm::ArrayRef<MLValue *> indices, unsigned dim);
 
 /// Checks whether all the LoadOp and StoreOp matched have access indexing
diff --git a/include/mlir/IR/Attributes.h b/include/mlir/IR/Attributes.h
index b84d20f..7c30397 100644
--- a/include/mlir/IR/Attributes.h
+++ b/include/mlir/IR/Attributes.h
@@ -250,9 +250,9 @@
   TypeAttr() = default;
   /* implicit */ TypeAttr(Attribute::ImplType *ptr);
 
-  static TypeAttr get(Type *type, MLIRContext *context);
+  static TypeAttr get(Type type, MLIRContext *context);
 
-  Type *getValue() const;
+  Type getValue() const;
 
   /// Methods for support type inquiry through isa, cast, and dyn_cast.
   static bool kindof(Kind kind) { return kind == Kind::Type; }
@@ -277,7 +277,7 @@
 
   Function *getValue() const;
 
-  FunctionType *getType() const;
+  FunctionType getType() const;
 
   /// Methods for support type inquiry through isa, cast, and dyn_cast.
   static bool kindof(Kind kind) { return kind == Kind::Function; }
@@ -294,7 +294,7 @@
   ElementsAttr() = default;
   /* implicit */ ElementsAttr(Attribute::ImplType *ptr);
 
-  VectorOrTensorType *getType() const;
+  VectorOrTensorType getType() const;
 
   /// Method for support type inquiry through isa, cast and dyn_cast.
   static bool kindof(Kind kind) {
@@ -313,7 +313,7 @@
   SplatElementsAttr() = default;
   /* implicit */ SplatElementsAttr(Attribute::ImplType *ptr);
 
-  static SplatElementsAttr get(VectorOrTensorType *type, Attribute elt);
+  static SplatElementsAttr get(VectorOrTensorType type, Attribute elt);
   Attribute getValue() const;
 
   /// Method for support type inquiry through isa, cast and dyn_cast.
@@ -335,12 +335,12 @@
   /// width specified by the element type (note all float type are 64 bits).
   /// When the value is retrieved, the bits are read from the storage and extend
   /// to 64 bits if necessary.
-  static DenseElementsAttr get(VectorOrTensorType *type, ArrayRef<char> data);
+  static DenseElementsAttr get(VectorOrTensorType type, ArrayRef<char> data);
 
   // TODO: Read the data from the attribute list and compress them
   // to a character array. Then call the above method to construct the
   // attribute.
-  static DenseElementsAttr get(VectorOrTensorType *type,
+  static DenseElementsAttr get(VectorOrTensorType type,
                                ArrayRef<Attribute> values);
 
   void getValues(SmallVectorImpl<Attribute> &values) const;
@@ -410,7 +410,7 @@
   OpaqueElementsAttr() = default;
   /* implicit */ OpaqueElementsAttr(Attribute::ImplType *ptr);
 
-  static OpaqueElementsAttr get(VectorOrTensorType *type, StringRef bytes);
+  static OpaqueElementsAttr get(VectorOrTensorType type, StringRef bytes);
 
   StringRef getValue() const;
 
@@ -440,7 +440,7 @@
   SparseElementsAttr() = default;
   /* implicit */ SparseElementsAttr(Attribute::ImplType *ptr);
 
-  static SparseElementsAttr get(VectorOrTensorType *type,
+  static SparseElementsAttr get(VectorOrTensorType type,
                                 DenseIntElementsAttr indices,
                                 DenseElementsAttr values);
 
diff --git a/include/mlir/IR/BasicBlock.h b/include/mlir/IR/BasicBlock.h
index c55d09c..cfae6af 100644
--- a/include/mlir/IR/BasicBlock.h
+++ b/include/mlir/IR/BasicBlock.h
@@ -64,10 +64,10 @@
   bool args_empty() const { return arguments.empty(); }
 
   /// Add one value to the operand list.
-  BBArgument *addArgument(Type *type);
+  BBArgument *addArgument(Type type);
 
   /// Add one argument to the argument list for each type specified in the list.
-  llvm::iterator_range<args_iterator> addArguments(ArrayRef<Type *> types);
+  llvm::iterator_range<args_iterator> addArguments(ArrayRef<Type> types);
 
   unsigned getNumArguments() const { return arguments.size(); }
   BBArgument *getArgument(unsigned i) { return arguments[i]; }
diff --git a/include/mlir/IR/Builders.h b/include/mlir/IR/Builders.h
index 2e48008..46952e2 100644
--- a/include/mlir/IR/Builders.h
+++ b/include/mlir/IR/Builders.h
@@ -68,29 +68,28 @@
                                     unsigned column);
 
   // Types.
-  FloatType *getBF16Type();
-  FloatType *getF16Type();
-  FloatType *getF32Type();
-  FloatType *getF64Type();
+  FloatType getBF16Type();
+  FloatType getF16Type();
+  FloatType getF32Type();
+  FloatType getF64Type();
 
-  OtherType *getIndexType();
-  OtherType *getTFControlType();
-  OtherType *getTFStringType();
-  OtherType *getTFResourceType();
-  OtherType *getTFVariantType();
-  OtherType *getTFComplex64Type();
-  OtherType *getTFComplex128Type();
-  OtherType *getTFF32REFType();
+  OtherType getIndexType();
+  OtherType getTFControlType();
+  OtherType getTFStringType();
+  OtherType getTFResourceType();
+  OtherType getTFVariantType();
+  OtherType getTFComplex64Type();
+  OtherType getTFComplex128Type();
+  OtherType getTFF32REFType();
 
-  IntegerType *getIntegerType(unsigned width);
-  FunctionType *getFunctionType(ArrayRef<Type *> inputs,
-                                ArrayRef<Type *> results);
-  MemRefType *getMemRefType(ArrayRef<int> shape, Type *elementType,
-                            ArrayRef<AffineMap> affineMapComposition = {},
-                            unsigned memorySpace = 0);
-  VectorType *getVectorType(ArrayRef<int> shape, Type *elementType);
-  RankedTensorType *getTensorType(ArrayRef<int> shape, Type *elementType);
-  UnrankedTensorType *getTensorType(Type *elementType);
+  IntegerType getIntegerType(unsigned width);
+  FunctionType getFunctionType(ArrayRef<Type> inputs, ArrayRef<Type> results);
+  MemRefType getMemRefType(ArrayRef<int> shape, Type elementType,
+                           ArrayRef<AffineMap> affineMapComposition = {},
+                           unsigned memorySpace = 0);
+  VectorType getVectorType(ArrayRef<int> shape, Type elementType);
+  RankedTensorType getTensorType(ArrayRef<int> shape, Type elementType);
+  UnrankedTensorType getTensorType(Type elementType);
 
   // Attributes.
 
@@ -102,15 +101,15 @@
   ArrayAttr getArrayAttr(ArrayRef<Attribute> value);
   AffineMapAttr getAffineMapAttr(AffineMap map);
   IntegerSetAttr getIntegerSetAttr(IntegerSet set);
-  TypeAttr getTypeAttr(Type *type);
+  TypeAttr getTypeAttr(Type type);
   FunctionAttr getFunctionAttr(const Function *value);
-  ElementsAttr getSplatElementsAttr(VectorOrTensorType *type, Attribute elt);
-  ElementsAttr getDenseElementsAttr(VectorOrTensorType *type,
+  ElementsAttr getSplatElementsAttr(VectorOrTensorType type, Attribute elt);
+  ElementsAttr getDenseElementsAttr(VectorOrTensorType type,
                                     ArrayRef<char> data);
-  ElementsAttr getSparseElementsAttr(VectorOrTensorType *type,
+  ElementsAttr getSparseElementsAttr(VectorOrTensorType type,
                                      DenseIntElementsAttr indices,
                                      DenseElementsAttr values);
-  ElementsAttr getOpaqueElementsAttr(VectorOrTensorType *type, StringRef bytes);
+  ElementsAttr getOpaqueElementsAttr(VectorOrTensorType type, StringRef bytes);
 
   // Affine expressions and affine maps.
   AffineExpr getAffineDimExpr(unsigned position);
@@ -366,7 +365,7 @@
   /// Creates an operation given the fields.
   OperationStmt *createOperation(Location *location, OperationName name,
                                  ArrayRef<MLValue *> operands,
-                                 ArrayRef<Type *> types,
+                                 ArrayRef<Type> types,
                                  ArrayRef<NamedAttribute> attrs);
 
   /// Create operation of specific op type at the current insertion point.
diff --git a/include/mlir/IR/BuiltinOps.h b/include/mlir/IR/BuiltinOps.h
index 88d4d81..5d810a9 100644
--- a/include/mlir/IR/BuiltinOps.h
+++ b/include/mlir/IR/BuiltinOps.h
@@ -96,7 +96,7 @@
 public:
   /// Builds a constant op with the specified attribute value and result type.
   static void build(Builder *builder, OperationState *result, Attribute value,
-                    Type *type);
+                    Type type);
 
   Attribute getValue() const { return getAttr("value"); }
 
@@ -123,7 +123,7 @@
 public:
   /// Builds a constant float op producing a float of the specified type.
   static void build(Builder *builder, OperationState *result,
-                    const APFloat &value, FloatType *type);
+                    const APFloat &value, FloatType type);
 
   APFloat getValue() const {
     return getAttrOfType<FloatAttr>("value").getValue();
@@ -150,7 +150,7 @@
   /// Build a constant int op producing an integer with the specified type,
   /// which must be an integer type.
   static void build(Builder *builder, OperationState *result, int64_t value,
-                    Type *type);
+                    Type type);
 
   int64_t getValue() const {
     return getAttrOfType<IntegerAttr>("value").getValue();
diff --git a/include/mlir/IR/CFGFunction.h b/include/mlir/IR/CFGFunction.h
index f3c1da3..fb20a6b 100644
--- a/include/mlir/IR/CFGFunction.h
+++ b/include/mlir/IR/CFGFunction.h
@@ -27,7 +27,7 @@
 // blocks, each of which includes instructions.
 class CFGFunction : public Function {
 public:
-  CFGFunction(Location *location, StringRef name, FunctionType *type,
+  CFGFunction(Location *location, StringRef name, FunctionType type,
               ArrayRef<NamedAttribute> attrs = {});
 
   ~CFGFunction();
diff --git a/include/mlir/IR/CFGValue.h b/include/mlir/IR/CFGValue.h
index 939073c..45b36c1 100644
--- a/include/mlir/IR/CFGValue.h
+++ b/include/mlir/IR/CFGValue.h
@@ -66,7 +66,7 @@
   }
 
 protected:
-  CFGValue(CFGValueKind kind, Type *type) : SSAValueImpl(kind, type) {}
+  CFGValue(CFGValueKind kind, Type type) : SSAValueImpl(kind, type) {}
 };
 
 /// Basic block arguments are CFG Values.
@@ -87,7 +87,7 @@
 
 private:
   friend class BasicBlock; // For access to private constructor.
-  BBArgument(Type *type, BasicBlock *owner)
+  BBArgument(Type type, BasicBlock *owner)
       : CFGValue(CFGValueKind::BBArgument, type), owner(owner) {}
 
   /// The owner of this operand.
@@ -99,7 +99,7 @@
 /// Instruction results are CFG Values.
 class InstResult : public CFGValue {
 public:
-  InstResult(Type *type, OperationInst *owner)
+  InstResult(Type type, OperationInst *owner)
       : CFGValue(CFGValueKind::InstResult, type), owner(owner) {}
 
   static bool classof(const SSAValue *value) {
diff --git a/include/mlir/IR/Function.h b/include/mlir/IR/Function.h
index d42f528..04acc59 100644
--- a/include/mlir/IR/Function.h
+++ b/include/mlir/IR/Function.h
@@ -26,6 +26,7 @@
 
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/Identifier.h"
+#include "mlir/IR/Types.h"
 #include "mlir/Support/LLVM.h"
 #include "llvm/ADT/ilist.h"
 
@@ -55,7 +56,7 @@
   Identifier getName() const { return nameAndKind.getPointer(); }
 
   /// Return the type of this function.
-  FunctionType *getType() const { return type; }
+  FunctionType getType() const { return type; }
 
   /// Returns all of the attributes on this function.
   ArrayRef<NamedAttribute> getAttrs() const;
@@ -93,7 +94,7 @@
   void emitNote(const Twine &message) const;
 
 protected:
-  Function(Kind kind, Location *location, StringRef name, FunctionType *type,
+  Function(Kind kind, Location *location, StringRef name, FunctionType type,
            ArrayRef<NamedAttribute> attrs = {});
   ~Function();
 
@@ -108,7 +109,7 @@
   Location *location;
 
   /// The type of the function.
-  FunctionType *const type;
+  FunctionType type;
 
   /// This holds general named attributes for the function.
   AttributeListStorage *attrs;
@@ -121,7 +122,7 @@
 /// defined in some other module.
 class ExtFunction : public Function {
 public:
-  ExtFunction(Location *location, StringRef name, FunctionType *type,
+  ExtFunction(Location *location, StringRef name, FunctionType type,
               ArrayRef<NamedAttribute> attrs = {});
 
   /// Methods for support type inquiry through isa, cast, and dyn_cast.
diff --git a/include/mlir/IR/Instructions.h b/include/mlir/IR/Instructions.h
index e74c561..6d5a1ca 100644
--- a/include/mlir/IR/Instructions.h
+++ b/include/mlir/IR/Instructions.h
@@ -202,7 +202,7 @@
   /// Create a new OperationInst with the specified fields.
   static OperationInst *create(Location *location, OperationName name,
                                ArrayRef<CFGValue *> operands,
-                               ArrayRef<Type *> resultTypes,
+                               ArrayRef<Type> resultTypes,
                                ArrayRef<NamedAttribute> attributes,
                                MLIRContext *context);
 
diff --git a/include/mlir/IR/MLFunction.h b/include/mlir/IR/MLFunction.h
index 104692f..cf0deb9 100644
--- a/include/mlir/IR/MLFunction.h
+++ b/include/mlir/IR/MLFunction.h
@@ -41,7 +41,7 @@
 public:
   /// Creates a new MLFunction with the specific type.
   static MLFunction *create(Location *location, StringRef name,
-                            FunctionType *type,
+                            FunctionType type,
                             ArrayRef<NamedAttribute> attrs = {});
 
   /// Destroys this statement and its subclass data.
@@ -52,7 +52,7 @@
   //===--------------------------------------------------------------------===//
 
   /// Returns number of arguments.
-  unsigned getNumArguments() const { return getType()->getInputs().size(); }
+  unsigned getNumArguments() const { return getType().getInputs().size(); }
 
   /// Gets argument.
   MLFuncArgument *getArgument(unsigned idx) {
@@ -103,13 +103,13 @@
   }
 
 private:
-  MLFunction(Location *location, StringRef name, FunctionType *type,
+  MLFunction(Location *location, StringRef name, FunctionType type,
              ArrayRef<NamedAttribute> attrs = {});
 
   // This stuff is used by the TrailingObjects template.
   friend llvm::TrailingObjects<MLFunction, MLFuncArgument>;
   size_t numTrailingObjects(OverloadToken<MLFuncArgument>) const {
-    return getType()->getInputs().size();
+    return getType().getInputs().size();
   }
 
   // Internal functions to get argument list used by getArgument() methods.
diff --git a/include/mlir/IR/MLValue.h b/include/mlir/IR/MLValue.h
index 1961da1..0c6c0b2 100644
--- a/include/mlir/IR/MLValue.h
+++ b/include/mlir/IR/MLValue.h
@@ -73,7 +73,7 @@
   }
 
 protected:
-  MLValue(MLValueKind kind, Type *type) : SSAValueImpl(kind, type) {}
+  MLValue(MLValueKind kind, Type type) : SSAValueImpl(kind, type) {}
 };
 
 /// This is the value defined by an argument of an ML function.
@@ -93,7 +93,7 @@
 
 private:
   friend class MLFunction; // For access to private constructor.
-  MLFuncArgument(Type *type, MLFunction *owner)
+  MLFuncArgument(Type type, MLFunction *owner)
       : MLValue(MLValueKind::MLFuncArgument, type), owner(owner) {}
 
   /// The owner of this operand.
@@ -105,7 +105,7 @@
 /// This is a value defined by a result of an operation instruction.
 class StmtResult : public MLValue {
 public:
-  StmtResult(Type *type, OperationStmt *owner)
+  StmtResult(Type type, OperationStmt *owner)
       : MLValue(MLValueKind::StmtResult, type), owner(owner) {}
 
   static bool classof(const SSAValue *value) {
diff --git a/include/mlir/IR/Matchers.h b/include/mlir/IR/Matchers.h
index 06013a7..ad97dd2 100644
--- a/include/mlir/IR/Matchers.h
+++ b/include/mlir/IR/Matchers.h
@@ -71,13 +71,13 @@
 
   bool match(Operation *op) {
     if (auto constOp = op->dyn_cast<ConstantOp>()) {
-      auto *type = constOp->getResult()->getType();
+      auto type = constOp->getResult()->getType();
       auto attr = constOp->getAttr("value");
 
-      if (isa<IntegerType>(type)) {
+      if (type.isa<IntegerType>()) {
         return attr_value_binder<IntegerAttr>(bind_value).match(attr);
       }
-      if (isa<VectorOrTensorType>(type)) {
+      if (type.isa<VectorOrTensorType>()) {
         if (auto splatAttr = attr.dyn_cast<SplatElementsAttr>()) {
           return attr_value_binder<IntegerAttr>(bind_value)
               .match(splatAttr.getValue());
diff --git a/include/mlir/IR/OpDefinition.h b/include/mlir/IR/OpDefinition.h
index 821beb2..c2bc357 100644
--- a/include/mlir/IR/OpDefinition.h
+++ b/include/mlir/IR/OpDefinition.h
@@ -493,7 +493,7 @@
     return this->getOperation()->getResult(0);
   }
 
-  Type *getType() const { return getResult()->getType(); }
+  Type getType() const { return getResult()->getType(); }
 
   /// Replace all uses of 'this' value with the new value, updating anything in
   /// the IR that uses 'this' to use the other value instead.  When this returns
@@ -539,7 +539,7 @@
       return this->getOperation()->getResult(i);
     }
 
-    Type *getType(unsigned i) const { return getResult(i)->getType(); }
+    Type getType(unsigned i) const { return getResult(i)->getType(); }
 
     static bool verifyTrait(const Operation *op) {
       return impl::verifyNResults(op, N);
@@ -565,7 +565,7 @@
       return this->getOperation()->getResult(i);
     }
 
-    Type *getType(unsigned i) const { return getResult(i)->getType(); }
+    Type getType(unsigned i) const { return getResult(i)->getType(); }
 
     static bool verifyTrait(const Operation *op) {
       return impl::verifyAtLeastNResults(op, N);
@@ -803,7 +803,7 @@
 // which avoids them being template instantiated/duplicated.
 namespace impl {
 void buildCastOp(Builder *builder, OperationState *result, SSAValue *source,
-                 Type *destType);
+                 Type destType);
 bool parseCastOp(OpAsmParser *parser, OperationState *result);
 void printCastOp(const Operation *op, OpAsmPrinter *p);
 } // namespace impl
@@ -819,7 +819,7 @@
                          OpTrait::HasNoSideEffect, Traits...> {
 public:
   static void build(Builder *builder, OperationState *result, SSAValue *source,
-                    Type *destType) {
+                    Type destType) {
     impl::buildCastOp(builder, result, source, destType);
   }
   static bool parse(OpAsmParser *parser, OperationState *result) {
diff --git a/include/mlir/IR/OpImplementation.h b/include/mlir/IR/OpImplementation.h
index 09ec3f9..ae8df55 100644
--- a/include/mlir/IR/OpImplementation.h
+++ b/include/mlir/IR/OpImplementation.h
@@ -67,7 +67,7 @@
       printOperand(*it);
     }
   }
-  virtual void printType(const Type *type) = 0;
+  virtual void printType(Type type) = 0;
   virtual void printFunctionReference(const Function *func) = 0;
   virtual void printAttribute(Attribute attr) = 0;
   virtual void printAffineMap(AffineMap map) = 0;
@@ -95,8 +95,8 @@
   return p;
 }
 
-inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const Type &type) {
-  p.printType(&type);
+inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Type type) {
+  p.printType(type);
   return p;
 }
 
@@ -163,20 +163,20 @@
   virtual bool parseComma() = 0;
 
   /// Parse a colon followed by a type.
-  virtual bool parseColonType(Type *&result) = 0;
+  virtual bool parseColonType(Type &result) = 0;
 
   /// Parse a type of a specific kind, e.g. a FunctionType.
-  template <typename TypeType> bool parseColonType(TypeType *&result) {
+  template <typename TypeType> bool parseColonType(TypeType &result) {
     llvm::SMLoc loc;
     getCurrentLocation(&loc);
 
     // Parse any kind of type.
-    Type *type;
+    Type type;
     if (parseColonType(type))
       return true;
 
     // Check for the right kind of attribute.
-    result = dyn_cast<TypeType>(type);
+    result = type.dyn_cast<TypeType>();
     if (!result) {
       emitError(loc, "invalid kind of type specified");
       return true;
@@ -186,15 +186,15 @@
   }
 
   /// Parse a colon followed by a type list, which must have at least one type.
-  virtual bool parseColonTypeList(SmallVectorImpl<Type *> &result) = 0;
+  virtual bool parseColonTypeList(SmallVectorImpl<Type> &result) = 0;
 
   /// Parse a keyword followed by a type.
-  virtual bool parseKeywordType(const char *keyword, Type *&result) = 0;
+  virtual bool parseKeywordType(const char *keyword, Type &result) = 0;
 
   /// Add the specified type to the end of the specified type list and return
   /// false.  This is a helper designed to allow parse methods to be simple and
   /// chain through || operators.
-  bool addTypeToList(Type *type, SmallVectorImpl<Type *> &result) {
+  bool addTypeToList(Type type, SmallVectorImpl<Type> &result) {
     result.push_back(type);
     return false;
   }
@@ -202,7 +202,7 @@
   /// Add the specified types to the end of the specified type list and return
   /// false.  This is a helper designed to allow parse methods to be simple and
   /// chain through || operators.
-  bool addTypesToList(ArrayRef<Type *> types, SmallVectorImpl<Type *> &result) {
+  bool addTypesToList(ArrayRef<Type> types, SmallVectorImpl<Type> &result) {
     result.append(types.begin(), types.end());
     return false;
   }
@@ -288,13 +288,13 @@
 
   /// Resolve an operand to an SSA value, emitting an error and returning true
   /// on failure.
-  virtual bool resolveOperand(const OperandType &operand, Type *type,
+  virtual bool resolveOperand(const OperandType &operand, Type type,
                               SmallVectorImpl<SSAValue *> &result) = 0;
 
   /// Resolve a list of operands to SSA values, emitting an error and returning
   /// true on failure, or appending the results to the list on success.
   /// This method should be used when all operands have the same type.
-  virtual bool resolveOperands(ArrayRef<OperandType> operands, Type *type,
+  virtual bool resolveOperands(ArrayRef<OperandType> operands, Type type,
                                SmallVectorImpl<SSAValue *> &result) {
     for (auto elt : operands)
       if (resolveOperand(elt, type, result))
@@ -306,7 +306,7 @@
   /// emitting an error and returning true on failure, or appending the results
   /// to the list on success.
   virtual bool resolveOperands(ArrayRef<OperandType> operands,
-                               ArrayRef<Type *> types, llvm::SMLoc loc,
+                               ArrayRef<Type> types, llvm::SMLoc loc,
                                SmallVectorImpl<SSAValue *> &result) {
     if (operands.size() != types.size())
       return emitError(loc, Twine(operands.size()) +
@@ -321,7 +321,7 @@
   }
 
   /// Resolve a parse function name and a type into a function reference.
-  virtual bool resolveFunctionName(StringRef name, FunctionType *type,
+  virtual bool resolveFunctionName(StringRef name, FunctionType type,
                                    llvm::SMLoc loc, Function *&result) = 0;
 
   /// Emit a diagnostic at the specified location and return true.
diff --git a/include/mlir/IR/OperationSupport.h b/include/mlir/IR/OperationSupport.h
index ddf1aee..6833e45 100644
--- a/include/mlir/IR/OperationSupport.h
+++ b/include/mlir/IR/OperationSupport.h
@@ -25,6 +25,7 @@
 
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/Identifier.h"
+#include "mlir/IR/Types.h"
 #include "llvm/ADT/PointerUnion.h"
 #include <memory>
 
@@ -191,7 +192,7 @@
   OperationName name;
   SmallVector<SSAValue *, 4> operands;
   /// Types of the results of this operation.
-  SmallVector<Type *, 4> types;
+  SmallVector<Type, 4> types;
   SmallVector<NamedAttribute, 4> attributes;
 
 public:
@@ -202,7 +203,7 @@
       : context(context), location(location), name(name) {}
 
   OperationState(MLIRContext *context, Location *location, StringRef name,
-                 ArrayRef<SSAValue *> operands, ArrayRef<Type *> types,
+                 ArrayRef<SSAValue *> operands, ArrayRef<Type> types,
                  ArrayRef<NamedAttribute> attributes = {})
       : context(context), location(location), name(name, context),
         operands(operands.begin(), operands.end()),
@@ -213,7 +214,7 @@
     operands.append(newOperands.begin(), newOperands.end());
   }
 
-  void addTypes(ArrayRef<Type *> newTypes) {
+  void addTypes(ArrayRef<Type> newTypes) {
     types.append(newTypes.begin(), newTypes.end());
   }
 
diff --git a/include/mlir/IR/SSAValue.h b/include/mlir/IR/SSAValue.h
index 93db6fa..ab16c98 100644
--- a/include/mlir/IR/SSAValue.h
+++ b/include/mlir/IR/SSAValue.h
@@ -25,7 +25,6 @@
 #include "mlir/IR/Types.h"
 #include "mlir/IR/UseDefLists.h"
 #include "mlir/Support/LLVM.h"
-#include "llvm/ADT/PointerIntPair.h"
 
 namespace mlir {
 class Function;
@@ -51,7 +50,7 @@
 
   SSAValueKind getKind() const { return typeAndKind.getInt(); }
 
-  Type *getType() const { return typeAndKind.getPointer(); }
+  Type getType() const { return typeAndKind.getPointer(); }
 
   /// Replace all uses of 'this' value with the new value, updating anything in
   /// the IR that uses 'this' to use the other value instead.  When this returns
@@ -93,9 +92,10 @@
   void dump() const;
 
 protected:
-  SSAValue(SSAValueKind kind, Type *type) : typeAndKind(type, kind) {}
+  SSAValue(SSAValueKind kind, Type type) : typeAndKind(type, kind) {}
+
 private:
-  const llvm::PointerIntPair<Type *, 3, SSAValueKind> typeAndKind;
+  const llvm::PointerIntPair<Type, 3, SSAValueKind> typeAndKind;
 };
 
 inline raw_ostream &operator<<(raw_ostream &os, const SSAValue &value) {
@@ -127,7 +127,7 @@
   inline use_range getUses() const;
 
 protected:
-  SSAValueImpl(KindTy kind, Type *type) : SSAValue((SSAValueKind)kind, type) {}
+  SSAValueImpl(KindTy kind, Type type) : SSAValue((SSAValueKind)kind, type) {}
 };
 
 // Utility functions for iterating through SSAValue uses.
diff --git a/include/mlir/IR/Statements.h b/include/mlir/IR/Statements.h
index 7e7a49f..8a5a9b5 100644
--- a/include/mlir/IR/Statements.h
+++ b/include/mlir/IR/Statements.h
@@ -44,7 +44,7 @@
   /// Create a new OperationStmt with the specific fields.
   static OperationStmt *create(Location *location, OperationName name,
                                ArrayRef<MLValue *> operands,
-                               ArrayRef<Type *> resultTypes,
+                               ArrayRef<Type> resultTypes,
                                ArrayRef<NamedAttribute> attributes,
                                MLIRContext *context);
 
@@ -329,7 +329,7 @@
   //===--------------------------------------------------------------------===//
 
   /// Return the context this operation is associated with.
-  MLIRContext *getContext() const { return getType()->getContext(); }
+  MLIRContext *getContext() const { return getType().getContext(); }
 
   using Statement::dump;
   using Statement::print;
diff --git a/include/mlir/IR/Types.h b/include/mlir/IR/Types.h
index 3d0afdf..493f607 100644
--- a/include/mlir/IR/Types.h
+++ b/include/mlir/IR/Types.h
@@ -20,6 +20,7 @@
 
 #include "mlir/Support/LLVM.h"
 #include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/DenseMapInfo.h"
 
 namespace mlir {
 class AffineMap;
@@ -28,6 +29,22 @@
 class FloatType;
 class OtherType;
 
+namespace detail {
+
+class TypeStorage;
+class IntegerTypeStorage;
+class FloatTypeStorage;
+struct OtherTypeStorage;
+struct FunctionTypeStorage;
+struct VectorOrTensorTypeStorage;
+struct VectorTypeStorage;
+struct TensorTypeStorage;
+struct RankedTensorTypeStorage;
+struct UnrankedTensorTypeStorage;
+struct MemRefTypeStorage;
+
+} // namespace detail
+
 /// Instances of the Type class are immutable, uniqued, immortal, and owned by
 /// MLIRContext.  As such, they are passed around by raw non-const pointer.
 ///
@@ -68,11 +85,34 @@
     MemRef,
   };
 
+  using ImplType = detail::TypeStorage;
+
+  Type() : type(nullptr) {}
+  /* implicit */ Type(const ImplType *type)
+      : type(const_cast<ImplType *>(type)) {}
+
+  Type(const Type &other) : type(other.type) {}
+  Type &operator=(Type other) {
+    type = other.type;
+    return *this;
+  }
+
+  bool operator==(Type other) const { return type == other.type; }
+  bool operator!=(Type other) const { return !(*this == other); }
+  explicit operator bool() const { return type; }
+
+  bool operator!() const { return type == nullptr; }
+
+  template <typename U> bool isa() const;
+  template <typename U> U dyn_cast() const;
+  template <typename U> U dyn_cast_or_null() const;
+  template <typename U> U cast() const;
+
   /// Return the classification for this type.
-  Kind getKind() const { return kind; }
+  Kind getKind() const;
 
   /// Return the LLVMContext in which this type was uniqued.
-  MLIRContext *getContext() const { return context; }
+  MLIRContext *getContext() const;
 
   // Convenience predicates.  This is only for 'other' and floating point types,
   // derived types should use isa/dyn_cast.
@@ -97,56 +137,42 @@
   unsigned getBitWidth() const;
 
   // Convenience factories.
-  static IntegerType *getInteger(unsigned width, MLIRContext *ctx);
-  static FloatType *getBF16(MLIRContext *ctx);
-  static FloatType *getF16(MLIRContext *ctx);
-  static FloatType *getF32(MLIRContext *ctx);
-  static FloatType *getF64(MLIRContext *ctx);
-  static OtherType *getIndex(MLIRContext *ctx);
-  static OtherType *getTFControl(MLIRContext *ctx);
-  static OtherType *getTFString(MLIRContext *ctx);
-  static OtherType *getTFResource(MLIRContext *ctx);
-  static OtherType *getTFVariant(MLIRContext *ctx);
-  static OtherType *getTFComplex64(MLIRContext *ctx);
-  static OtherType *getTFComplex128(MLIRContext *ctx);
-  static OtherType *getTFF32REF(MLIRContext *ctx);
+  static IntegerType getInteger(unsigned width, MLIRContext *ctx);
+  static FloatType getBF16(MLIRContext *ctx);
+  static FloatType getF16(MLIRContext *ctx);
+  static FloatType getF32(MLIRContext *ctx);
+  static FloatType getF64(MLIRContext *ctx);
+  static OtherType getIndex(MLIRContext *ctx);
+  static OtherType getTFControl(MLIRContext *ctx);
+  static OtherType getTFString(MLIRContext *ctx);
+  static OtherType getTFResource(MLIRContext *ctx);
+  static OtherType getTFVariant(MLIRContext *ctx);
+  static OtherType getTFComplex64(MLIRContext *ctx);
+  static OtherType getTFComplex128(MLIRContext *ctx);
+  static OtherType getTFF32REF(MLIRContext *ctx);
 
   /// Print the current type.
   void print(raw_ostream &os) const;
   void dump() const;
 
+  friend ::llvm::hash_code hash_value(Type arg);
+
+  unsigned getSubclassData() const;
+  void setSubclassData(unsigned val);
+
+  /// Methods for supporting PointerLikeTypeTraits.
+  const void *getAsOpaquePointer() const {
+    return static_cast<const void *>(type);
+  }
+  static Type getFromOpaquePointer(const void *pointer) {
+    return Type((ImplType *)(pointer));
+  }
+
 protected:
-  explicit Type(Kind kind, MLIRContext *context)
-      : context(context), kind(kind), subclassData(0) {}
-  explicit Type(Kind kind, MLIRContext *context, unsigned subClassData)
-      : Type(kind, context) {
-    setSubclassData(subClassData);
-  }
-
-  ~Type() {}
-
-  unsigned getSubclassData() const { return subclassData; }
-
-  void setSubclassData(unsigned val) {
-    subclassData = val;
-    // Ensure we don't have any accidental truncation.
-    assert(getSubclassData() == val && "Subclass data too large for field");
-  }
-
-private:
-  Type(const Type&) = delete;
-  void operator=(const Type&) = delete;
-  /// This refers to the MLIRContext in which this type was uniqued.
-  MLIRContext *const context;
-
-  /// Classification of the subclass, used for type checking.
-  Kind kind : 8;
-
-  // Space for subclasses to store data.
-  unsigned subclassData : 24;
+  ImplType *type;
 };
 
-inline raw_ostream &operator<<(raw_ostream &os, const Type &type) {
+inline raw_ostream &operator<<(raw_ostream &os, Type type) {
   type.print(os);
   return os;
 }
@@ -154,148 +180,138 @@
 /// Integer types can have arbitrary bitwidth up to a large fixed limit.
 class IntegerType : public Type {
 public:
-  static IntegerType *get(unsigned width, MLIRContext *context);
+  using ImplType = detail::IntegerTypeStorage;
+  IntegerType() = default;
+  /* implicit */ IntegerType(Type::ImplType *ptr);
+
+  static IntegerType get(unsigned width, MLIRContext *context);
 
   /// Return the bitwidth of this integer type.
-  unsigned getWidth() const {
-    return width;
-  }
+  unsigned getWidth() const;
 
   /// Methods for support type inquiry through isa, cast, and dyn_cast.
-  static bool classof(const Type *type) {
-    return type->getKind() == Kind::Integer;
-  }
+  static bool kindof(Kind kind) { return kind == Kind::Integer; }
 
   /// Integer representation maximal bitwidth.
   static constexpr unsigned kMaxWidth = 4096;
-private:
-  unsigned width;
-  IntegerType(unsigned width, MLIRContext *context);
-  ~IntegerType() = delete;
 };
 
-inline IntegerType *Type::getInteger(unsigned width, MLIRContext *ctx) {
+inline IntegerType Type::getInteger(unsigned width, MLIRContext *ctx) {
   return IntegerType::get(width, ctx);
 }
 
 /// Return true if this is an integer type with the specified width.
 inline bool Type::isInteger(unsigned width) const {
-  if (auto *intTy = dyn_cast<IntegerType>(this))
-    return intTy->getWidth() == width;
+  if (auto intTy = dyn_cast<IntegerType>())
+    return intTy.getWidth() == width;
   return false;
 }
 
 class FloatType : public Type {
 public:
+  using ImplType = detail::FloatTypeStorage;
+  FloatType() = default;
+  /* implicit */ FloatType(Type::ImplType *ptr);
+
+  static FloatType get(Kind kind, MLIRContext *context);
+
   /// Methods for support type inquiry through isa, cast, and dyn_cast.
-  static bool classof(const Type *type) {
-    return type->getKind() >= Kind::FIRST_FLOATING_POINT_TYPE &&
-           type->getKind() <= Kind::LAST_FLOATING_POINT_TYPE;
+  static bool kindof(Kind kind) {
+    return kind >= Kind::FIRST_FLOATING_POINT_TYPE &&
+           kind <= Kind::LAST_FLOATING_POINT_TYPE;
   }
-
-  static FloatType *get(Kind kind, MLIRContext *context);
-
-private:
-  FloatType(Kind kind, MLIRContext *context);
-  ~FloatType() = delete;
 };
 
-inline FloatType *Type::getBF16(MLIRContext *ctx) {
+inline FloatType Type::getBF16(MLIRContext *ctx) {
   return FloatType::get(Kind::BF16, ctx);
 }
-inline FloatType *Type::getF16(MLIRContext *ctx) {
+inline FloatType Type::getF16(MLIRContext *ctx) {
   return FloatType::get(Kind::F16, ctx);
 }
-inline FloatType *Type::getF32(MLIRContext *ctx) {
+inline FloatType Type::getF32(MLIRContext *ctx) {
   return FloatType::get(Kind::F32, ctx);
 }
-inline FloatType *Type::getF64(MLIRContext *ctx) {
+inline FloatType Type::getF64(MLIRContext *ctx) {
   return FloatType::get(Kind::F64, ctx);
 }
 
 /// This is a type for the random collection of special base types.
 class OtherType : public Type {
 public:
-  /// Methods for support type inquiry through isa, cast, and dyn_cast.
-  static bool classof(const Type *type) {
-    return type->getKind() >= Kind::FIRST_OTHER_TYPE &&
-           type->getKind() <= Kind::LAST_OTHER_TYPE;
-  }
-  static OtherType *get(Kind kind, MLIRContext *context);
+  using ImplType = detail::OtherTypeStorage;
+  OtherType() = default;
+  /* implicit */ OtherType(Type::ImplType *ptr);
 
-private:
-  OtherType(Kind kind, MLIRContext *context);
-  ~OtherType() = delete;
+  static OtherType get(Kind kind, MLIRContext *context);
+
+  /// Methods for support type inquiry through isa, cast, and dyn_cast.
+  static bool kindof(Kind kind) {
+    return kind >= Kind::FIRST_OTHER_TYPE && kind <= Kind::LAST_OTHER_TYPE;
+  }
 };
 
-inline OtherType *Type::getIndex(MLIRContext *ctx) {
+inline OtherType Type::getIndex(MLIRContext *ctx) {
   return OtherType::get(Kind::Index, ctx);
 }
-inline OtherType *Type::getTFControl(MLIRContext *ctx) {
+inline OtherType Type::getTFControl(MLIRContext *ctx) {
   return OtherType::get(Kind::TFControl, ctx);
 }
-inline OtherType *Type::getTFResource(MLIRContext *ctx) {
+inline OtherType Type::getTFResource(MLIRContext *ctx) {
   return OtherType::get(Kind::TFResource, ctx);
 }
-inline OtherType *Type::getTFString(MLIRContext *ctx) {
+inline OtherType Type::getTFString(MLIRContext *ctx) {
   return OtherType::get(Kind::TFString, ctx);
 }
-inline OtherType *Type::getTFVariant(MLIRContext *ctx) {
+inline OtherType Type::getTFVariant(MLIRContext *ctx) {
   return OtherType::get(Kind::TFVariant, ctx);
 }
-inline OtherType *Type::getTFComplex64(MLIRContext *ctx) {
+inline OtherType Type::getTFComplex64(MLIRContext *ctx) {
   return OtherType::get(Kind::TFComplex64, ctx);
 }
-inline OtherType *Type::getTFComplex128(MLIRContext *ctx) {
+inline OtherType Type::getTFComplex128(MLIRContext *ctx) {
   return OtherType::get(Kind::TFComplex128, ctx);
 }
-inline OtherType *Type::getTFF32REF(MLIRContext *ctx) {
+inline OtherType Type::getTFF32REF(MLIRContext *ctx) {
   return OtherType::get(Kind::TFF32REF, ctx);
 }
 
 /// Function types map from a list of inputs to a list of results.
 class FunctionType : public Type {
 public:
-  static FunctionType *get(ArrayRef<Type*> inputs, ArrayRef<Type*> results,
-                           MLIRContext *context);
+  using ImplType = detail::FunctionTypeStorage;
+  FunctionType() = default;
+  /* implicit */ FunctionType(Type::ImplType *ptr);
+
+  static FunctionType get(ArrayRef<Type> inputs, ArrayRef<Type> results,
+                          MLIRContext *context);
 
   // Input types.
   unsigned getNumInputs() const { return getSubclassData(); }
 
-  Type *getInput(unsigned i) const { return getInputs()[i]; }
+  Type getInput(unsigned i) const { return getInputs()[i]; }
 
-  ArrayRef<Type*> getInputs() const {
-    return ArrayRef<Type *>(inputsAndResults, getNumInputs());
-  }
+  ArrayRef<Type> getInputs() const;
 
   // Result types.
-  unsigned getNumResults() const { return numResults; }
+  unsigned getNumResults() const;
 
-  Type *getResult(unsigned i) const { return getResults()[i]; }
+  Type getResult(unsigned i) const { return getResults()[i]; }
 
-  ArrayRef<Type*> getResults() const {
-    return ArrayRef<Type *>(inputsAndResults + getSubclassData(), numResults);
-  }
+  ArrayRef<Type> getResults() const;
 
   /// Methods for support type inquiry through isa, cast, and dyn_cast.
-  static bool classof(const Type *type) {
-    return type->getKind() == Kind::Function;
-  }
-
-private:
-  unsigned numResults;
-  Type *const *inputsAndResults;
-
-  FunctionType(Type *const *inputsAndResults, unsigned numInputs,
-               unsigned numResults, MLIRContext *context);
-  ~FunctionType() = delete;
+  static bool kindof(Kind kind) { return kind == Kind::Function; }
 };
 
 /// This is a common base class between Vector, UnrankedTensor, and RankedTensor
 /// types, because many operations work on values of these aggregate types.
 class VectorOrTensorType : public Type {
 public:
-  Type *getElementType() const { return elementType; }
+  using ImplType = detail::VectorOrTensorTypeStorage;
+  VectorOrTensorType() = default;
+  /* implicit */ VectorOrTensorType(Type::ImplType *ptr);
+
+  Type getElementType() const;
 
   /// If this is ranked tensor or vector type, return the number of elements. If
   /// it is an unranked tensor or vector, abort.
@@ -319,56 +335,40 @@
   int getDimSize(unsigned i) const;
 
   /// Methods for support type inquiry through isa, cast, and dyn_cast.
-  static bool classof(const Type *type) {
-    return type->getKind() == Kind::Vector ||
-           type->getKind() == Kind::RankedTensor ||
-           type->getKind() == Kind::UnrankedTensor;
+  static bool kindof(Kind kind) {
+    return kind == Kind::Vector || kind == Kind::RankedTensor ||
+           kind == Kind::UnrankedTensor;
   }
-
-public:
-  Type *elementType;
-
-  VectorOrTensorType(Kind kind, MLIRContext *context, Type *elementType,
-                     unsigned subClassData = 0);
 };
 
 /// Vector types represent multi-dimensional SIMD vectors, and have a fixed
 /// known constant shape with one or more dimension.
 class VectorType : public VectorOrTensorType {
 public:
-  static VectorType *get(ArrayRef<int> shape, Type *elementType);
+  using ImplType = detail::VectorTypeStorage;
+  VectorType() = default;
+  /* implicit */ VectorType(Type::ImplType *ptr);
 
-  ArrayRef<int> getShape() const {
-    return ArrayRef<int>(shapeElements, getSubclassData());
-  }
+  static VectorType get(ArrayRef<int> shape, Type elementType);
+
+  ArrayRef<int> getShape() const;
 
   /// Methods for support type inquiry through isa, cast, and dyn_cast.
-  static bool classof(const Type *type) {
-    return type->getKind() == Kind::Vector;
-  }
-
-private:
-  const int *shapeElements;
-  Type *elementType;
-
-  VectorType(ArrayRef<int> shape, Type *elementType, MLIRContext *context);
-  ~VectorType() = delete;
+  static bool kindof(Kind kind) { return kind == Kind::Vector; }
 };
 
 /// Tensor types represent multi-dimensional arrays, and have two variants:
 /// RankedTensorType and UnrankedTensorType.
 class TensorType : public VectorOrTensorType {
 public:
+  using ImplType = detail::TensorTypeStorage;
+  TensorType() = default;
+  /* implicit */ TensorType(Type::ImplType *ptr);
 
   /// Methods for support type inquiry through isa, cast, and dyn_cast.
-  static bool classof(const Type *type) {
-    return type->getKind() == Kind::RankedTensor ||
-           type->getKind() == Kind::UnrankedTensor;
+  static bool kindof(Kind kind) {
+    return kind == Kind::RankedTensor || kind == Kind::UnrankedTensor;
   }
-
-protected:
-  TensorType(Kind kind, Type *elementType, MLIRContext *context);
-  ~TensorType() {}
 };
 
 /// Ranked tensor types represent multi-dimensional arrays that have a shape
@@ -376,40 +376,30 @@
 /// integer or unknown (represented -1).
 class RankedTensorType : public TensorType {
 public:
-  static RankedTensorType *get(ArrayRef<int> shape,
-                               Type *elementType);
+  using ImplType = detail::RankedTensorTypeStorage;
+  RankedTensorType() = default;
+  /* implicit */ RankedTensorType(Type::ImplType *ptr);
 
-  ArrayRef<int> getShape() const {
-    return ArrayRef<int>(shapeElements, getSubclassData());
-  }
+  static RankedTensorType get(ArrayRef<int> shape, Type elementType);
 
-  static bool classof(const Type *type) {
-    return type->getKind() == Kind::RankedTensor;
-  }
+  ArrayRef<int> getShape() const;
 
-private:
-  const int *shapeElements;
-
-  RankedTensorType(ArrayRef<int> shape, Type *elementType,
-                   MLIRContext *context);
-  ~RankedTensorType() = delete;
+  static bool kindof(Kind kind) { return kind == Kind::RankedTensor; }
 };
 
 /// Unranked tensor types represent multi-dimensional arrays that have an
 /// unknown shape.
 class UnrankedTensorType : public TensorType {
 public:
-  static UnrankedTensorType *get(Type *elementType);
+  using ImplType = detail::UnrankedTensorTypeStorage;
+  UnrankedTensorType() = default;
+  /* implicit */ UnrankedTensorType(Type::ImplType *ptr);
+
+  static UnrankedTensorType get(Type elementType);
 
   ArrayRef<int> getShape() const { return ArrayRef<int>(); }
 
-  static bool classof(const Type *type) {
-    return type->getKind() == Kind::UnrankedTensor;
-  }
-
-private:
-  UnrankedTensorType(Type *elementType, MLIRContext *context);
-  ~UnrankedTensorType() = delete;
+  static bool kindof(Kind kind) { return kind == Kind::UnrankedTensor; }
 };
 
 /// MemRef types represent a region of memory that have a shape with a fixed
@@ -418,62 +408,96 @@
 /// affine map composition, represented as an array AffineMap pointers.
 class MemRefType : public Type {
 public:
+  using ImplType = detail::MemRefTypeStorage;
+  MemRefType() = default;
+  /* implicit */ MemRefType(Type::ImplType *ptr);
+
   /// Get or create a new MemRefType based on shape, element type, affine
   /// map composition, and memory space.
-  static MemRefType *get(ArrayRef<int> shape, Type *elementType,
-                         ArrayRef<AffineMap> affineMapComposition,
-                         unsigned memorySpace);
+  static MemRefType get(ArrayRef<int> shape, Type elementType,
+                        ArrayRef<AffineMap> affineMapComposition,
+                        unsigned memorySpace);
 
   unsigned getRank() const { return getShape().size(); }
 
   /// Returns an array of memref shape dimension sizes.
-  ArrayRef<int> getShape() const {
-    return ArrayRef<int>(shapeElements, getSubclassData());
-  }
+  ArrayRef<int> getShape() const;
 
   /// Return the size of the specified dimension, or -1 if unspecified.
   int getDimSize(unsigned i) const { return getShape()[i]; }
 
   /// Returns the elemental type for this memref shape.
-  Type *getElementType() const { return elementType; }
+  Type getElementType() const;
 
   /// Returns an array of affine map pointers representing the memref affine
   /// map composition.
   ArrayRef<AffineMap> getAffineMaps() const;
 
   /// Returns the memory space in which data referred to by this memref resides.
-  unsigned getMemorySpace() const { return memorySpace; }
+  unsigned getMemorySpace() const;
 
   /// Returns the number of dimensions with dynamic size.
   unsigned getNumDynamicDims() const;
 
-  static bool classof(const Type *type) {
-    return type->getKind() == Kind::MemRef;
-  }
-
-private:
-  /// The type of each scalar element of the memref.
-  Type *elementType;
-  /// An array of integers which stores the shape dimension sizes.
-  const int *shapeElements;
-  /// The number of affine maps in the 'affineMapList' array.
-  const unsigned numAffineMaps;
-  /// List of affine maps in the memref's layout/index map composition.
-  AffineMap const *affineMapList;
-  /// Memory space in which data referenced by memref resides.
-  const unsigned memorySpace;
-
-  MemRefType(ArrayRef<int> shape, Type *elementType,
-             ArrayRef<AffineMap> affineMapList, unsigned memorySpace,
-             MLIRContext *context);
-  ~MemRefType() = delete;
+  static bool kindof(Kind kind) { return kind == Kind::MemRef; }
 };
 
-/// Return true if the specified element type is ok in a tensor.
-static bool isValidTensorElementType(Type *type) {
-  return isa<FloatType>(type) || isa<VectorType>(type) ||
-         isa<IntegerType>(type) || isa<OtherType>(type);
+// Make Type hashable.
+inline ::llvm::hash_code hash_value(Type arg) {
+  return ::llvm::hash_value(arg.type);
 }
+
+template <typename U> bool Type::isa() const {
+  assert(type && "isa<> used on a null type.");
+  return U::kindof(getKind());
+}
+template <typename U> U Type::dyn_cast() const {
+  return isa<U>() ? U(type) : U(nullptr);
+}
+template <typename U> U Type::dyn_cast_or_null() const {
+  return (type && isa<U>()) ? U(type) : U(nullptr);
+}
+template <typename U> U Type::cast() const {
+  assert(isa<U>());
+  return U(type);
+}
+
+/// Return true if the specified element type is ok in a tensor.
+static bool isValidTensorElementType(Type type) {
+  return type.isa<FloatType>() || type.isa<VectorType>() ||
+         type.isa<IntegerType>() || type.isa<OtherType>();
+}
+
 } // end namespace mlir
 
+namespace llvm {
+
+// Type hash just like pointers.
+template <> struct DenseMapInfo<mlir::Type> {
+  static mlir::Type getEmptyKey() {
+    auto pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
+    return mlir::Type(static_cast<mlir::Type::ImplType *>(pointer));
+  }
+  static mlir::Type getTombstoneKey() {
+    auto pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
+    return mlir::Type(static_cast<mlir::Type::ImplType *>(pointer));
+  }
+  static unsigned getHashValue(mlir::Type val) { return mlir::hash_value(val); }
+  static bool isEqual(mlir::Type LHS, mlir::Type RHS) { return LHS == RHS; }
+};
+
+/// We align TypeStorage by 8, so allow LLVM to steal the low bits.
+template <> struct PointerLikeTypeTraits<mlir::Type> {
+public:
+  static inline void *getAsVoidPointer(mlir::Type I) {
+    return const_cast<void *>(I.getAsOpaquePointer());
+  }
+  static inline mlir::Type getFromVoidPointer(void *P) {
+    return mlir::Type::getFromOpaquePointer(P);
+  }
+  enum { NumLowBitsAvailable = 3 };
+};
+
+} // namespace llvm
+
 #endif  // MLIR_IR_TYPES_H
diff --git a/include/mlir/StandardOps/StandardOps.h b/include/mlir/StandardOps/StandardOps.h
index b733bad..c0fe4cf 100644
--- a/include/mlir/StandardOps/StandardOps.h
+++ b/include/mlir/StandardOps/StandardOps.h
@@ -104,15 +104,15 @@
     : public Op<AllocOp, OpTrait::VariadicOperands, OpTrait::OneResult> {
 public:
   /// The result of an alloc is always a MemRefType.
-  MemRefType *getType() const {
-    return cast<MemRefType>(getResult()->getType());
+  MemRefType getType() const {
+    return getResult()->getType().cast<MemRefType>();
   }
 
   static StringRef getOperationName() { return "alloc"; }
 
   // Hooks to customize behavior of this op.
   static void build(Builder *builder, OperationState *result,
-                    MemRefType *memrefType, ArrayRef<SSAValue *> operands = {});
+                    MemRefType memrefType, ArrayRef<SSAValue *> operands = {});
   bool verify() const;
   static bool parse(OpAsmParser *parser, OperationState *result);
   void print(OpAsmPrinter *p) const;
@@ -276,7 +276,7 @@
   const SSAValue *getSrcMemRef() const { return getOperand(0); }
   // Returns the rank (number of indices) of the source MemRefType.
   unsigned getSrcMemRefRank() const {
-    return cast<MemRefType>(getSrcMemRef()->getType())->getRank();
+    return getSrcMemRef()->getType().cast<MemRefType>().getRank();
   }
   // Returns the source memerf indices for this DMA operation.
   llvm::iterator_range<Operation::const_operand_iterator>
@@ -291,13 +291,13 @@
   }
   // Returns the rank (number of indices) of the destination MemRefType.
   unsigned getDstMemRefRank() const {
-    return cast<MemRefType>(getDstMemRef()->getType())->getRank();
+    return getDstMemRef()->getType().cast<MemRefType>().getRank();
   }
   unsigned getSrcMemorySpace() const {
-    return cast<MemRefType>(getSrcMemRef()->getType())->getMemorySpace();
+    return getSrcMemRef()->getType().cast<MemRefType>().getMemorySpace();
   }
   unsigned getDstMemorySpace() const {
-    return cast<MemRefType>(getDstMemRef()->getType())->getMemorySpace();
+    return getDstMemRef()->getType().cast<MemRefType>().getMemorySpace();
   }
 
   // Returns the destination memref indices for this DMA operation.
@@ -387,7 +387,7 @@
 
   // Returns the rank (number of indices) of the tag memref.
   unsigned getTagMemRefRank() const {
-    return cast<MemRefType>(getTagMemRef()->getType())->getRank();
+    return getTagMemRef()->getType().cast<MemRefType>().getRank();
   }
 
   // Returns the number of elements transferred in the associated DMA operation.
@@ -460,8 +460,8 @@
   SSAValue *getMemRef() { return getOperand(0); }
   const SSAValue *getMemRef() const { return getOperand(0); }
   void setMemRef(SSAValue *value) { setOperand(0, value); }
-  MemRefType *getMemRefType() const {
-    return cast<MemRefType>(getMemRef()->getType());
+  MemRefType getMemRefType() const {
+    return getMemRef()->getType().cast<MemRefType>();
   }
 
   llvm::iterator_range<Operation::operand_iterator> getIndices() {
@@ -508,8 +508,8 @@
   static StringRef getOperationName() { return "memref_cast"; }
 
   /// The result of a memref_cast is always a memref.
-  MemRefType *getType() const {
-    return cast<MemRefType>(getResult()->getType());
+  MemRefType getType() const {
+    return getResult()->getType().cast<MemRefType>();
   }
 
   bool verify() const;
@@ -583,8 +583,8 @@
   SSAValue *getMemRef() { return getOperand(1); }
   const SSAValue *getMemRef() const { return getOperand(1); }
   void setMemRef(SSAValue *value) { setOperand(1, value); }
-  MemRefType *getMemRefType() const {
-    return cast<MemRefType>(getMemRef()->getType());
+  MemRefType getMemRefType() const {
+    return getMemRef()->getType().cast<MemRefType>();
   }
 
   llvm::iterator_range<Operation::operand_iterator> getIndices() {
@@ -671,8 +671,8 @@
   static StringRef getOperationName() { return "tensor_cast"; }
 
   /// The result of a tensor_cast is always a tensor.
-  TensorType *getType() const {
-    return cast<TensorType>(getResult()->getType());
+  TensorType getType() const {
+    return getResult()->getType().cast<TensorType>();
   }
 
   bool verify() const;
diff --git a/lib/Analysis/LoopAnalysis.cpp b/lib/Analysis/LoopAnalysis.cpp
index 1b3c24f..1904a63 100644
--- a/lib/Analysis/LoopAnalysis.cpp
+++ b/lib/Analysis/LoopAnalysis.cpp
@@ -118,15 +118,15 @@
   return tripCountExpr.getLargestKnownDivisor();
 }
 
-bool mlir::isAccessInvariant(const MLValue &input, MemRefType *memRefType,
+bool mlir::isAccessInvariant(const MLValue &input, MemRefType memRefType,
                              ArrayRef<MLValue *> indices, unsigned dim) {
-  assert(indices.size() == memRefType->getRank());
+  assert(indices.size() == memRefType.getRank());
   assert(dim < indices.size());
-  auto layoutMap = memRefType->getAffineMaps();
-  assert(memRefType->getAffineMaps().size() <= 1);
+  auto layoutMap = memRefType.getAffineMaps();
+  assert(memRefType.getAffineMaps().size() <= 1);
   // TODO(ntv): remove dependency on Builder once we support non-identity
   // layout map.
-  Builder b(memRefType->getContext());
+  Builder b(memRefType.getContext());
   assert(layoutMap.empty() ||
          layoutMap[0] == b.getMultiDimIdentityMap(indices.size()));
   (void)layoutMap;
@@ -170,7 +170,7 @@
   using namespace functional;
   auto indices = map([](SSAValue *val) { return dyn_cast<MLValue>(val); },
                      memoryOp->getIndices());
-  auto *memRefType = memoryOp->getMemRefType();
+  auto memRefType = memoryOp->getMemRefType();
   for (unsigned d = 0, numIndices = indices.size(); d < numIndices; ++d) {
     if (fastestVaryingDim == (numIndices - 1) - d) {
       continue;
@@ -184,8 +184,8 @@
 
 template <typename LoadOrStoreOpPointer>
 static bool isVectorElement(LoadOrStoreOpPointer memoryOp) {
-  auto *memRefType = memoryOp->getMemRefType();
-  return isa<VectorType>(memRefType->getElementType());
+  auto memRefType = memoryOp->getMemRefType();
+  return memRefType.getElementType().template isa<VectorType>();
 }
 
 bool mlir::isVectorizableLoop(const ForStmt &loop, unsigned fastestVaryingDim) {
diff --git a/lib/Analysis/Verifier.cpp b/lib/Analysis/Verifier.cpp
index bfbcb16..0dd030d 100644
--- a/lib/Analysis/Verifier.cpp
+++ b/lib/Analysis/Verifier.cpp
@@ -195,7 +195,7 @@
 
   // Verify that the argument list of the function and the arg list of the first
   // block line up.
-  auto fnInputTypes = fn.getType()->getInputs();
+  auto fnInputTypes = fn.getType().getInputs();
   if (fnInputTypes.size() != firstBB->getNumArguments())
     return failure("first block of cfgfunc must have " +
                        Twine(fnInputTypes.size()) +
@@ -306,7 +306,7 @@
 
 bool CFGFuncVerifier::verifyReturn(const ReturnInst &inst) {
   // Verify that the return operands match the results of the function.
-  auto results = fn.getType()->getResults();
+  auto results = fn.getType().getResults();
   if (inst.getNumOperands() != results.size())
     return failure("return has " + Twine(inst.getNumOperands()) +
                        " operands, but enclosing function returns " +
diff --git a/lib/IR/AsmPrinter.cpp b/lib/IR/AsmPrinter.cpp
index 454a28a..cb5e96f 100644
--- a/lib/IR/AsmPrinter.cpp
+++ b/lib/IR/AsmPrinter.cpp
@@ -122,7 +122,7 @@
   void visitForStmt(const ForStmt *forStmt);
   void visitIfStmt(const IfStmt *ifStmt);
   void visitOperationStmt(const OperationStmt *opStmt);
-  void visitType(const Type *type);
+  void visitType(Type type);
   void visitAttribute(Attribute attr);
   void visitOperation(const Operation *op);
 
@@ -135,16 +135,16 @@
 } // end anonymous namespace
 
 // TODO Support visiting other types/instructions when implemented.
-void ModuleState::visitType(const Type *type) {
-  if (auto *funcType = dyn_cast<FunctionType>(type)) {
+void ModuleState::visitType(Type type) {
+  if (auto funcType = type.dyn_cast<FunctionType>()) {
     // Visit input and result types for functions.
-    for (auto *input : funcType->getInputs())
+    for (auto input : funcType.getInputs())
       visitType(input);
-    for (auto *result : funcType->getResults())
+    for (auto result : funcType.getResults())
       visitType(result);
-  } else if (auto *memref = dyn_cast<MemRefType>(type)) {
+  } else if (auto memref = type.dyn_cast<MemRefType>()) {
     // Visit affine maps in memref type.
-    for (auto map : memref->getAffineMaps()) {
+    for (auto map : memref.getAffineMaps()) {
       recordAffineMapReference(map);
     }
   }
@@ -271,7 +271,7 @@
   void print(const Module *module);
   void printFunctionReference(const Function *func);
   void printAttribute(Attribute attr);
-  void printType(const Type *type);
+  void printType(Type type);
   void print(const Function *fn);
   void print(const ExtFunction *fn);
   void print(const CFGFunction *fn);
@@ -290,7 +290,7 @@
   void printFunctionAttributes(const Function *fn);
   void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
                              ArrayRef<const char *> elidedAttrs = {});
-  void printFunctionResultType(const FunctionType *type);
+  void printFunctionResultType(FunctionType type);
   void printAffineMapId(int affineMapId) const;
   void printAffineMapReference(AffineMap affineMap);
   void printIntegerSetId(int integerSetId) const;
@@ -489,9 +489,9 @@
 }
 
 void ModulePrinter::printDenseElementsAttr(DenseElementsAttr attr) {
-  auto *type = attr.getType();
-  auto shape = type->getShape();
-  auto rank = type->getRank();
+  auto type = attr.getType();
+  auto shape = type.getShape();
+  auto rank = type.getRank();
 
   SmallVector<Attribute, 16> elements;
   attr.getValues(elements);
@@ -541,8 +541,8 @@
     os << ']';
 }
 
-void ModulePrinter::printType(const Type *type) {
-  switch (type->getKind()) {
+void ModulePrinter::printType(Type type) {
+  switch (type.getKind()) {
   case Type::Kind::Index:
     os << "index";
     return;
@@ -581,71 +581,71 @@
     return;
 
   case Type::Kind::Integer: {
-    auto *integer = cast<IntegerType>(type);
-    os << 'i' << integer->getWidth();
+    auto integer = type.cast<IntegerType>();
+    os << 'i' << integer.getWidth();
     return;
   }
   case Type::Kind::Function: {
-    auto *func = cast<FunctionType>(type);
+    auto func = type.cast<FunctionType>();
     os << '(';
-    interleaveComma(func->getInputs(), [&](Type *type) { printType(type); });
+    interleaveComma(func.getInputs(), [&](Type type) { printType(type); });
     os << ") -> ";
-    auto results = func->getResults();
+    auto results = func.getResults();
     if (results.size() == 1)
-      os << *results[0];
+      os << results[0];
     else {
       os << '(';
-      interleaveComma(results, [&](Type *type) { printType(type); });
+      interleaveComma(results, [&](Type type) { printType(type); });
       os << ')';
     }
     return;
   }
   case Type::Kind::Vector: {
-    auto *v = cast<VectorType>(type);
+    auto v = type.cast<VectorType>();
     os << "vector<";
-    for (auto dim : v->getShape())
+    for (auto dim : v.getShape())
       os << dim << 'x';
-    os << *v->getElementType() << '>';
+    os << v.getElementType() << '>';
     return;
   }
   case Type::Kind::RankedTensor: {
-    auto *v = cast<RankedTensorType>(type);
+    auto v = type.cast<RankedTensorType>();
     os << "tensor<";
-    for (auto dim : v->getShape()) {
+    for (auto dim : v.getShape()) {
       if (dim < 0)
         os << '?';
       else
         os << dim;
       os << 'x';
     }
-    os << *v->getElementType() << '>';
+    os << v.getElementType() << '>';
     return;
   }
   case Type::Kind::UnrankedTensor: {
-    auto *v = cast<UnrankedTensorType>(type);
+    auto v = type.cast<UnrankedTensorType>();
     os << "tensor<*x";
-    printType(v->getElementType());
+    printType(v.getElementType());
     os << '>';
     return;
   }
   case Type::Kind::MemRef: {
-    auto *v = cast<MemRefType>(type);
+    auto v = type.cast<MemRefType>();
     os << "memref<";
-    for (auto dim : v->getShape()) {
+    for (auto dim : v.getShape()) {
       if (dim < 0)
         os << '?';
       else
         os << dim;
       os << 'x';
     }
-    printType(v->getElementType());
-    for (auto map : v->getAffineMaps()) {
+    printType(v.getElementType());
+    for (auto map : v.getAffineMaps()) {
       os << ", ";
       printAffineMapReference(map);
     }
     // Only print the memory space if it is the non-default one.
-    if (v->getMemorySpace())
-      os << ", " << v->getMemorySpace();
+    if (v.getMemorySpace())
+      os << ", " << v.getMemorySpace();
     os << '>';
     return;
   }
@@ -842,18 +842,18 @@
 // Function printing
 //===----------------------------------------------------------------------===//
 
-void ModulePrinter::printFunctionResultType(const FunctionType *type) {
-  switch (type->getResults().size()) {
+void ModulePrinter::printFunctionResultType(FunctionType type) {
+  switch (type.getResults().size()) {
   case 0:
     break;
   case 1:
     os << " -> ";
-    printType(type->getResults()[0]);
+    printType(type.getResults()[0]);
     break;
   default:
     os << " -> (";
-    interleaveComma(type->getResults(),
-                    [&](Type *eltType) { printType(eltType); });
+    interleaveComma(type.getResults(),
+                    [&](Type eltType) { printType(eltType); });
     os << ')';
     break;
   }
@@ -871,8 +871,7 @@
   auto type = fn->getType();
 
   os << "@" << fn->getName() << '(';
-  interleaveComma(type->getInputs(),
-                  [&](Type *eltType) { printType(eltType); });
+  interleaveComma(type.getInputs(), [&](Type eltType) { printType(eltType); });
   os << ')';
 
   printFunctionResultType(type);
@@ -937,7 +936,7 @@
 
   // Implement OpAsmPrinter.
   raw_ostream &getStream() const { return os; }
-  void printType(const Type *type) { ModulePrinter::printType(type); }
+  void printType(Type type) { ModulePrinter::printType(type); }
   void printAttribute(Attribute attr) { ModulePrinter::printAttribute(attr); }
   void printAffineMap(AffineMap map) {
     return ModulePrinter::printAffineMapReference(map);
@@ -974,10 +973,10 @@
     if (auto *op = value->getDefiningOperation()) {
       if (auto intOp = op->dyn_cast<ConstantIntOp>()) {
         // i1 constants get special names.
-        if (intOp->getType()->isInteger(1)) {
+        if (intOp->getType().isInteger(1)) {
           specialName << (intOp->getValue() ? "true" : "false");
         } else {
-          specialName << 'c' << intOp->getValue() << '_' << *intOp->getType();
+          specialName << 'c' << intOp->getValue() << '_' << intOp->getType();
         }
       } else if (auto intOp = op->dyn_cast<ConstantIndexOp>()) {
         specialName << 'c' << intOp->getValue();
@@ -1579,7 +1578,7 @@
 
 void Type::print(raw_ostream &os) const {
   ModuleState state(getContext());
-  ModulePrinter(os, state).printType(this);
+  ModulePrinter(os, state).printType(*this);
 }
 
 void Type::dump() const { print(llvm::errs()); }
diff --git a/lib/IR/AttributeDetail.h b/lib/IR/AttributeDetail.h
index a0e9afb..63ad544 100644
--- a/lib/IR/AttributeDetail.h
+++ b/lib/IR/AttributeDetail.h
@@ -26,6 +26,7 @@
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/IntegerSet.h"
 #include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Types.h"
 #include "llvm/Support/TrailingObjects.h"
 
 namespace mlir {
@@ -86,7 +87,7 @@
 
 /// An attribute representing a reference to a type.
 struct TypeAttributeStorage : public AttributeStorage {
-  Type *value;
+  Type value;
 };
 
 /// An attribute representing a reference to a function.
@@ -96,7 +97,7 @@
 
 /// A base attribute representing a reference to a vector or tensor constant.
 struct ElementsAttributeStorage : public AttributeStorage {
-  VectorOrTensorType *type;
+  VectorOrTensorType type;
 };
 
 /// An attribute representing a reference to a vector or tensor constant,
diff --git a/lib/IR/Attributes.cpp b/lib/IR/Attributes.cpp
index 34312b8..58b5b90 100644
--- a/lib/IR/Attributes.cpp
+++ b/lib/IR/Attributes.cpp
@@ -75,9 +75,7 @@
 
 TypeAttr::TypeAttr(Attribute::ImplType *ptr) : Attribute(ptr) {}
 
-Type *TypeAttr::getValue() const {
-  return static_cast<ImplType *>(attr)->value;
-}
+Type TypeAttr::getValue() const { return static_cast<ImplType *>(attr)->value; }
 
 FunctionAttr::FunctionAttr(Attribute::ImplType *ptr) : Attribute(ptr) {}
 
@@ -85,11 +83,11 @@
   return static_cast<ImplType *>(attr)->value;
 }
 
-FunctionType *FunctionAttr::getType() const { return getValue()->getType(); }
+FunctionType FunctionAttr::getType() const { return getValue()->getType(); }
 
 ElementsAttr::ElementsAttr(Attribute::ImplType *ptr) : Attribute(ptr) {}
 
-VectorOrTensorType *ElementsAttr::getType() const {
+VectorOrTensorType ElementsAttr::getType() const {
   return static_cast<ImplType *>(attr)->type;
 }
 
@@ -166,8 +164,8 @@
 
 void DenseIntElementsAttr::getValues(SmallVectorImpl<Attribute> &values) const {
   auto bitsWidth = static_cast<ImplType *>(attr)->bitsWidth;
-  auto elementNum = getType()->getNumElements();
-  auto context = getType()->getContext();
+  auto elementNum = getType().getNumElements();
+  auto context = getType().getContext();
   values.reserve(elementNum);
   if (bitsWidth == 64) {
     ArrayRef<int64_t> vs(
@@ -192,8 +190,8 @@
     : DenseElementsAttr(ptr) {}
 
 void DenseFPElementsAttr::getValues(SmallVectorImpl<Attribute> &values) const {
-  auto elementNum = getType()->getNumElements();
-  auto context = getType()->getContext();
+  auto elementNum = getType().getNumElements();
+  auto context = getType().getContext();
   ArrayRef<double> vs({reinterpret_cast<const double *>(getRawData().data()),
                        getRawData().size() / 8});
   values.reserve(elementNum);
diff --git a/lib/IR/BasicBlock.cpp b/lib/IR/BasicBlock.cpp
index bb8ac75..29a5ce1 100644
--- a/lib/IR/BasicBlock.cpp
+++ b/lib/IR/BasicBlock.cpp
@@ -33,18 +33,18 @@
 // Argument list management.
 //===----------------------------------------------------------------------===//
 
-BBArgument *BasicBlock::addArgument(Type *type) {
+BBArgument *BasicBlock::addArgument(Type type) {
   auto *arg = new BBArgument(type, this);
   arguments.push_back(arg);
   return arg;
 }
 
 /// Add one argument to the argument list for each type specified in the list.
-auto BasicBlock::addArguments(ArrayRef<Type *> types)
+auto BasicBlock::addArguments(ArrayRef<Type> types)
     -> llvm::iterator_range<args_iterator> {
   arguments.reserve(arguments.size() + types.size());
   auto initialSize = arguments.size();
-  for (auto *type : types) {
+  for (auto type : types) {
     addArgument(type);
   }
   return {arguments.data() + initialSize, arguments.data() + arguments.size()};
diff --git a/lib/IR/Builders.cpp b/lib/IR/Builders.cpp
index 22d749a..906b580 100644
--- a/lib/IR/Builders.cpp
+++ b/lib/IR/Builders.cpp
@@ -52,59 +52,58 @@
 // Types.
 //===----------------------------------------------------------------------===//
 
-FloatType *Builder::getBF16Type() { return Type::getBF16(context); }
+FloatType Builder::getBF16Type() { return Type::getBF16(context); }
 
-FloatType *Builder::getF16Type() { return Type::getF16(context); }
+FloatType Builder::getF16Type() { return Type::getF16(context); }
 
-FloatType *Builder::getF32Type() { return Type::getF32(context); }
+FloatType Builder::getF32Type() { return Type::getF32(context); }
 
-FloatType *Builder::getF64Type() { return Type::getF64(context); }
+FloatType Builder::getF64Type() { return Type::getF64(context); }
 
-OtherType *Builder::getIndexType() { return Type::getIndex(context); }
+OtherType Builder::getIndexType() { return Type::getIndex(context); }
 
-OtherType *Builder::getTFControlType() { return Type::getTFControl(context); }
+OtherType Builder::getTFControlType() { return Type::getTFControl(context); }
 
-OtherType *Builder::getTFResourceType() { return Type::getTFResource(context); }
+OtherType Builder::getTFResourceType() { return Type::getTFResource(context); }
 
-OtherType *Builder::getTFVariantType() { return Type::getTFVariant(context); }
+OtherType Builder::getTFVariantType() { return Type::getTFVariant(context); }
 
-OtherType *Builder::getTFComplex64Type() {
+OtherType Builder::getTFComplex64Type() {
   return Type::getTFComplex64(context);
 }
 
-OtherType *Builder::getTFComplex128Type() {
+OtherType Builder::getTFComplex128Type() {
   return Type::getTFComplex128(context);
 }
 
-OtherType *Builder::getTFF32REFType() { return Type::getTFF32REF(context); }
+OtherType Builder::getTFF32REFType() { return Type::getTFF32REF(context); }
 
-OtherType *Builder::getTFStringType() { return Type::getTFString(context); }
+OtherType Builder::getTFStringType() { return Type::getTFString(context); }
 
-IntegerType *Builder::getIntegerType(unsigned width) {
+IntegerType Builder::getIntegerType(unsigned width) {
   return Type::getInteger(width, context);
 }
 
-FunctionType *Builder::getFunctionType(ArrayRef<Type *> inputs,
-                                       ArrayRef<Type *> results) {
+FunctionType Builder::getFunctionType(ArrayRef<Type> inputs,
+                                      ArrayRef<Type> results) {
   return FunctionType::get(inputs, results, context);
 }
 
-MemRefType *Builder::getMemRefType(ArrayRef<int> shape, Type *elementType,
-                                   ArrayRef<AffineMap> affineMapComposition,
-                                   unsigned memorySpace) {
+MemRefType Builder::getMemRefType(ArrayRef<int> shape, Type elementType,
+                                  ArrayRef<AffineMap> affineMapComposition,
+                                  unsigned memorySpace) {
   return MemRefType::get(shape, elementType, affineMapComposition, memorySpace);
 }
 
-VectorType *Builder::getVectorType(ArrayRef<int> shape, Type *elementType) {
+VectorType Builder::getVectorType(ArrayRef<int> shape, Type elementType) {
   return VectorType::get(shape, elementType);
 }
 
-RankedTensorType *Builder::getTensorType(ArrayRef<int> shape,
-                                         Type *elementType) {
+RankedTensorType Builder::getTensorType(ArrayRef<int> shape, Type elementType) {
   return RankedTensorType::get(shape, elementType);
 }
 
-UnrankedTensorType *Builder::getTensorType(Type *elementType) {
+UnrankedTensorType Builder::getTensorType(Type elementType) {
   return UnrankedTensorType::get(elementType);
 }
 
@@ -144,7 +143,7 @@
   return IntegerSetAttr::get(set);
 }
 
-TypeAttr Builder::getTypeAttr(Type *type) {
+TypeAttr Builder::getTypeAttr(Type type) {
   return TypeAttr::get(type, context);
 }
 
@@ -152,23 +151,23 @@
   return FunctionAttr::get(value, context);
 }
 
-ElementsAttr Builder::getSplatElementsAttr(VectorOrTensorType *type,
+ElementsAttr Builder::getSplatElementsAttr(VectorOrTensorType type,
                                            Attribute elt) {
   return SplatElementsAttr::get(type, elt);
 }
 
-ElementsAttr Builder::getDenseElementsAttr(VectorOrTensorType *type,
+ElementsAttr Builder::getDenseElementsAttr(VectorOrTensorType type,
                                            ArrayRef<char> data) {
   return DenseElementsAttr::get(type, data);
 }
 
-ElementsAttr Builder::getSparseElementsAttr(VectorOrTensorType *type,
+ElementsAttr Builder::getSparseElementsAttr(VectorOrTensorType type,
                                             DenseIntElementsAttr indices,
                                             DenseElementsAttr values) {
   return SparseElementsAttr::get(type, indices, values);
 }
 
-ElementsAttr Builder::getOpaqueElementsAttr(VectorOrTensorType *type,
+ElementsAttr Builder::getOpaqueElementsAttr(VectorOrTensorType type,
                                             StringRef bytes) {
   return OpaqueElementsAttr::get(type, bytes);
 }
@@ -296,7 +295,7 @@
 OperationStmt *MLFuncBuilder::createOperation(Location *location,
                                               OperationName name,
                                               ArrayRef<MLValue *> operands,
-                                              ArrayRef<Type *> types,
+                                              ArrayRef<Type> types,
                                               ArrayRef<NamedAttribute> attrs) {
   auto *op = OperationStmt::create(location, name, operands, types, attrs,
                                    getContext());
diff --git a/lib/IR/BuiltinOps.cpp b/lib/IR/BuiltinOps.cpp
index 542e67e..e4bca03 100644
--- a/lib/IR/BuiltinOps.cpp
+++ b/lib/IR/BuiltinOps.cpp
@@ -63,7 +63,7 @@
   numDims = opInfos.size();
 
   // Parse the optional symbol operands.
-  auto *affineIntTy = parser->getBuilder().getIndexType();
+  auto affineIntTy = parser->getBuilder().getIndexType();
   if (parser->parseOperandList(opInfos, -1,
                                OpAsmParser::Delimiter::OptionalSquare) ||
       parser->resolveOperands(opInfos, affineIntTy, operands))
@@ -84,7 +84,7 @@
 
 bool AffineApplyOp::parse(OpAsmParser *parser, OperationState *result) {
   auto &builder = parser->getBuilder();
-  auto *affineIntTy = builder.getIndexType();
+  auto affineIntTy = builder.getIndexType();
 
   AffineMapAttr mapAttr;
   unsigned numDims;
@@ -171,7 +171,7 @@
 
 /// Builds a constant op with the specified attribute value and result type.
 void ConstantOp::build(Builder *builder, OperationState *result,
-                       Attribute value, Type *type) {
+                       Attribute value, Type type) {
   result->addAttribute("value", value);
   result->types.push_back(type);
 }
@@ -181,12 +181,12 @@
   p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"value");
 
   if (!getValue().isa<FunctionAttr>())
-    *p << " : " << *getType();
+    *p << " : " << getType();
 }
 
 bool ConstantOp::parse(OpAsmParser *parser, OperationState *result) {
   Attribute valueAttr;
-  Type *type;
+  Type type;
 
   if (parser->parseAttribute(valueAttr, "value", result->attributes) ||
       parser->parseOptionalAttributeDict(result->attributes))
@@ -208,33 +208,33 @@
   if (!value)
     return emitOpError("requires a 'value' attribute");
 
-  auto *type = this->getType();
-  if (isa<IntegerType>(type) || type->isIndex()) {
+  auto type = this->getType();
+  if (type.isa<IntegerType>() || type.isIndex()) {
     if (!value.isa<IntegerAttr>())
       return emitOpError(
           "requires 'value' to be an integer for an integer result type");
     return false;
   }
 
-  if (isa<FloatType>(type)) {
+  if (type.isa<FloatType>()) {
     if (!value.isa<FloatAttr>())
       return emitOpError("requires 'value' to be a floating point constant");
     return false;
   }
 
-  if (isa<VectorOrTensorType>(type)) {
+  if (type.isa<VectorOrTensorType>()) {
     if (!value.isa<ElementsAttr>())
       return emitOpError("requires 'value' to be a vector/tensor constant");
     return false;
   }
 
-  if (type->isTFString()) {
+  if (type.isTFString()) {
     if (!value.isa<StringAttr>())
       return emitOpError("requires 'value' to be a string constant");
     return false;
   }
 
-  if (isa<FunctionType>(type)) {
+  if (type.isa<FunctionType>()) {
     if (!value.isa<FunctionAttr>())
       return emitOpError("requires 'value' to be a function reference");
     return false;
@@ -251,19 +251,19 @@
 }
 
 void ConstantFloatOp::build(Builder *builder, OperationState *result,
-                            const APFloat &value, FloatType *type) {
+                            const APFloat &value, FloatType type) {
   ConstantOp::build(builder, result, builder->getFloatAttr(value), type);
 }
 
 bool ConstantFloatOp::isClassFor(const Operation *op) {
   return ConstantOp::isClassFor(op) &&
-         isa<FloatType>(op->getResult(0)->getType());
+         op->getResult(0)->getType().isa<FloatType>();
 }
 
 /// ConstantIntOp only matches values whose result type is an IntegerType.
 bool ConstantIntOp::isClassFor(const Operation *op) {
   return ConstantOp::isClassFor(op) &&
-         isa<IntegerType>(op->getResult(0)->getType());
+         op->getResult(0)->getType().isa<IntegerType>();
 }
 
 void ConstantIntOp::build(Builder *builder, OperationState *result,
@@ -275,14 +275,14 @@
 /// Build a constant int op producing an integer with the specified type,
 /// which must be an integer type.
 void ConstantIntOp::build(Builder *builder, OperationState *result,
-                          int64_t value, Type *type) {
-  assert(isa<IntegerType>(type) && "ConstantIntOp can only have integer type");
+                          int64_t value, Type type) {
+  assert(type.isa<IntegerType>() && "ConstantIntOp can only have integer type");
   ConstantOp::build(builder, result, builder->getIntegerAttr(value), type);
 }
 
 /// ConstantIndexOp only matches values whose result type is Index.
 bool ConstantIndexOp::isClassFor(const Operation *op) {
-  return ConstantOp::isClassFor(op) && op->getResult(0)->getType()->isIndex();
+  return ConstantOp::isClassFor(op) && op->getResult(0)->getType().isIndex();
 }
 
 void ConstantIndexOp::build(Builder *builder, OperationState *result,
@@ -302,7 +302,7 @@
 
 bool ReturnOp::parse(OpAsmParser *parser, OperationState *result) {
   SmallVector<OpAsmParser::OperandType, 2> opInfo;
-  SmallVector<Type *, 2> types;
+  SmallVector<Type, 2> types;
   llvm::SMLoc loc;
   return parser->getCurrentLocation(&loc) || parser->parseOperandList(opInfo) ||
          (!opInfo.empty() && parser->parseColonTypeList(types)) ||
@@ -330,7 +330,7 @@
 
     // The operand number and types must match the function signature.
     MLFunction *function = cast<MLFunction>(block);
-    const auto &results = function->getType()->getResults();
+    const auto &results = function->getType().getResults();
     if (stmt->getNumOperands() != results.size())
       return emitOpError("has " + Twine(stmt->getNumOperands()) +
                          " operands, but enclosing function returns " +
diff --git a/lib/IR/Function.cpp b/lib/IR/Function.cpp
index efeb16b..70c0e12 100644
--- a/lib/IR/Function.cpp
+++ b/lib/IR/Function.cpp
@@ -28,8 +28,8 @@
 using namespace mlir;
 
 Function::Function(Kind kind, Location *location, StringRef name,
-                   FunctionType *type, ArrayRef<NamedAttribute> attrs)
-    : nameAndKind(Identifier::get(name, type->getContext()), kind),
+                   FunctionType type, ArrayRef<NamedAttribute> attrs)
+    : nameAndKind(Identifier::get(name, type.getContext()), kind),
       location(location), type(type) {
   this->attrs = AttributeListStorage::get(attrs, getContext());
 }
@@ -46,7 +46,7 @@
     return {};
 }
 
-MLIRContext *Function::getContext() const { return getType()->getContext(); }
+MLIRContext *Function::getContext() const { return getType().getContext(); }
 
 /// Delete this object.
 void Function::destroy() {
@@ -159,7 +159,7 @@
 // ExtFunction implementation.
 //===----------------------------------------------------------------------===//
 
-ExtFunction::ExtFunction(Location *location, StringRef name, FunctionType *type,
+ExtFunction::ExtFunction(Location *location, StringRef name, FunctionType type,
                          ArrayRef<NamedAttribute> attrs)
     : Function(Kind::ExtFunc, location, name, type, attrs) {}
 
@@ -167,7 +167,7 @@
 // CFGFunction implementation.
 //===----------------------------------------------------------------------===//
 
-CFGFunction::CFGFunction(Location *location, StringRef name, FunctionType *type,
+CFGFunction::CFGFunction(Location *location, StringRef name, FunctionType type,
                          ArrayRef<NamedAttribute> attrs)
     : Function(Kind::CFGFunc, location, name, type, attrs) {}
 
@@ -188,9 +188,9 @@
 
 /// Create a new MLFunction with the specific fields.
 MLFunction *MLFunction::create(Location *location, StringRef name,
-                               FunctionType *type,
+                               FunctionType type,
                                ArrayRef<NamedAttribute> attrs) {
-  const auto &argTypes = type->getInputs();
+  const auto &argTypes = type.getInputs();
   auto byteSize = totalSizeToAlloc<MLFuncArgument>(argTypes.size());
   void *rawMem = malloc(byteSize);
 
@@ -204,7 +204,7 @@
   return function;
 }
 
-MLFunction::MLFunction(Location *location, StringRef name, FunctionType *type,
+MLFunction::MLFunction(Location *location, StringRef name, FunctionType type,
                        ArrayRef<NamedAttribute> attrs)
     : Function(Kind::MLFunc, location, name, type, attrs),
       StmtBlock(StmtBlockKind::MLFunc) {}
diff --git a/lib/IR/Instructions.cpp b/lib/IR/Instructions.cpp
index 422636b..d2f49dd 100644
--- a/lib/IR/Instructions.cpp
+++ b/lib/IR/Instructions.cpp
@@ -143,7 +143,7 @@
 /// Create a new OperationInst with the specified fields.
 OperationInst *OperationInst::create(Location *location, OperationName name,
                                      ArrayRef<CFGValue *> operands,
-                                     ArrayRef<Type *> resultTypes,
+                                     ArrayRef<Type> resultTypes,
                                      ArrayRef<NamedAttribute> attributes,
                                      MLIRContext *context) {
   auto byteSize = totalSizeToAlloc<InstOperand, InstResult>(operands.size(),
@@ -167,7 +167,7 @@
 
 OperationInst *OperationInst::clone() const {
   SmallVector<CFGValue *, 8> operands;
-  SmallVector<Type *, 8> resultTypes;
+  SmallVector<Type, 8> resultTypes;
 
   // Put together the operands and results.
   for (auto *operand : getOperands())
diff --git a/lib/IR/MLIRContext.cpp b/lib/IR/MLIRContext.cpp
index 0a2e941..8811f7b 100644
--- a/lib/IR/MLIRContext.cpp
+++ b/lib/IR/MLIRContext.cpp
@@ -21,6 +21,7 @@
 #include "AttributeDetail.h"
 #include "AttributeListStorage.h"
 #include "IntegerSetDetail.h"
+#include "TypeDetail.h"
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Attributes.h"
@@ -44,11 +45,11 @@
 using namespace llvm;
 
 namespace {
-struct FunctionTypeKeyInfo : DenseMapInfo<FunctionType *> {
+struct FunctionTypeKeyInfo : DenseMapInfo<FunctionTypeStorage *> {
   // Functions are uniqued based on their inputs and results.
-  using KeyTy = std::pair<ArrayRef<Type *>, ArrayRef<Type *>>;
-  using DenseMapInfo<FunctionType *>::getHashValue;
-  using DenseMapInfo<FunctionType *>::isEqual;
+  using KeyTy = std::pair<ArrayRef<Type>, ArrayRef<Type>>;
+  using DenseMapInfo<FunctionTypeStorage *>::getHashValue;
+  using DenseMapInfo<FunctionTypeStorage *>::isEqual;
 
   static unsigned getHashValue(KeyTy key) {
     return hash_combine(
@@ -56,7 +57,7 @@
         hash_combine_range(key.second.begin(), key.second.end()));
   }
 
-  static bool isEqual(const KeyTy &lhs, const FunctionType *rhs) {
+  static bool isEqual(const KeyTy &lhs, const FunctionTypeStorage *rhs) {
     if (rhs == getEmptyKey() || rhs == getTombstoneKey())
       return false;
     return lhs == KeyTy(rhs->getInputs(), rhs->getResults());
@@ -109,65 +110,64 @@
   }
 };
 
-struct VectorTypeKeyInfo : DenseMapInfo<VectorType *> {
+struct VectorTypeKeyInfo : DenseMapInfo<VectorTypeStorage *> {
   // Vectors are uniqued based on their element type and shape.
-  using KeyTy = std::pair<Type *, ArrayRef<int>>;
-  using DenseMapInfo<VectorType *>::getHashValue;
-  using DenseMapInfo<VectorType *>::isEqual;
+  using KeyTy = std::pair<Type, ArrayRef<int>>;
+  using DenseMapInfo<VectorTypeStorage *>::getHashValue;
+  using DenseMapInfo<VectorTypeStorage *>::isEqual;
 
   static unsigned getHashValue(KeyTy key) {
     return hash_combine(
-        DenseMapInfo<Type *>::getHashValue(key.first),
+        DenseMapInfo<Type>::getHashValue(key.first),
         hash_combine_range(key.second.begin(), key.second.end()));
   }
 
-  static bool isEqual(const KeyTy &lhs, const VectorType *rhs) {
+  static bool isEqual(const KeyTy &lhs, const VectorTypeStorage *rhs) {
     if (rhs == getEmptyKey() || rhs == getTombstoneKey())
       return false;
-    return lhs == KeyTy(rhs->getElementType(), rhs->getShape());
+    return lhs == KeyTy(rhs->elementType, rhs->getShape());
   }
 };
 
-struct RankedTensorTypeKeyInfo : DenseMapInfo<RankedTensorType *> {
+struct RankedTensorTypeKeyInfo : DenseMapInfo<RankedTensorTypeStorage *> {
   // Ranked tensors are uniqued based on their element type and shape.
-  using KeyTy = std::pair<Type *, ArrayRef<int>>;
-  using DenseMapInfo<RankedTensorType *>::getHashValue;
-  using DenseMapInfo<RankedTensorType *>::isEqual;
+  using KeyTy = std::pair<Type, ArrayRef<int>>;
+  using DenseMapInfo<RankedTensorTypeStorage *>::getHashValue;
+  using DenseMapInfo<RankedTensorTypeStorage *>::isEqual;
 
   static unsigned getHashValue(KeyTy key) {
     return hash_combine(
-        DenseMapInfo<Type *>::getHashValue(key.first),
+        DenseMapInfo<Type>::getHashValue(key.first),
         hash_combine_range(key.second.begin(), key.second.end()));
   }
 
-  static bool isEqual(const KeyTy &lhs, const RankedTensorType *rhs) {
+  static bool isEqual(const KeyTy &lhs, const RankedTensorTypeStorage *rhs) {
     if (rhs == getEmptyKey() || rhs == getTombstoneKey())
       return false;
-    return lhs == KeyTy(rhs->getElementType(), rhs->getShape());
+    return lhs == KeyTy(rhs->elementType, rhs->getShape());
   }
 };
 
-struct MemRefTypeKeyInfo : DenseMapInfo<MemRefType *> {
+struct MemRefTypeKeyInfo : DenseMapInfo<MemRefTypeStorage *> {
   // MemRefs are uniqued based on their element type, shape, affine map
   // composition, and memory space.
-  using KeyTy =
-      std::tuple<Type *, ArrayRef<int>, ArrayRef<AffineMap>, unsigned>;
-  using DenseMapInfo<MemRefType *>::getHashValue;
-  using DenseMapInfo<MemRefType *>::isEqual;
+  using KeyTy = std::tuple<Type, ArrayRef<int>, ArrayRef<AffineMap>, unsigned>;
+  using DenseMapInfo<MemRefTypeStorage *>::getHashValue;
+  using DenseMapInfo<MemRefTypeStorage *>::isEqual;
 
   static unsigned getHashValue(KeyTy key) {
     return hash_combine(
-        DenseMapInfo<Type *>::getHashValue(std::get<0>(key)),
+        DenseMapInfo<Type>::getHashValue(std::get<0>(key)),
         hash_combine_range(std::get<1>(key).begin(), std::get<1>(key).end()),
         hash_combine_range(std::get<2>(key).begin(), std::get<2>(key).end()),
         std::get<3>(key));
   }
 
-  static bool isEqual(const KeyTy &lhs, const MemRefType *rhs) {
+  static bool isEqual(const KeyTy &lhs, const MemRefTypeStorage *rhs) {
     if (rhs == getEmptyKey() || rhs == getTombstoneKey())
       return false;
-    return lhs == std::make_tuple(rhs->getElementType(), rhs->getShape(),
-                                  rhs->getAffineMaps(), rhs->getMemorySpace());
+    return lhs == std::make_tuple(rhs->elementType, rhs->getShape(),
+                                  rhs->getAffineMaps(), rhs->memorySpace);
   }
 };
 
@@ -221,7 +221,7 @@
 };
 
 struct DenseElementsAttrInfo : DenseMapInfo<DenseElementsAttributeStorage *> {
-  using KeyTy = std::pair<VectorOrTensorType *, ArrayRef<char>>;
+  using KeyTy = std::pair<VectorOrTensorType, ArrayRef<char>>;
   using DenseMapInfo<DenseElementsAttributeStorage *>::getHashValue;
   using DenseMapInfo<DenseElementsAttributeStorage *>::isEqual;
 
@@ -239,7 +239,7 @@
 };
 
 struct OpaqueElementsAttrInfo : DenseMapInfo<OpaqueElementsAttributeStorage *> {
-  using KeyTy = std::pair<VectorOrTensorType *, StringRef>;
+  using KeyTy = std::pair<VectorOrTensorType, StringRef>;
   using DenseMapInfo<OpaqueElementsAttributeStorage *>::getHashValue;
   using DenseMapInfo<OpaqueElementsAttributeStorage *>::isEqual;
 
@@ -295,13 +295,14 @@
   llvm::StringMap<char, llvm::BumpPtrAllocator &> identifiers;
 
   // Uniquing table for 'other' types.
-  OtherType *otherTypes[int(Type::Kind::LAST_OTHER_TYPE) -
-                        int(Type::Kind::FIRST_OTHER_TYPE) + 1] = {nullptr};
+  OtherTypeStorage *otherTypes[int(Type::Kind::LAST_OTHER_TYPE) -
+                               int(Type::Kind::FIRST_OTHER_TYPE) + 1] = {
+      nullptr};
 
   // Uniquing table for 'float' types.
-  FloatType *floatTypes[int(Type::Kind::LAST_FLOATING_POINT_TYPE) -
-                        int(Type::Kind::FIRST_FLOATING_POINT_TYPE) + 1] = {
-      nullptr};
+  FloatTypeStorage *floatTypes[int(Type::Kind::LAST_FLOATING_POINT_TYPE) -
+                               int(Type::Kind::FIRST_FLOATING_POINT_TYPE) + 1] =
+      {nullptr};
 
   // Affine map uniquing.
   using AffineMapSet = DenseSet<AffineMap, AffineMapKeyInfo>;
@@ -324,26 +325,26 @@
   DenseMap<int64_t, AffineConstantExprStorage *> constExprs;
 
   /// Integer type uniquing.
-  DenseMap<unsigned, IntegerType *> integers;
+  DenseMap<unsigned, IntegerTypeStorage *> integers;
 
   /// Function type uniquing.
-  using FunctionTypeSet = DenseSet<FunctionType *, FunctionTypeKeyInfo>;
+  using FunctionTypeSet = DenseSet<FunctionTypeStorage *, FunctionTypeKeyInfo>;
   FunctionTypeSet functions;
 
   /// Vector type uniquing.
-  using VectorTypeSet = DenseSet<VectorType *, VectorTypeKeyInfo>;
+  using VectorTypeSet = DenseSet<VectorTypeStorage *, VectorTypeKeyInfo>;
   VectorTypeSet vectors;
 
   /// Ranked tensor type uniquing.
   using RankedTensorTypeSet =
-      DenseSet<RankedTensorType *, RankedTensorTypeKeyInfo>;
+      DenseSet<RankedTensorTypeStorage *, RankedTensorTypeKeyInfo>;
   RankedTensorTypeSet rankedTensors;
 
   /// Unranked tensor type uniquing.
-  DenseMap<Type *, UnrankedTensorType *> unrankedTensors;
+  DenseMap<Type, UnrankedTensorTypeStorage *> unrankedTensors;
 
   /// MemRef type uniquing.
-  using MemRefTypeSet = DenseSet<MemRefType *, MemRefTypeKeyInfo>;
+  using MemRefTypeSet = DenseSet<MemRefTypeStorage *, MemRefTypeKeyInfo>;
   MemRefTypeSet memrefs;
 
   // Attribute uniquing.
@@ -355,13 +356,12 @@
   ArrayAttrSet arrayAttrs;
   DenseMap<AffineMap, AffineMapAttributeStorage *> affineMapAttrs;
   DenseMap<IntegerSet, IntegerSetAttributeStorage *> integerSetAttrs;
-  DenseMap<Type *, TypeAttributeStorage *> typeAttrs;
+  DenseMap<Type, TypeAttributeStorage *> typeAttrs;
   using AttributeListSet =
       DenseSet<AttributeListStorage *, AttributeListKeyInfo>;
   AttributeListSet attributeLists;
   DenseMap<const Function *, FunctionAttributeStorage *> functionAttrs;
-  DenseMap<std::pair<VectorOrTensorType *, Attribute>,
-           SplatElementsAttributeStorage *>
+  DenseMap<std::pair<Type, Attribute>, SplatElementsAttributeStorage *>
       splatElementsAttrs;
   using DenseElementsAttrSet =
       DenseSet<DenseElementsAttributeStorage *, DenseElementsAttrInfo>;
@@ -369,7 +369,7 @@
   using OpaqueElementsAttrSet =
       DenseSet<OpaqueElementsAttributeStorage *, OpaqueElementsAttrInfo>;
   OpaqueElementsAttrSet opaqueElementsAttrs;
-  DenseMap<std::tuple<Type *, Attribute, Attribute>,
+  DenseMap<std::tuple<Type, Attribute, Attribute>,
            SparseElementsAttributeStorage *>
       sparseElementsAttrs;
 
@@ -556,19 +556,20 @@
 // Type uniquing
 //===----------------------------------------------------------------------===//
 
-IntegerType *IntegerType::get(unsigned width, MLIRContext *context) {
+IntegerType IntegerType::get(unsigned width, MLIRContext *context) {
+  assert(width <= kMaxWidth && "admissible integer bitwidth exceeded");
   auto &impl = context->getImpl();
 
   auto *&result = impl.integers[width];
   if (!result) {
-    result = impl.allocator.Allocate<IntegerType>();
-    new (result) IntegerType(width, context);
+    result = impl.allocator.Allocate<IntegerTypeStorage>();
+    new (result) IntegerTypeStorage{{Kind::Integer, context}, width};
   }
 
   return result;
 }
 
-FloatType *FloatType::get(Kind kind, MLIRContext *context) {
+FloatType FloatType::get(Kind kind, MLIRContext *context) {
   assert(kind >= Kind::FIRST_FLOATING_POINT_TYPE &&
          kind <= Kind::LAST_FLOATING_POINT_TYPE && "Not an FP type kind");
   auto &impl = context->getImpl();
@@ -580,16 +581,16 @@
     return entry;
 
   // On the first use, we allocate them into the bump pointer.
-  auto *ptr = impl.allocator.Allocate<FloatType>();
+  auto *ptr = impl.allocator.Allocate<FloatTypeStorage>();
 
   // Initialize the memory using placement new.
-  new (ptr) FloatType(kind, context);
+  new (ptr) FloatTypeStorage{{kind, context}};
 
   // Cache and return it.
   return entry = ptr;
 }
 
-OtherType *OtherType::get(Kind kind, MLIRContext *context) {
+OtherType OtherType::get(Kind kind, MLIRContext *context) {
   assert(kind >= Kind::FIRST_OTHER_TYPE && kind <= Kind::LAST_OTHER_TYPE &&
          "Not an 'other' type kind");
   auto &impl = context->getImpl();
@@ -600,18 +601,17 @@
     return entry;
 
   // On the first use, we allocate them into the bump pointer.
-  auto *ptr = impl.allocator.Allocate<OtherType>();
+  auto *ptr = impl.allocator.Allocate<OtherTypeStorage>();
 
   // Initialize the memory using placement new.
-  new (ptr) OtherType(kind, context);
+  new (ptr) OtherTypeStorage{{kind, context}};
 
   // Cache and return it.
   return entry = ptr;
 }
 
-FunctionType *FunctionType::get(ArrayRef<Type *> inputs,
-                                ArrayRef<Type *> results,
-                                MLIRContext *context) {
+FunctionType FunctionType::get(ArrayRef<Type> inputs, ArrayRef<Type> results,
+                               MLIRContext *context) {
   auto &impl = context->getImpl();
 
   // Look to see if we already have this function type.
@@ -623,32 +623,34 @@
     return *existing.first;
 
   // On the first use, we allocate them into the bump pointer.
-  auto *result = impl.allocator.Allocate<FunctionType>();
+  auto *result = impl.allocator.Allocate<FunctionTypeStorage>();
 
   // Copy the inputs and results into the bump pointer.
-  SmallVector<Type *, 16> types;
+  SmallVector<Type, 16> types;
   types.reserve(inputs.size() + results.size());
   types.append(inputs.begin(), inputs.end());
   types.append(results.begin(), results.end());
-  auto typesList = impl.copyInto(ArrayRef<Type *>(types));
+  auto typesList = impl.copyInto(ArrayRef<Type>(types));
 
   // Initialize the memory using placement new.
-  new (result)
-      FunctionType(typesList.data(), inputs.size(), results.size(), context);
+  new (result) FunctionTypeStorage{
+      {Kind::Function, context, static_cast<unsigned int>(inputs.size())},
+      static_cast<unsigned int>(results.size()),
+      typesList.data()};
 
   // Cache and return it.
   return *existing.first = result;
 }
 
-VectorType *VectorType::get(ArrayRef<int> shape, Type *elementType) {
+VectorType VectorType::get(ArrayRef<int> shape, Type elementType) {
   assert(!shape.empty() && "vector types must have at least one dimension");
-  assert((isa<FloatType>(elementType) || isa<IntegerType>(elementType)) &&
+  assert((elementType.isa<FloatType>() || elementType.isa<IntegerType>()) &&
          "vectors elements must be primitives");
   assert(!std::any_of(shape.begin(), shape.end(), [](int i) {
     return i < 0;
   }) && "vector types must have static shape");
 
-  auto *context = elementType->getContext();
+  auto *context = elementType.getContext();
   auto &impl = context->getImpl();
 
   // Look to see if we already have this vector type.
@@ -660,21 +662,23 @@
     return *existing.first;
 
   // On the first use, we allocate them into the bump pointer.
-  auto *result = impl.allocator.Allocate<VectorType>();
+  auto *result = impl.allocator.Allocate<VectorTypeStorage>();
 
   // Copy the shape into the bump pointer.
   shape = impl.copyInto(shape);
 
   // Initialize the memory using placement new.
-  new (result) VectorType(shape, elementType, context);
+  new (result) VectorTypeStorage{
+      {{Kind::Vector, context, static_cast<unsigned int>(shape.size())},
+       elementType},
+      shape.data()};
 
   // Cache and return it.
   return *existing.first = result;
 }
 
-RankedTensorType *RankedTensorType::get(ArrayRef<int> shape,
-                                        Type *elementType) {
-  auto *context = elementType->getContext();
+RankedTensorType RankedTensorType::get(ArrayRef<int> shape, Type elementType) {
+  auto *context = elementType.getContext();
   auto &impl = context->getImpl();
 
   // Look to see if we already have this ranked tensor type.
@@ -686,20 +690,23 @@
     return *existing.first;
 
   // On the first use, we allocate them into the bump pointer.
-  auto *result = impl.allocator.Allocate<RankedTensorType>();
+  auto *result = impl.allocator.Allocate<RankedTensorTypeStorage>();
 
   // Copy the shape into the bump pointer.
   shape = impl.copyInto(shape);
 
   // Initialize the memory using placement new.
-  new (result) RankedTensorType(shape, elementType, context);
+  new (result) RankedTensorTypeStorage{
+      {{{Kind::RankedTensor, context, static_cast<unsigned int>(shape.size())},
+        elementType}},
+      shape.data()};
 
   // Cache and return it.
   return *existing.first = result;
 }
 
-UnrankedTensorType *UnrankedTensorType::get(Type *elementType) {
-  auto *context = elementType->getContext();
+UnrankedTensorType UnrankedTensorType::get(Type elementType) {
+  auto *context = elementType.getContext();
   auto &impl = context->getImpl();
 
   // Look to see if we already have this unranked tensor type.
@@ -710,17 +717,18 @@
     return result;
 
   // On the first use, we allocate them into the bump pointer.
-  result = impl.allocator.Allocate<UnrankedTensorType>();
+  result = impl.allocator.Allocate<UnrankedTensorTypeStorage>();
 
   // Initialize the memory using placement new.
-  new (result) UnrankedTensorType(elementType, context);
+  new (result) UnrankedTensorTypeStorage{
+      {{{Kind::UnrankedTensor, context}, elementType}}};
   return result;
 }
 
-MemRefType *MemRefType::get(ArrayRef<int> shape, Type *elementType,
-                            ArrayRef<AffineMap> affineMapComposition,
-                            unsigned memorySpace) {
-  auto *context = elementType->getContext();
+MemRefType MemRefType::get(ArrayRef<int> shape, Type elementType,
+                           ArrayRef<AffineMap> affineMapComposition,
+                           unsigned memorySpace) {
+  auto *context = elementType.getContext();
   auto &impl = context->getImpl();
 
   // Drop the unbounded identity maps from the composition.
@@ -744,7 +752,7 @@
     return *existing.first;
 
   // On the first use, we allocate them into the bump pointer.
-  auto *result = impl.allocator.Allocate<MemRefType>();
+  auto *result = impl.allocator.Allocate<MemRefTypeStorage>();
 
   // Copy the shape into the bump pointer.
   shape = impl.copyInto(shape);
@@ -755,8 +763,13 @@
       impl.copyInto(ArrayRef<AffineMap>(affineMapComposition));
 
   // Initialize the memory using placement new.
-  new (result) MemRefType(shape, elementType, affineMapComposition, memorySpace,
-                          context);
+  new (result) MemRefTypeStorage{
+      {Kind::MemRef, context, static_cast<unsigned int>(shape.size())},
+      elementType,
+      shape.data(),
+      static_cast<unsigned int>(affineMapComposition.size()),
+      affineMapComposition.data(),
+      memorySpace};
   // Cache and return it.
   return *existing.first = result;
 }
@@ -895,7 +908,7 @@
   return result;
 }
 
-TypeAttr TypeAttr::get(Type *type, MLIRContext *context) {
+TypeAttr TypeAttr::get(Type type, MLIRContext *context) {
   auto *&result = context->getImpl().typeAttrs[type];
   if (result)
     return result;
@@ -1009,9 +1022,9 @@
   return *existing.first = result;
 }
 
-SplatElementsAttr SplatElementsAttr::get(VectorOrTensorType *type,
+SplatElementsAttr SplatElementsAttr::get(VectorOrTensorType type,
                                          Attribute elt) {
-  auto &impl = type->getContext()->getImpl();
+  auto &impl = type.getContext()->getImpl();
 
   // Look to see if we already have this.
   auto *&result = impl.splatElementsAttrs[{type, elt}];
@@ -1030,14 +1043,14 @@
   return result;
 }
 
-DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType *type,
+DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType type,
                                          ArrayRef<char> data) {
-  auto bitsRequired = (long)type->getBitWidth() * type->getNumElements();
+  auto bitsRequired = (long)type.getBitWidth() * type.getNumElements();
   (void)bitsRequired;
   assert((bitsRequired <= data.size() * 8L) &&
          "Input data bit size should be larger than that type requires");
 
-  auto &impl = type->getContext()->getImpl();
+  auto &impl = type.getContext()->getImpl();
 
   // Look to see if this constant is already defined.
   DenseElementsAttrInfo::KeyTy key({type, data});
@@ -1048,8 +1061,8 @@
     return *existing.first;
 
   // Otherwise, allocate a new one, unique it and return it.
-  auto *eltType = type->getElementType();
-  switch (eltType->getKind()) {
+  auto eltType = type.getElementType();
+  switch (eltType.getKind()) {
   case Type::Kind::BF16:
   case Type::Kind::F16:
   case Type::Kind::F32:
@@ -1064,7 +1077,7 @@
     return *existing.first = result;
   }
   case Type::Kind::Integer: {
-    auto width = ::cast<IntegerType>(eltType)->getWidth();
+    auto width = eltType.cast<IntegerType>().getWidth();
     auto *result = impl.allocator.Allocate<DenseIntElementsAttributeStorage>();
     auto *copy = (char *)impl.allocator.Allocate(data.size(), 64);
     std::uninitialized_copy(data.begin(), data.end(), copy);
@@ -1080,12 +1093,12 @@
   }
 }
 
-OpaqueElementsAttr OpaqueElementsAttr::get(VectorOrTensorType *type,
+OpaqueElementsAttr OpaqueElementsAttr::get(VectorOrTensorType type,
                                            StringRef bytes) {
-  assert(isValidTensorElementType(type->getElementType()) &&
+  assert(isValidTensorElementType(type.getElementType()) &&
          "Input element type should be a valid tensor element type");
 
-  auto &impl = type->getContext()->getImpl();
+  auto &impl = type.getContext()->getImpl();
 
   // Look to see if this constant is already defined.
   OpaqueElementsAttrInfo::KeyTy key({type, bytes});
@@ -1104,10 +1117,10 @@
   return *existing.first = result;
 }
 
-SparseElementsAttr SparseElementsAttr::get(VectorOrTensorType *type,
+SparseElementsAttr SparseElementsAttr::get(VectorOrTensorType type,
                                            DenseIntElementsAttr indices,
                                            DenseElementsAttr values) {
-  auto &impl = type->getContext()->getImpl();
+  auto &impl = type.getContext()->getImpl();
 
   // Look to see if we already have this.
   auto key = std::make_tuple(type, indices, values);
diff --git a/lib/IR/Operation.cpp b/lib/IR/Operation.cpp
index 2ed09b8..0722421 100644
--- a/lib/IR/Operation.cpp
+++ b/lib/IR/Operation.cpp
@@ -377,7 +377,7 @@
 }
 
 bool OpTrait::impl::verifySameOperandsAndResult(const Operation *op) {
-  auto *type = op->getResult(0)->getType();
+  auto type = op->getResult(0)->getType();
   for (unsigned i = 1, e = op->getNumResults(); i < e; ++i) {
     if (op->getResult(i)->getType() != type)
       return op->emitOpError(
@@ -393,19 +393,19 @@
 
 /// If this is a vector type, or a tensor type, return the scalar element type
 /// that it is built around, otherwise return the type unmodified.
-static Type *getTensorOrVectorElementType(Type *type) {
-  if (auto *vec = dyn_cast<VectorType>(type))
-    return vec->getElementType();
+static Type getTensorOrVectorElementType(Type type) {
+  if (auto vec = type.dyn_cast<VectorType>())
+    return vec.getElementType();
 
   // Look through tensor<vector<...>> to find the underlying element type.
-  if (auto *tensor = dyn_cast<TensorType>(type))
-    return getTensorOrVectorElementType(tensor->getElementType());
+  if (auto tensor = type.dyn_cast<TensorType>())
+    return getTensorOrVectorElementType(tensor.getElementType());
   return type;
 }
 
 bool OpTrait::impl::verifyResultsAreFloatLike(const Operation *op) {
   for (auto *result : op->getResults()) {
-    if (!isa<FloatType>(getTensorOrVectorElementType(result->getType())))
+    if (!getTensorOrVectorElementType(result->getType()).isa<FloatType>())
       return op->emitOpError("requires a floating point type");
   }
 
@@ -414,7 +414,7 @@
 
 bool OpTrait::impl::verifyResultsAreIntegerLike(const Operation *op) {
   for (auto *result : op->getResults()) {
-    if (!isa<IntegerType>(getTensorOrVectorElementType(result->getType())))
+    if (!getTensorOrVectorElementType(result->getType()).isa<IntegerType>())
       return op->emitOpError("requires an integer type");
   }
   return false;
@@ -436,7 +436,7 @@
 
 bool impl::parseBinaryOp(OpAsmParser *parser, OperationState *result) {
   SmallVector<OpAsmParser::OperandType, 2> ops;
-  Type *type;
+  Type type;
   return parser->parseOperandList(ops, 2) ||
          parser->parseOptionalAttributeDict(result->attributes) ||
          parser->parseColonType(type) ||
@@ -448,7 +448,7 @@
   *p << op->getName() << ' ' << *op->getOperand(0) << ", "
      << *op->getOperand(1);
   p->printOptionalAttrDict(op->getAttrs());
-  *p << " : " << *op->getResult(0)->getType();
+  *p << " : " << op->getResult(0)->getType();
 }
 
 //===----------------------------------------------------------------------===//
@@ -456,14 +456,14 @@
 //===----------------------------------------------------------------------===//
 
 void impl::buildCastOp(Builder *builder, OperationState *result,
-                       SSAValue *source, Type *destType) {
+                       SSAValue *source, Type destType) {
   result->addOperands(source);
   result->addTypes(destType);
 }
 
 bool impl::parseCastOp(OpAsmParser *parser, OperationState *result) {
   OpAsmParser::OperandType srcInfo;
-  Type *srcType, *dstType;
+  Type srcType, dstType;
   return parser->parseOperand(srcInfo) || parser->parseColonType(srcType) ||
          parser->resolveOperand(srcInfo, srcType, result->operands) ||
          parser->parseKeywordType("to", dstType) ||
@@ -472,5 +472,5 @@
 
 void impl::printCastOp(const Operation *op, OpAsmPrinter *p) {
   *p << op->getName() << ' ' << *op->getOperand(0) << " : "
-     << *op->getOperand(0)->getType() << " to " << *op->getResult(0)->getType();
+     << op->getOperand(0)->getType() << " to " << op->getResult(0)->getType();
 }
diff --git a/lib/IR/Statement.cpp b/lib/IR/Statement.cpp
index e9c46d6..698089a1 100644
--- a/lib/IR/Statement.cpp
+++ b/lib/IR/Statement.cpp
@@ -239,7 +239,7 @@
 /// Create a new OperationStmt with the specific fields.
 OperationStmt *OperationStmt::create(Location *location, OperationName name,
                                      ArrayRef<MLValue *> operands,
-                                     ArrayRef<Type *> resultTypes,
+                                     ArrayRef<Type> resultTypes,
                                      ArrayRef<NamedAttribute> attributes,
                                      MLIRContext *context) {
   auto byteSize = totalSizeToAlloc<StmtOperand, StmtResult>(operands.size(),
@@ -288,9 +288,9 @@
   // If we have a result or operand type, that is a constant time way to get
   // to the context.
   if (getNumResults())
-    return getResult(0)->getType()->getContext();
+    return getResult(0)->getType().getContext();
   if (getNumOperands())
-    return getOperand(0)->getType()->getContext();
+    return getOperand(0)->getType().getContext();
 
   // In the very odd case where we have no operands or results, fall back to
   // doing a find.
@@ -474,7 +474,7 @@
   if (operands.empty())
     return findFunction()->getContext();
 
-  return getOperand(0)->getType()->getContext();
+  return getOperand(0)->getType().getContext();
 }
 
 //===----------------------------------------------------------------------===//
@@ -501,7 +501,7 @@
     operands.push_back(remapOperand(opValue));
 
   if (auto *opStmt = dyn_cast<OperationStmt>(this)) {
-    SmallVector<Type *, 8> resultTypes;
+    SmallVector<Type, 8> resultTypes;
     resultTypes.reserve(opStmt->getNumResults());
     for (auto *result : opStmt->getResults())
       resultTypes.push_back(result->getType());
diff --git a/lib/IR/TypeDetail.h b/lib/IR/TypeDetail.h
new file mode 100644
index 0000000..c22e87a
--- /dev/null
+++ b/lib/IR/TypeDetail.h
@@ -0,0 +1,126 @@
+//===- TypeDetail.h - MLIR Affine Expr storage details ----------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This holds implementation details of Type.
+//
+//===----------------------------------------------------------------------===//
+#ifndef TYPEDETAIL_H_
+#define TYPEDETAIL_H_
+
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Types.h"
+
+namespace mlir {
+
+class AffineMap;
+class MLIRContext;
+
+namespace detail {
+
+/// Base storage class appearing in a Type.
+struct alignas(8) TypeStorage {
+  TypeStorage(Type::Kind kind, MLIRContext *context)
+      : context(context), kind(kind), subclassData(0) {}
+  TypeStorage(Type::Kind kind, MLIRContext *context, unsigned subclassData)
+      : context(context), kind(kind), subclassData(subclassData) {}
+
+  unsigned getSubclassData() const { return subclassData; }
+
+  void setSubclassData(unsigned val) {
+    subclassData = val;
+    // Ensure we don't have any accidental truncation.
+    assert(getSubclassData() == val && "Subclass data too large for field");
+  }
+
+  /// This refers to the MLIRContext in which this type was uniqued.
+  MLIRContext *const context;
+
+  /// Classification of the subclass, used for type checking.
+  Type::Kind kind : 8;
+
+  /// Space for subclasses to store data.
+  unsigned subclassData : 24;
+};
+
+struct IntegerTypeStorage : public TypeStorage {
+  unsigned width;
+};
+
+struct FloatTypeStorage : public TypeStorage {};
+
+struct OtherTypeStorage : public TypeStorage {};
+
+struct FunctionTypeStorage : public TypeStorage {
+  ArrayRef<Type> getInputs() const {
+    return ArrayRef<Type>(inputsAndResults, subclassData);
+  }
+  ArrayRef<Type> getResults() const {
+    return ArrayRef<Type>(inputsAndResults + subclassData, numResults);
+  }
+
+  unsigned numResults;
+  Type const *inputsAndResults;
+};
+
+struct VectorOrTensorTypeStorage : public TypeStorage {
+  Type elementType;
+};
+
+struct VectorTypeStorage : public VectorOrTensorTypeStorage {
+  ArrayRef<int> getShape() const {
+    return ArrayRef<int>(shapeElements, getSubclassData());
+  }
+
+  const int *shapeElements;
+};
+
+struct TensorTypeStorage : public VectorOrTensorTypeStorage {};
+
+struct RankedTensorTypeStorage : public TensorTypeStorage {
+  ArrayRef<int> getShape() const {
+    return ArrayRef<int>(shapeElements, getSubclassData());
+  }
+
+  const int *shapeElements;
+};
+
+struct UnrankedTensorTypeStorage : public TensorTypeStorage {};
+
+struct MemRefTypeStorage : public TypeStorage {
+  ArrayRef<int> getShape() const {
+    return ArrayRef<int>(shapeElements, getSubclassData());
+  }
+
+  ArrayRef<AffineMap> getAffineMaps() const {
+    return ArrayRef<AffineMap>(affineMapList, numAffineMaps);
+  }
+
+  /// The type of each scalar element of the memref.
+  Type elementType;
+  /// An array of integers which stores the shape dimension sizes.
+  const int *shapeElements;
+  /// The number of affine maps in the 'affineMapList' array.
+  const unsigned numAffineMaps;
+  /// List of affine maps in the memref's layout/index map composition.
+  AffineMap const *affineMapList;
+  /// Memory space in which data referenced by memref resides.
+  const unsigned memorySpace;
+};
+
+} // namespace detail
+} // namespace mlir
+#endif // TYPEDETAIL_H_
diff --git a/lib/IR/Types.cpp b/lib/IR/Types.cpp
index 0ad3f47..1a71695 100644
--- a/lib/IR/Types.cpp
+++ b/lib/IR/Types.cpp
@@ -16,10 +16,17 @@
 // =============================================================================
 
 #include "mlir/IR/Types.h"
+#include "TypeDetail.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/Support/STLExtras.h"
 #include "llvm/Support/raw_ostream.h"
+
 using namespace mlir;
+using namespace mlir::detail;
+
+Type::Kind Type::getKind() const { return type->kind; }
+
+MLIRContext *Type::getContext() const { return type->context; }
 
 unsigned Type::getBitWidth() const {
   switch (getKind()) {
@@ -32,34 +39,49 @@
   case Type::Kind::F64:
     return 64;
   case Type::Kind::Integer:
-    return cast<IntegerType>(this)->getWidth();
+    return cast<IntegerType>().getWidth();
   case Type::Kind::Vector:
   case Type::Kind::RankedTensor:
   case Type::Kind::UnrankedTensor:
-    return cast<VectorOrTensorType>(this)->getElementType()->getBitWidth();
+    return cast<VectorOrTensorType>().getElementType().getBitWidth();
     // TODO: Handle more types.
   default:
     llvm_unreachable("unexpected type");
   }
 }
 
-IntegerType::IntegerType(unsigned width, MLIRContext *context)
-    : Type(Kind::Integer, context), width(width) {
-  assert(width <= kMaxWidth && "admissible integer bitwidth exceeded");
+unsigned Type::getSubclassData() const { return type->getSubclassData(); }
+void Type::setSubclassData(unsigned val) { type->setSubclassData(val); }
+
+IntegerType::IntegerType(Type::ImplType *ptr) : Type(ptr) {}
+
+unsigned IntegerType::getWidth() const {
+  return static_cast<ImplType *>(type)->width;
 }
 
-FloatType::FloatType(Kind kind, MLIRContext *context) : Type(kind, context) {}
+FloatType::FloatType(Type::ImplType *ptr) : Type(ptr) {}
 
-OtherType::OtherType(Kind kind, MLIRContext *context) : Type(kind, context) {}
+OtherType::OtherType(Type::ImplType *ptr) : Type(ptr) {}
 
-FunctionType::FunctionType(Type *const *inputsAndResults, unsigned numInputs,
-                           unsigned numResults, MLIRContext *context)
-    : Type(Kind::Function, context, numInputs), numResults(numResults),
-      inputsAndResults(inputsAndResults) {}
+FunctionType::FunctionType(Type::ImplType *ptr) : Type(ptr) {}
 
-VectorOrTensorType::VectorOrTensorType(Kind kind, MLIRContext *context,
-                                       Type *elementType, unsigned subClassData)
-    : Type(kind, context, subClassData), elementType(elementType) {}
+ArrayRef<Type> FunctionType::getInputs() const {
+  return static_cast<ImplType *>(type)->getInputs();
+}
+
+unsigned FunctionType::getNumResults() const {
+  return static_cast<ImplType *>(type)->numResults;
+}
+
+ArrayRef<Type> FunctionType::getResults() const {
+  return static_cast<ImplType *>(type)->getResults();
+}
+
+VectorOrTensorType::VectorOrTensorType(Type::ImplType *ptr) : Type(ptr) {}
+
+Type VectorOrTensorType::getElementType() const {
+  return static_cast<ImplType *>(type)->elementType;
+}
 
 unsigned VectorOrTensorType::getNumElements() const {
   switch (getKind()) {
@@ -103,11 +125,11 @@
 ArrayRef<int> VectorOrTensorType::getShape() const {
   switch (getKind()) {
   case Kind::Vector:
-    return cast<VectorType>(this)->getShape();
+    return cast<VectorType>().getShape();
   case Kind::RankedTensor:
-    return cast<RankedTensorType>(this)->getShape();
+    return cast<RankedTensorType>().getShape();
   case Kind::UnrankedTensor:
-    return cast<RankedTensorType>(this)->getShape();
+    return cast<RankedTensorType>().getShape();
   default:
     llvm_unreachable("not a VectorOrTensorType");
   }
@@ -118,35 +140,38 @@
   return !std::any_of(dims.begin(), dims.end(), [](int i) { return i < 0; });
 }
 
-VectorType::VectorType(ArrayRef<int> shape, Type *elementType,
-                       MLIRContext *context)
-    : VectorOrTensorType(Kind::Vector, context, elementType, shape.size()),
-      shapeElements(shape.data()) {}
+VectorType::VectorType(Type::ImplType *ptr) : VectorOrTensorType(ptr) {}
 
-TensorType::TensorType(Kind kind, Type *elementType, MLIRContext *context)
-    : VectorOrTensorType(kind, context, elementType) {
-  assert(isValidTensorElementType(elementType));
+ArrayRef<int> VectorType::getShape() const {
+  return static_cast<ImplType *>(type)->getShape();
 }
 
-RankedTensorType::RankedTensorType(ArrayRef<int> shape, Type *elementType,
-                                   MLIRContext *context)
-    : TensorType(Kind::RankedTensor, elementType, context),
-      shapeElements(shape.data()) {
-  setSubclassData(shape.size());
+TensorType::TensorType(Type::ImplType *ptr) : VectorOrTensorType(ptr) {}
+
+RankedTensorType::RankedTensorType(Type::ImplType *ptr) : TensorType(ptr) {}
+
+ArrayRef<int> RankedTensorType::getShape() const {
+  return static_cast<ImplType *>(type)->getShape();
 }
 
-UnrankedTensorType::UnrankedTensorType(Type *elementType, MLIRContext *context)
-    : TensorType(Kind::UnrankedTensor, elementType, context) {}
+UnrankedTensorType::UnrankedTensorType(Type::ImplType *ptr) : TensorType(ptr) {}
 
-MemRefType::MemRefType(ArrayRef<int> shape, Type *elementType,
-                       ArrayRef<AffineMap> affineMapList, unsigned memorySpace,
-                       MLIRContext *context)
-    : Type(Kind::MemRef, context, shape.size()), elementType(elementType),
-      shapeElements(shape.data()), numAffineMaps(affineMapList.size()),
-      affineMapList(affineMapList.data()), memorySpace(memorySpace) {}
+MemRefType::MemRefType(Type::ImplType *ptr) : Type(ptr) {}
+
+ArrayRef<int> MemRefType::getShape() const {
+  return static_cast<ImplType *>(type)->getShape();
+}
+
+Type MemRefType::getElementType() const {
+  return static_cast<ImplType *>(type)->elementType;
+}
 
 ArrayRef<AffineMap> MemRefType::getAffineMaps() const {
-  return ArrayRef<AffineMap>(affineMapList, numAffineMaps);
+  return static_cast<ImplType *>(type)->getAffineMaps();
+}
+
+unsigned MemRefType::getMemorySpace() const {
+  return static_cast<ImplType *>(type)->memorySpace;
 }
 
 unsigned MemRefType::getNumDynamicDims() const {
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index 7974c7c..ceb8931 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -182,19 +182,19 @@
   // as the results of their action.
 
   // Type parsing.
-  VectorType *parseVectorType();
+  VectorType parseVectorType();
   ParseResult parseXInDimensionList();
   ParseResult parseDimensionListRanked(SmallVectorImpl<int> &dimensions);
-  Type *parseTensorType();
-  Type *parseMemRefType();
-  Type *parseFunctionType();
-  Type *parseType();
-  ParseResult parseTypeListNoParens(SmallVectorImpl<Type *> &elements);
-  ParseResult parseTypeList(SmallVectorImpl<Type *> &elements);
+  Type parseTensorType();
+  Type parseMemRefType();
+  Type parseFunctionType();
+  Type parseType();
+  ParseResult parseTypeListNoParens(SmallVectorImpl<Type> &elements);
+  ParseResult parseTypeList(SmallVectorImpl<Type> &elements);
 
   // Attribute parsing.
   Function *resolveFunctionReference(StringRef nameStr, SMLoc nameLoc,
-                                     FunctionType *type);
+                                     FunctionType type);
   Attribute parseAttribute();
 
   ParseResult parseAttributeDict(SmallVectorImpl<NamedAttribute> &attributes);
@@ -206,9 +206,9 @@
   AffineMap parseAffineMapReference();
   IntegerSet parseIntegerSetInline();
   IntegerSet parseIntegerSetReference();
-  DenseElementsAttr parseDenseElementsAttr(VectorOrTensorType *type);
-  DenseElementsAttr parseDenseElementsAttr(Type *eltType, bool isVector);
-  VectorOrTensorType *parseVectorOrTensorType();
+  DenseElementsAttr parseDenseElementsAttr(VectorOrTensorType type);
+  DenseElementsAttr parseDenseElementsAttr(Type eltType, bool isVector);
+  VectorOrTensorType parseVectorOrTensorType();
 
 private:
   // The Parser is subclassed and reinstantiated.  Do not add additional
@@ -299,7 +299,7 @@
 ///   float-type ::= `f16` | `bf16` | `f32` | `f64`
 ///   other-type ::= `index` | `tf_control`
 ///
-Type *Parser::parseType() {
+Type Parser::parseType() {
   switch (getToken().getKind()) {
   default:
     return (emitError("expected type"), nullptr);
@@ -368,7 +368,7 @@
 ///   vector-type ::= `vector` `<` const-dimension-list primitive-type `>`
 ///   const-dimension-list ::= (integer-literal `x`)+
 ///
-VectorType *Parser::parseVectorType() {
+VectorType Parser::parseVectorType() {
   consumeToken(Token::kw_vector);
 
   if (parseToken(Token::less, "expected '<' in vector type"))
@@ -402,11 +402,11 @@
 
   // Parse the element type.
   auto typeLoc = getToken().getLoc();
-  auto *elementType = parseType();
+  auto elementType = parseType();
   if (!elementType || parseToken(Token::greater, "expected '>' in vector type"))
     return nullptr;
 
-  if (!isa<FloatType>(elementType) && !isa<IntegerType>(elementType))
+  if (!elementType.isa<FloatType>() && !elementType.isa<IntegerType>())
     return (emitError(typeLoc, "invalid vector element type"), nullptr);
 
   return VectorType::get(dimensions, elementType);
@@ -461,7 +461,7 @@
 ///   tensor-type ::= `tensor` `<` dimension-list element-type `>`
 ///   dimension-list ::= dimension-list-ranked | `*x`
 ///
-Type *Parser::parseTensorType() {
+Type Parser::parseTensorType() {
   consumeToken(Token::kw_tensor);
 
   if (parseToken(Token::less, "expected '<' in tensor type"))
@@ -485,7 +485,7 @@
 
   // Parse the element type.
   auto typeLoc = getToken().getLoc();
-  auto *elementType = parseType();
+  auto elementType = parseType();
   if (!elementType || parseToken(Token::greater, "expected '>' in tensor type"))
     return nullptr;
 
@@ -505,7 +505,7 @@
 ///   semi-affine-map-composition ::= (semi-affine-map `,` )* semi-affine-map
 ///   memory-space ::= integer-literal /* | TODO: address-space-id */
 ///
-Type *Parser::parseMemRefType() {
+Type Parser::parseMemRefType() {
   consumeToken(Token::kw_memref);
 
   if (parseToken(Token::less, "expected '<' in memref type"))
@@ -517,12 +517,12 @@
 
   // Parse the element type.
   auto typeLoc = getToken().getLoc();
-  auto *elementType = parseType();
+  auto elementType = parseType();
   if (!elementType)
     return nullptr;
 
-  if (!isa<IntegerType>(elementType) && !isa<FloatType>(elementType) &&
-      !isa<VectorType>(elementType))
+  if (!elementType.isa<IntegerType>() && !elementType.isa<FloatType>() &&
+      !elementType.isa<VectorType>())
     return (emitError(typeLoc, "invalid memref element type"), nullptr);
 
   // Parse semi-affine-map-composition.
@@ -581,10 +581,10 @@
 ///
 ///   function-type ::= type-list-parens `->` type-list
 ///
-Type *Parser::parseFunctionType() {
+Type Parser::parseFunctionType() {
   assert(getToken().is(Token::l_paren));
 
-  SmallVector<Type *, 4> arguments, results;
+  SmallVector<Type, 4> arguments, results;
   if (parseTypeList(arguments) ||
       parseToken(Token::arrow, "expected '->' in function type") ||
       parseTypeList(results))
@@ -598,7 +598,7 @@
 ///
 ///   type-list-no-parens ::=  type (`,` type)*
 ///
-ParseResult Parser::parseTypeListNoParens(SmallVectorImpl<Type *> &elements) {
+ParseResult Parser::parseTypeListNoParens(SmallVectorImpl<Type> &elements) {
   auto parseElt = [&]() -> ParseResult {
     auto elt = parseType();
     elements.push_back(elt);
@@ -615,7 +615,7 @@
 ///   type-list-parens ::= `(` `)`
 ///                      | `(` type-list-no-parens `)`
 ///
-ParseResult Parser::parseTypeList(SmallVectorImpl<Type *> &elements) {
+ParseResult Parser::parseTypeList(SmallVectorImpl<Type> &elements) {
   auto parseElt = [&]() -> ParseResult {
     auto elt = parseType();
     elements.push_back(elt);
@@ -639,8 +639,8 @@
 namespace {
 class TensorLiteralParser {
 public:
-  TensorLiteralParser(Parser &p, Type *eltTy)
-      : p(p), eltTy(eltTy), currBitPos(0), bitsWidth(eltTy->getBitWidth()) {}
+  TensorLiteralParser(Parser &p, Type eltTy)
+      : p(p), eltTy(eltTy), currBitPos(0), bitsWidth(eltTy.getBitWidth()) {}
 
   ParseResult parse() { return parseList(shape); }
 
@@ -676,7 +676,7 @@
   }
 
   Parser &p;
-  Type *eltTy;
+  Type eltTy;
   size_t currBitPos;
   size_t bitsWidth;
   SmallVector<int, 4> shape;
@@ -698,7 +698,7 @@
     if (!result)
       return p.emitError("expected tensor element");
     // check result matches the element type.
-    switch (eltTy->getKind()) {
+    switch (eltTy.getKind()) {
     case Type::Kind::BF16:
     case Type::Kind::F16:
     case Type::Kind::F32:
@@ -779,7 +779,7 @@
 /// synthesizing a forward reference) or emit an error and return null on
 /// failure.
 Function *Parser::resolveFunctionReference(StringRef nameStr, SMLoc nameLoc,
-                                           FunctionType *type) {
+                                           FunctionType type) {
   Identifier name = builder.getIdentifier(nameStr.drop_front());
 
   // See if the function has already been defined in the module.
@@ -902,10 +902,10 @@
     if (parseToken(Token::colon, "expected ':' and function type"))
       return nullptr;
     auto typeLoc = getToken().getLoc();
-    Type *type = parseType();
+    Type type = parseType();
     if (!type)
       return nullptr;
-    auto *fnType = dyn_cast<FunctionType>(type);
+    auto fnType = type.dyn_cast<FunctionType>();
     if (!fnType)
       return (emitError(typeLoc, "expected function type"), nullptr);
 
@@ -916,7 +916,7 @@
     consumeToken(Token::kw_opaque);
     if (parseToken(Token::less, "expected '<' after 'opaque'"))
       return nullptr;
-    auto *type = parseVectorOrTensorType();
+    auto type = parseVectorOrTensorType();
     if (!type)
       return nullptr;
     auto val = getToken().getStringValue();
@@ -937,7 +937,7 @@
     if (parseToken(Token::less, "expected '<' after 'splat'"))
       return nullptr;
 
-    auto *type = parseVectorOrTensorType();
+    auto type = parseVectorOrTensorType();
     if (!type)
       return nullptr;
     switch (getToken().getKind()) {
@@ -959,7 +959,7 @@
     if (parseToken(Token::less, "expected '<' after 'dense'"))
       return nullptr;
 
-    auto *type = parseVectorOrTensorType();
+    auto type = parseVectorOrTensorType();
     if (!type)
       return nullptr;
 
@@ -981,41 +981,41 @@
     if (parseToken(Token::less, "Expected '<' after 'sparse'"))
       return nullptr;
 
-    auto *type = parseVectorOrTensorType();
+    auto type = parseVectorOrTensorType();
     if (!type)
       return nullptr;
 
     switch (getToken().getKind()) {
     case Token::l_square: {
       /// Parse indices
-      auto *indicesEltType = builder.getIntegerType(32);
+      auto indicesEltType = builder.getIntegerType(32);
       auto indices =
-          parseDenseElementsAttr(indicesEltType, isa<VectorType>(type));
+          parseDenseElementsAttr(indicesEltType, type.isa<VectorType>());
 
       if (parseToken(Token::comma, "expected ','"))
         return nullptr;
 
       /// Parse values.
-      auto *valuesEltType = type->getElementType();
+      auto valuesEltType = type.getElementType();
       auto values =
-          parseDenseElementsAttr(valuesEltType, isa<VectorType>(type));
+          parseDenseElementsAttr(valuesEltType, type.isa<VectorType>());
 
       /// Sanity check.
-      auto *indicesType = indices.getType();
-      auto *valuesType = values.getType();
-      auto sameShape = (indicesType->getRank() == 1) ||
-                       (type->getRank() == indicesType->getDimSize(1));
+      auto indicesType = indices.getType();
+      auto valuesType = values.getType();
+      auto sameShape = (indicesType.getRank() == 1) ||
+                       (type.getRank() == indicesType.getDimSize(1));
       auto sameElementNum =
-          indicesType->getDimSize(0) == valuesType->getDimSize(0);
+          indicesType.getDimSize(0) == valuesType.getDimSize(0);
       if (!sameShape || !sameElementNum) {
         std::string str;
         llvm::raw_string_ostream s(str);
         s << "expected shape ([";
-        interleaveComma(type->getShape(), s);
+        interleaveComma(type.getShape(), s);
         s << "]); inferred shape of indices literal ([";
-        interleaveComma(indicesType->getShape(), s);
+        interleaveComma(indicesType.getShape(), s);
         s << "]); inferred shape of values literal ([";
-        interleaveComma(valuesType->getShape(), s);
+        interleaveComma(valuesType.getShape(), s);
         s << "])";
         return (emitError(s.str()), nullptr);
       }
@@ -1035,7 +1035,7 @@
             nullptr);
   }
   default: {
-    if (Type *type = parseType())
+    if (Type type = parseType())
       return builder.getTypeAttr(type);
     return nullptr;
   }
@@ -1051,12 +1051,12 @@
 ///
 /// This method returns a constructed dense elements attribute with the shape
 /// from the parsing result.
-DenseElementsAttr Parser::parseDenseElementsAttr(Type *eltType, bool isVector) {
+DenseElementsAttr Parser::parseDenseElementsAttr(Type eltType, bool isVector) {
   TensorLiteralParser literalParser(*this, eltType);
   if (literalParser.parse())
     return nullptr;
 
-  VectorOrTensorType *type;
+  VectorOrTensorType type;
   if (isVector) {
     type = builder.getVectorType(literalParser.getShape(), eltType);
   } else {
@@ -1076,18 +1076,18 @@
 /// This method compares the shapes from the parsing result and that from the
 /// input argument. It returns a constructed dense elements attribute if both
 /// match.
-DenseElementsAttr Parser::parseDenseElementsAttr(VectorOrTensorType *type) {
-  auto *eltTy = type->getElementType();
+DenseElementsAttr Parser::parseDenseElementsAttr(VectorOrTensorType type) {
+  auto eltTy = type.getElementType();
   TensorLiteralParser literalParser(*this, eltTy);
   if (literalParser.parse())
     return nullptr;
-  if (literalParser.getShape() != type->getShape()) {
+  if (literalParser.getShape() != type.getShape()) {
     std::string str;
     llvm::raw_string_ostream s(str);
     s << "inferred shape of elements literal ([";
     interleaveComma(literalParser.getShape(), s);
     s << "]) does not match type ([";
-    interleaveComma(type->getShape(), s);
+    interleaveComma(type.getShape(), s);
     s << "])";
     return (emitError(s.str()), nullptr);
   }
@@ -1100,8 +1100,8 @@
 ///   vector-or-tensor-type ::= vector-type | tensor-type
 ///
 /// This method also checks the type has static shape and ranked.
-VectorOrTensorType *Parser::parseVectorOrTensorType() {
-  auto *type = dyn_cast<VectorOrTensorType>(parseType());
+VectorOrTensorType Parser::parseVectorOrTensorType() {
+  auto type = parseType().dyn_cast<VectorOrTensorType>();
   if (!type) {
     return (emitError("expected elements literal has a tensor or vector type"),
             nullptr);
@@ -1110,7 +1110,7 @@
   if (parseToken(Token::comma, "expected ','"))
     return nullptr;
 
-  if (!type->hasStaticShape() || type->getRank() == -1) {
+  if (!type.hasStaticShape() || type.getRank() == -1) {
     return (emitError("tensor literals must be ranked and have static shape"),
             nullptr);
   }
@@ -1834,7 +1834,7 @@
 
   /// Given a reference to an SSA value and its type, return a reference. This
   /// returns null on failure.
-  SSAValue *resolveSSAUse(SSAUseInfo useInfo, Type *type);
+  SSAValue *resolveSSAUse(SSAUseInfo useInfo, Type type);
 
   /// Register a definition of a value with the symbol table.
   ParseResult addDefinition(SSAUseInfo useInfo, SSAValue *value);
@@ -1845,11 +1845,11 @@
 
   template <typename ResultType>
   ResultType parseSSADefOrUseAndType(
-      const std::function<ResultType(SSAUseInfo, Type *)> &action);
+      const std::function<ResultType(SSAUseInfo, Type)> &action);
 
   SSAValue *parseSSAUseAndType() {
     return parseSSADefOrUseAndType<SSAValue *>(
-        [&](SSAUseInfo useInfo, Type *type) -> SSAValue * {
+        [&](SSAUseInfo useInfo, Type type) -> SSAValue * {
           return resolveSSAUse(useInfo, type);
         });
   }
@@ -1880,7 +1880,7 @@
   /// their first reference, to allow checking for use of undefined values.
   DenseMap<SSAValue *, SMLoc> forwardReferencePlaceholders;
 
-  SSAValue *createForwardReferencePlaceholder(SMLoc loc, Type *type);
+  SSAValue *createForwardReferencePlaceholder(SMLoc loc, Type type);
 
   /// Return true if this is a forward reference.
   bool isForwardReferencePlaceholder(SSAValue *value) {
@@ -1891,7 +1891,7 @@
 
 /// Create and remember a new placeholder for a forward reference.
 SSAValue *FunctionParser::createForwardReferencePlaceholder(SMLoc loc,
-                                                            Type *type) {
+                                                            Type type) {
   // Forward references are always created as instructions, even in ML
   // functions, because we just need something with a def/use chain.
   //
@@ -1908,7 +1908,7 @@
 
 /// Given an unbound reference to an SSA value and its type, return the value
 /// it specifies.  This returns null on failure.
-SSAValue *FunctionParser::resolveSSAUse(SSAUseInfo useInfo, Type *type) {
+SSAValue *FunctionParser::resolveSSAUse(SSAUseInfo useInfo, Type type) {
   auto &entries = values[useInfo.name];
 
   // If we have already seen a value of this name, return it.
@@ -2057,14 +2057,14 @@
 ///   ssa-use-and-type ::= ssa-use `:` type
 template <typename ResultType>
 ResultType FunctionParser::parseSSADefOrUseAndType(
-    const std::function<ResultType(SSAUseInfo, Type *)> &action) {
+    const std::function<ResultType(SSAUseInfo, Type)> &action) {
 
   SSAUseInfo useInfo;
   if (parseSSAUse(useInfo) ||
       parseToken(Token::colon, "expected ':' and type for SSA operand"))
     return nullptr;
 
-  auto *type = parseType();
+  auto type = parseType();
   if (!type)
     return nullptr;
 
@@ -2101,7 +2101,7 @@
   if (valueIDs.empty())
     return ParseSuccess;
 
-  SmallVector<Type *, 4> types;
+  SmallVector<Type, 4> types;
   if (parseToken(Token::colon, "expected ':' in operand list") ||
       parseTypeListNoParens(types))
     return ParseFailure;
@@ -2209,14 +2209,14 @@
   auto type = parseType();
   if (!type)
     return nullptr;
-  auto fnType = dyn_cast<FunctionType>(type);
+  auto fnType = type.dyn_cast<FunctionType>();
   if (!fnType)
     return (emitError(typeLoc, "expected function type"), nullptr);
 
-  result.addTypes(fnType->getResults());
+  result.addTypes(fnType.getResults());
 
   // Check that we have the right number of types for the operands.
-  auto operandTypes = fnType->getInputs();
+  auto operandTypes = fnType.getInputs();
   if (operandTypes.size() != operandInfos.size()) {
     auto plural = "s"[operandInfos.size() == 1];
     return (emitError(typeLoc, "expected " + llvm::utostr(operandInfos.size()) +
@@ -2253,17 +2253,17 @@
     return parser.parseToken(Token::comma, "expected ','");
   }
 
-  bool parseColonType(Type *&result) override {
+  bool parseColonType(Type &result) override {
     return parser.parseToken(Token::colon, "expected ':'") ||
            !(result = parser.parseType());
   }
 
-  bool parseColonTypeList(SmallVectorImpl<Type *> &result) override {
+  bool parseColonTypeList(SmallVectorImpl<Type> &result) override {
     if (parser.parseToken(Token::colon, "expected ':'"))
       return true;
 
     do {
-      if (auto *type = parser.parseType())
+      if (auto type = parser.parseType())
         result.push_back(type);
       else
         return true;
@@ -2273,7 +2273,7 @@
   }
 
   /// Parse a keyword followed by a type.
-  bool parseKeywordType(const char *keyword, Type *&result) override {
+  bool parseKeywordType(const char *keyword, Type &result) override {
     if (parser.getTokenSpelling() != keyword)
       return parser.emitError("expected '" + Twine(keyword) + "'");
     parser.consumeToken();
@@ -2396,7 +2396,7 @@
   }
 
   /// Resolve a parse function name and a type into a function reference.
-  virtual bool resolveFunctionName(StringRef name, FunctionType *type,
+  virtual bool resolveFunctionName(StringRef name, FunctionType type,
                                    llvm::SMLoc loc, Function *&result) {
     result = parser.resolveFunctionReference(name, loc, type);
     return result == nullptr;
@@ -2410,7 +2410,7 @@
 
   llvm::SMLoc getNameLoc() const override { return nameLoc; }
 
-  bool resolveOperand(const OperandType &operand, Type *type,
+  bool resolveOperand(const OperandType &operand, Type type,
                       SmallVectorImpl<SSAValue *> &result) override {
     FunctionParser::SSAUseInfo operandInfo = {operand.name, operand.number,
                                               operand.location};
@@ -2559,11 +2559,11 @@
     return ParseSuccess;
 
   return parseCommaSeparatedList([&]() -> ParseResult {
-    auto type = parseSSADefOrUseAndType<Type *>(
-        [&](SSAUseInfo useInfo, Type *type) -> Type * {
+    auto type = parseSSADefOrUseAndType<Type>(
+        [&](SSAUseInfo useInfo, Type type) -> Type {
           BBArgument *arg = owner->addArgument(type);
           if (addDefinition(useInfo, arg))
-            return nullptr;
+            return {};
           return type;
         });
     return type ? ParseSuccess : ParseFailure;
@@ -2908,7 +2908,7 @@
                      " symbol count must match");
 
   // Resolve SSA uses.
-  Type *indexType = builder.getIndexType();
+  Type indexType = builder.getIndexType();
   for (unsigned i = 0, e = opInfo.size(); i != e; ++i) {
     SSAValue *sval = resolveSSAUse(opInfo[i], indexType);
     if (!sval)
@@ -3187,9 +3187,9 @@
   ParseResult parseAffineStructureDef();
 
   // Functions.
-  ParseResult parseMLArgumentList(SmallVectorImpl<Type *> &argTypes,
+  ParseResult parseMLArgumentList(SmallVectorImpl<Type> &argTypes,
                                   SmallVectorImpl<StringRef> &argNames);
-  ParseResult parseFunctionSignature(StringRef &name, FunctionType *&type,
+  ParseResult parseFunctionSignature(StringRef &name, FunctionType &type,
                                      SmallVectorImpl<StringRef> *argNames);
   ParseResult parseFunctionAttribute(SmallVectorImpl<NamedAttribute> &attrs);
   ParseResult parseExtFunc();
@@ -3248,7 +3248,7 @@
 /// ml-argument-list ::= ml-argument (`,` ml-argument)* | /*empty*/
 ///
 ParseResult
-ModuleParser::parseMLArgumentList(SmallVectorImpl<Type *> &argTypes,
+ModuleParser::parseMLArgumentList(SmallVectorImpl<Type> &argTypes,
                                   SmallVectorImpl<StringRef> &argNames) {
   consumeToken(Token::l_paren);
 
@@ -3284,7 +3284,7 @@
 ///   type-list)?
 ///
 ParseResult
-ModuleParser::parseFunctionSignature(StringRef &name, FunctionType *&type,
+ModuleParser::parseFunctionSignature(StringRef &name, FunctionType &type,
                                      SmallVectorImpl<StringRef> *argNames) {
   if (getToken().isNot(Token::at_identifier))
     return emitError("expected a function identifier like '@foo'");
@@ -3295,7 +3295,7 @@
   if (getToken().isNot(Token::l_paren))
     return emitError("expected '(' in function signature");
 
-  SmallVector<Type *, 4> argTypes;
+  SmallVector<Type, 4> argTypes;
   ParseResult parseResult;
 
   if (argNames)
@@ -3307,7 +3307,7 @@
     return ParseFailure;
 
   // Parse the return type if present.
-  SmallVector<Type *, 4> results;
+  SmallVector<Type, 4> results;
   if (consumeIf(Token::arrow)) {
     if (parseTypeList(results))
       return ParseFailure;
@@ -3340,7 +3340,7 @@
   auto loc = getToken().getLoc();
 
   StringRef name;
-  FunctionType *type = nullptr;
+  FunctionType type;
   if (parseFunctionSignature(name, type, /*arguments*/ nullptr))
     return ParseFailure;
 
@@ -3372,7 +3372,7 @@
   auto loc = getToken().getLoc();
 
   StringRef name;
-  FunctionType *type = nullptr;
+  FunctionType type;
   if (parseFunctionSignature(name, type, /*arguments*/ nullptr))
     return ParseFailure;
 
@@ -3405,7 +3405,7 @@
   consumeToken(Token::kw_mlfunc);
 
   StringRef name;
-  FunctionType *type = nullptr;
+  FunctionType type;
   SmallVector<StringRef, 4> argNames;
 
   auto loc = getToken().getLoc();
diff --git a/lib/StandardOps/StandardOps.cpp b/lib/StandardOps/StandardOps.cpp
index b60d209..e2bdfd7 100644
--- a/lib/StandardOps/StandardOps.cpp
+++ b/lib/StandardOps/StandardOps.cpp
@@ -138,23 +138,23 @@
 //===----------------------------------------------------------------------===//
 
 void AllocOp::build(Builder *builder, OperationState *result,
-                    MemRefType *memrefType, ArrayRef<SSAValue *> operands) {
+                    MemRefType memrefType, ArrayRef<SSAValue *> operands) {
   result->addOperands(operands);
   result->types.push_back(memrefType);
 }
 
 void AllocOp::print(OpAsmPrinter *p) const {
-  MemRefType *type = getType();
+  MemRefType type = getType();
   *p << "alloc";
   // Print dynamic dimension operands.
   printDimAndSymbolList(operand_begin(), operand_end(),
-                        type->getNumDynamicDims(), p);
+                        type.getNumDynamicDims(), p);
   p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"map");
-  *p << " : " << *type;
+  *p << " : " << type;
 }
 
 bool AllocOp::parse(OpAsmParser *parser, OperationState *result) {
-  MemRefType *type;
+  MemRefType type;
 
   // Parse the dimension operands and optional symbol operands, followed by a
   // memref type.
@@ -170,7 +170,7 @@
   // Verification still checks that the total number of operands matches
   // the number of symbols in the affine map, plus the number of dynamic
   // dimensions in the memref.
-  if (numDimOperands != type->getNumDynamicDims()) {
+  if (numDimOperands != type.getNumDynamicDims()) {
     return parser->emitError(parser->getNameLoc(),
                              "dimension operand count does not equal memref "
                              "dynamic dimension count");
@@ -180,13 +180,13 @@
 }
 
 bool AllocOp::verify() const {
-  auto *memRefType = dyn_cast<MemRefType>(getResult()->getType());
+  auto memRefType = getResult()->getType().dyn_cast<MemRefType>();
   if (!memRefType)
     return emitOpError("result must be a memref");
 
   unsigned numSymbols = 0;
-  if (!memRefType->getAffineMaps().empty()) {
-    AffineMap affineMap = memRefType->getAffineMaps()[0];
+  if (!memRefType.getAffineMaps().empty()) {
+    AffineMap affineMap = memRefType.getAffineMaps()[0];
     // Store number of symbols used in affine map (used in subsequent check).
     numSymbols = affineMap.getNumSymbols();
     // TODO(zinenko): this check does not belong to AllocOp, or any other op but
@@ -195,10 +195,10 @@
     // Remove when we can emit errors directly from *Type::get(...) functions.
     //
     // Verify that the layout affine map matches the rank of the memref.
-    if (affineMap.getNumDims() != memRefType->getRank())
+    if (affineMap.getNumDims() != memRefType.getRank())
       return emitOpError("affine map dimension count must equal memref rank");
   }
-  unsigned numDynamicDims = memRefType->getNumDynamicDims();
+  unsigned numDynamicDims = memRefType.getNumDynamicDims();
   // Check that the total number of operands matches the number of symbols in
   // the affine map, plus the number of dynamic dimensions specified in the
   // memref type.
@@ -208,7 +208,7 @@
   }
   // Verify that all operands are of type Index.
   for (auto *operand : getOperands()) {
-    if (!operand->getType()->isIndex())
+    if (!operand->getType().isIndex())
       return emitOpError("requires operands to be of type Index");
   }
   return false;
@@ -239,13 +239,13 @@
     // Ok, we have one or more constant operands.  Collect the non-constant ones
     // and keep track of the resultant memref type to build.
     SmallVector<int, 4> newShapeConstants;
-    newShapeConstants.reserve(memrefType->getRank());
+    newShapeConstants.reserve(memrefType.getRank());
     SmallVector<SSAValue *, 4> newOperands;
     SmallVector<SSAValue *, 4> droppedOperands;
 
     unsigned dynamicDimPos = 0;
-    for (unsigned dim = 0, e = memrefType->getRank(); dim < e; ++dim) {
-      int dimSize = memrefType->getDimSize(dim);
+    for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) {
+      int dimSize = memrefType.getDimSize(dim);
       // If this is already static dimension, keep it.
       if (dimSize != -1) {
         newShapeConstants.push_back(dimSize);
@@ -267,10 +267,10 @@
     }
 
     // Create new memref type (which will have fewer dynamic dimensions).
-    auto *newMemRefType = MemRefType::get(
-        newShapeConstants, memrefType->getElementType(),
-        memrefType->getAffineMaps(), memrefType->getMemorySpace());
-    assert(newOperands.size() == newMemRefType->getNumDynamicDims());
+    auto newMemRefType = MemRefType::get(
+        newShapeConstants, memrefType.getElementType(),
+        memrefType.getAffineMaps(), memrefType.getMemorySpace());
+    assert(newOperands.size() == newMemRefType.getNumDynamicDims());
 
     // Create and insert the alloc op for the new memref.
     auto newAlloc =
@@ -297,13 +297,13 @@
                    ArrayRef<SSAValue *> operands) {
   result->addOperands(operands);
   result->addAttribute("callee", builder->getFunctionAttr(callee));
-  result->addTypes(callee->getType()->getResults());
+  result->addTypes(callee->getType().getResults());
 }
 
 bool CallOp::parse(OpAsmParser *parser, OperationState *result) {
   StringRef calleeName;
   llvm::SMLoc calleeLoc;
-  FunctionType *calleeType = nullptr;
+  FunctionType calleeType;
   SmallVector<OpAsmParser::OperandType, 4> operands;
   Function *callee = nullptr;
   if (parser->parseFunctionName(calleeName, calleeLoc) ||
@@ -312,8 +312,8 @@
       parser->parseOptionalAttributeDict(result->attributes) ||
       parser->parseColonType(calleeType) ||
       parser->resolveFunctionName(calleeName, calleeType, calleeLoc, callee) ||
-      parser->addTypesToList(calleeType->getResults(), result->types) ||
-      parser->resolveOperands(operands, calleeType->getInputs(), calleeLoc,
+      parser->addTypesToList(calleeType.getResults(), result->types) ||
+      parser->resolveOperands(operands, calleeType.getInputs(), calleeLoc,
                               result->operands))
     return true;
 
@@ -328,7 +328,7 @@
   p->printOperands(getOperands());
   *p << ')';
   p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"callee");
-  *p << " : " << *getCallee()->getType();
+  *p << " : " << getCallee()->getType();
 }
 
 bool CallOp::verify() const {
@@ -338,20 +338,20 @@
     return emitOpError("requires a 'callee' function attribute");
 
   // Verify that the operand and result types match the callee.
-  auto *fnType = fnAttr.getValue()->getType();
-  if (fnType->getNumInputs() != getNumOperands())
+  auto fnType = fnAttr.getValue()->getType();
+  if (fnType.getNumInputs() != getNumOperands())
     return emitOpError("incorrect number of operands for callee");
 
-  for (unsigned i = 0, e = fnType->getNumInputs(); i != e; ++i) {
-    if (getOperand(i)->getType() != fnType->getInput(i))
+  for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) {
+    if (getOperand(i)->getType() != fnType.getInput(i))
       return emitOpError("operand type mismatch");
   }
 
-  if (fnType->getNumResults() != getNumResults())
+  if (fnType.getNumResults() != getNumResults())
     return emitOpError("incorrect number of results for callee");
 
-  for (unsigned i = 0, e = fnType->getNumResults(); i != e; ++i) {
-    if (getResult(i)->getType() != fnType->getResult(i))
+  for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) {
+    if (getResult(i)->getType() != fnType.getResult(i))
       return emitOpError("result type mismatch");
   }
 
@@ -364,14 +364,14 @@
 
 void CallIndirectOp::build(Builder *builder, OperationState *result,
                            SSAValue *callee, ArrayRef<SSAValue *> operands) {
-  auto *fnType = cast<FunctionType>(callee->getType());
+  auto fnType = callee->getType().cast<FunctionType>();
   result->operands.push_back(callee);
   result->addOperands(operands);
-  result->addTypes(fnType->getResults());
+  result->addTypes(fnType.getResults());
 }
 
 bool CallIndirectOp::parse(OpAsmParser *parser, OperationState *result) {
-  FunctionType *calleeType = nullptr;
+  FunctionType calleeType;
   OpAsmParser::OperandType callee;
   llvm::SMLoc operandsLoc;
   SmallVector<OpAsmParser::OperandType, 4> operands;
@@ -382,9 +382,9 @@
          parser->parseOptionalAttributeDict(result->attributes) ||
          parser->parseColonType(calleeType) ||
          parser->resolveOperand(callee, calleeType, result->operands) ||
-         parser->resolveOperands(operands, calleeType->getInputs(), operandsLoc,
+         parser->resolveOperands(operands, calleeType.getInputs(), operandsLoc,
                                  result->operands) ||
-         parser->addTypesToList(calleeType->getResults(), result->types);
+         parser->addTypesToList(calleeType.getResults(), result->types);
 }
 
 void CallIndirectOp::print(OpAsmPrinter *p) const {
@@ -395,29 +395,29 @@
   p->printOperands(++operandRange.begin(), operandRange.end());
   *p << ')';
   p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"callee");
-  *p << " : " << *getCallee()->getType();
+  *p << " : " << getCallee()->getType();
 }
 
 bool CallIndirectOp::verify() const {
   // The callee must be a function.
-  auto *fnType = dyn_cast<FunctionType>(getCallee()->getType());
+  auto fnType = getCallee()->getType().dyn_cast<FunctionType>();
   if (!fnType)
     return emitOpError("callee must have function type");
 
   // Verify that the operand and result types match the callee.
-  if (fnType->getNumInputs() != getNumOperands() - 1)
+  if (fnType.getNumInputs() != getNumOperands() - 1)
     return emitOpError("incorrect number of operands for callee");
 
-  for (unsigned i = 0, e = fnType->getNumInputs(); i != e; ++i) {
-    if (getOperand(i + 1)->getType() != fnType->getInput(i))
+  for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) {
+    if (getOperand(i + 1)->getType() != fnType.getInput(i))
       return emitOpError("operand type mismatch");
   }
 
-  if (fnType->getNumResults() != getNumResults())
+  if (fnType.getNumResults() != getNumResults())
     return emitOpError("incorrect number of results for callee");
 
-  for (unsigned i = 0, e = fnType->getNumResults(); i != e; ++i) {
-    if (getResult(i)->getType() != fnType->getResult(i))
+  for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) {
+    if (getResult(i)->getType() != fnType.getResult(i))
       return emitOpError("result type mismatch");
   }
 
@@ -434,19 +434,19 @@
 }
 
 void DeallocOp::print(OpAsmPrinter *p) const {
-  *p << "dealloc " << *getMemRef() << " : " << *getMemRef()->getType();
+  *p << "dealloc " << *getMemRef() << " : " << getMemRef()->getType();
 }
 
 bool DeallocOp::parse(OpAsmParser *parser, OperationState *result) {
   OpAsmParser::OperandType memrefInfo;
-  MemRefType *type;
+  MemRefType type;
 
   return parser->parseOperand(memrefInfo) || parser->parseColonType(type) ||
          parser->resolveOperand(memrefInfo, type, result->operands);
 }
 
 bool DeallocOp::verify() const {
-  if (!isa<MemRefType>(getMemRef()->getType()))
+  if (!getMemRef()->getType().isa<MemRefType>())
     return emitOpError("operand must be a memref");
   return false;
 }
@@ -472,13 +472,13 @@
 void DimOp::print(OpAsmPrinter *p) const {
   *p << "dim " << *getOperand() << ", " << getIndex();
   p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"index");
-  *p << " : " << *getOperand()->getType();
+  *p << " : " << getOperand()->getType();
 }
 
 bool DimOp::parse(OpAsmParser *parser, OperationState *result) {
   OpAsmParser::OperandType operandInfo;
   IntegerAttr indexAttr;
-  Type *type;
+  Type type;
 
   return parser->parseOperand(operandInfo) || parser->parseComma() ||
          parser->parseAttribute(indexAttr, "index", result->attributes) ||
@@ -496,15 +496,15 @@
     return emitOpError("requires an integer attribute named 'index'");
   uint64_t index = (uint64_t)indexAttr.getValue();
 
-  auto *type = getOperand()->getType();
-  if (auto *tensorType = dyn_cast<RankedTensorType>(type)) {
-    if (index >= tensorType->getRank())
+  auto type = getOperand()->getType();
+  if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
+    if (index >= tensorType.getRank())
       return emitOpError("index is out of range");
-  } else if (auto *memrefType = dyn_cast<MemRefType>(type)) {
-    if (index >= memrefType->getRank())
+  } else if (auto memrefType = type.dyn_cast<MemRefType>()) {
+    if (index >= memrefType.getRank())
       return emitOpError("index is out of range");
 
-  } else if (isa<UnrankedTensorType>(type)) {
+  } else if (type.isa<UnrankedTensorType>()) {
     // ok, assumed to be in-range.
   } else {
     return emitOpError("requires an operand with tensor or memref type");
@@ -516,12 +516,12 @@
 Attribute DimOp::constantFold(ArrayRef<Attribute> operands,
                               MLIRContext *context) const {
   // Constant fold dim when the size along the index referred to is a constant.
-  auto *opType = getOperand()->getType();
+  auto opType = getOperand()->getType();
   int indexSize = -1;
-  if (auto *tensorType = dyn_cast<RankedTensorType>(opType)) {
-    indexSize = tensorType->getShape()[getIndex()];
-  } else if (auto *memrefType = dyn_cast<MemRefType>(opType)) {
-    indexSize = memrefType->getShape()[getIndex()];
+  if (auto tensorType = opType.dyn_cast<RankedTensorType>()) {
+    indexSize = tensorType.getShape()[getIndex()];
+  } else if (auto memrefType = opType.dyn_cast<MemRefType>()) {
+    indexSize = memrefType.getShape()[getIndex()];
   }
 
   if (indexSize >= 0)
@@ -544,9 +544,9 @@
   p->printOperands(getTagIndices());
   *p << ']';
   p->printOptionalAttrDict(getAttrs());
-  *p << " : " << *getSrcMemRef()->getType();
-  *p << ", " << *getDstMemRef()->getType();
-  *p << ", " << *getTagMemRef()->getType();
+  *p << " : " << getSrcMemRef()->getType();
+  *p << ", " << getDstMemRef()->getType();
+  *p << ", " << getTagMemRef()->getType();
 }
 
 // Parse DmaStartOp.
@@ -566,8 +566,8 @@
   OpAsmParser::OperandType tagMemrefInfo;
   SmallVector<OpAsmParser::OperandType, 4> tagIndexInfos;
 
-  SmallVector<Type *, 3> types;
-  auto *indexType = parser->getBuilder().getIndexType();
+  SmallVector<Type, 3> types;
+  auto indexType = parser->getBuilder().getIndexType();
 
   // Parse and resolve the following list of operands:
   // *) source memref followed by its indices (in square brackets).
@@ -601,12 +601,12 @@
     return true;
 
   // Check that source/destination index list size matches associated rank.
-  if (srcIndexInfos.size() != cast<MemRefType>(types[0])->getRank() ||
-      dstIndexInfos.size() != cast<MemRefType>(types[1])->getRank())
+  if (srcIndexInfos.size() != types[0].cast<MemRefType>().getRank() ||
+      dstIndexInfos.size() != types[1].cast<MemRefType>().getRank())
     return parser->emitError(parser->getNameLoc(),
                              "memref rank not equal to indices count");
 
-  if (tagIndexInfos.size() != cast<MemRefType>(types[2])->getRank())
+  if (tagIndexInfos.size() != types[2].cast<MemRefType>().getRank())
     return parser->emitError(parser->getNameLoc(),
                              "tag memref rank not equal to indices count");
 
@@ -632,7 +632,7 @@
   p->printOperands(getTagIndices());
   *p << "], ";
   p->printOperand(getNumElements());
-  *p << " : " << *getTagMemRef()->getType();
+  *p << " : " << getTagMemRef()->getType();
 }
 
 // Parse DmaWaitOp.
@@ -642,8 +642,8 @@
 bool DmaWaitOp::parse(OpAsmParser *parser, OperationState *result) {
   OpAsmParser::OperandType tagMemrefInfo;
   SmallVector<OpAsmParser::OperandType, 2> tagIndexInfos;
-  Type *type;
-  auto *indexType = parser->getBuilder().getIndexType();
+  Type type;
+  auto indexType = parser->getBuilder().getIndexType();
   OpAsmParser::OperandType numElementsInfo;
 
   // Parse tag memref, its indices, and dma size.
@@ -657,7 +657,7 @@
       parser->resolveOperand(numElementsInfo, indexType, result->operands))
     return true;
 
-  if (tagIndexInfos.size() != cast<MemRefType>(type)->getRank())
+  if (tagIndexInfos.size() != type.cast<MemRefType>().getRank())
     return parser->emitError(parser->getNameLoc(),
                              "tag memref rank not equal to indices count");
 
@@ -678,10 +678,10 @@
 void ExtractElementOp::build(Builder *builder, OperationState *result,
                              SSAValue *aggregate,
                              ArrayRef<SSAValue *> indices) {
-  auto *aggregateType = cast<VectorOrTensorType>(aggregate->getType());
+  auto aggregateType = aggregate->getType().cast<VectorOrTensorType>();
   result->addOperands(aggregate);
   result->addOperands(indices);
-  result->types.push_back(aggregateType->getElementType());
+  result->types.push_back(aggregateType.getElementType());
 }
 
 void ExtractElementOp::print(OpAsmPrinter *p) const {
@@ -689,13 +689,13 @@
   p->printOperands(getIndices());
   *p << ']';
   p->printOptionalAttrDict(getAttrs());
-  *p << " : " << *getAggregate()->getType();
+  *p << " : " << getAggregate()->getType();
 }
 
 bool ExtractElementOp::parse(OpAsmParser *parser, OperationState *result) {
   OpAsmParser::OperandType aggregateInfo;
   SmallVector<OpAsmParser::OperandType, 4> indexInfo;
-  VectorOrTensorType *type;
+  VectorOrTensorType type;
 
   auto affineIntTy = parser->getBuilder().getIndexType();
   return parser->parseOperand(aggregateInfo) ||
@@ -705,26 +705,26 @@
          parser->parseColonType(type) ||
          parser->resolveOperand(aggregateInfo, type, result->operands) ||
          parser->resolveOperands(indexInfo, affineIntTy, result->operands) ||
-         parser->addTypeToList(type->getElementType(), result->types);
+         parser->addTypeToList(type.getElementType(), result->types);
 }
 
 bool ExtractElementOp::verify() const {
   if (getNumOperands() == 0)
     return emitOpError("expected an aggregate to index into");
 
-  auto *aggregateType = dyn_cast<VectorOrTensorType>(getAggregate()->getType());
+  auto aggregateType = getAggregate()->getType().dyn_cast<VectorOrTensorType>();
   if (!aggregateType)
     return emitOpError("first operand must be a vector or tensor");
 
-  if (getType() != aggregateType->getElementType())
+  if (getType() != aggregateType.getElementType())
     return emitOpError("result type must match element type of aggregate");
 
   for (auto *idx : getIndices())
-    if (!idx->getType()->isIndex())
+    if (!idx->getType().isIndex())
       return emitOpError("index to extract_element must have 'index' type");
 
   // Verify the # indices match if we have a ranked type.
-  auto aggregateRank = aggregateType->getRank();
+  auto aggregateRank = aggregateType.getRank();
   if (aggregateRank != -1 && aggregateRank != getNumOperands() - 1)
     return emitOpError("incorrect number of indices for extract_element");
 
@@ -737,10 +737,10 @@
 
 void LoadOp::build(Builder *builder, OperationState *result, SSAValue *memref,
                    ArrayRef<SSAValue *> indices) {
-  auto *memrefType = cast<MemRefType>(memref->getType());
+  auto memrefType = memref->getType().cast<MemRefType>();
   result->addOperands(memref);
   result->addOperands(indices);
-  result->types.push_back(memrefType->getElementType());
+  result->types.push_back(memrefType.getElementType());
 }
 
 void LoadOp::print(OpAsmPrinter *p) const {
@@ -748,13 +748,13 @@
   p->printOperands(getIndices());
   *p << ']';
   p->printOptionalAttrDict(getAttrs());
-  *p << " : " << *getMemRefType();
+  *p << " : " << getMemRefType();
 }
 
 bool LoadOp::parse(OpAsmParser *parser, OperationState *result) {
   OpAsmParser::OperandType memrefInfo;
   SmallVector<OpAsmParser::OperandType, 4> indexInfo;
-  MemRefType *type;
+  MemRefType type;
 
   auto affineIntTy = parser->getBuilder().getIndexType();
   return parser->parseOperand(memrefInfo) ||
@@ -764,25 +764,25 @@
          parser->parseColonType(type) ||
          parser->resolveOperand(memrefInfo, type, result->operands) ||
          parser->resolveOperands(indexInfo, affineIntTy, result->operands) ||
-         parser->addTypeToList(type->getElementType(), result->types);
+         parser->addTypeToList(type.getElementType(), result->types);
 }
 
 bool LoadOp::verify() const {
   if (getNumOperands() == 0)
     return emitOpError("expected a memref to load from");
 
-  auto *memRefType = dyn_cast<MemRefType>(getMemRef()->getType());
+  auto memRefType = getMemRef()->getType().dyn_cast<MemRefType>();
   if (!memRefType)
     return emitOpError("first operand must be a memref");
 
-  if (getType() != memRefType->getElementType())
+  if (getType() != memRefType.getElementType())
     return emitOpError("result type must match element type of memref");
 
-  if (memRefType->getRank() != getNumOperands() - 1)
+  if (memRefType.getRank() != getNumOperands() - 1)
     return emitOpError("incorrect number of indices for load");
 
   for (auto *idx : getIndices())
-    if (!idx->getType()->isIndex())
+    if (!idx->getType().isIndex())
       return emitOpError("index to load must have 'index' type");
 
   // TODO: Verify we have the right number of indices.
@@ -804,31 +804,31 @@
 //===----------------------------------------------------------------------===//
 
 bool MemRefCastOp::verify() const {
-  auto *opType = dyn_cast<MemRefType>(getOperand()->getType());
-  auto *resType = dyn_cast<MemRefType>(getType());
+  auto opType = getOperand()->getType().dyn_cast<MemRefType>();
+  auto resType = getType().dyn_cast<MemRefType>();
   if (!opType || !resType)
     return emitOpError("requires input and result types to be memrefs");
 
   if (opType == resType)
     return emitOpError("requires the input and result type to be different");
 
-  if (opType->getElementType() != resType->getElementType())
+  if (opType.getElementType() != resType.getElementType())
     return emitOpError(
         "requires input and result element types to be the same");
 
-  if (opType->getAffineMaps() != resType->getAffineMaps())
+  if (opType.getAffineMaps() != resType.getAffineMaps())
     return emitOpError("requires input and result mappings to be the same");
 
-  if (opType->getMemorySpace() != resType->getMemorySpace())
+  if (opType.getMemorySpace() != resType.getMemorySpace())
     return emitOpError(
         "requires input and result memory spaces to be the same");
 
   // They must have the same rank, and any specified dimensions must match.
-  if (opType->getRank() != resType->getRank())
+  if (opType.getRank() != resType.getRank())
     return emitOpError("requires input and result ranks to match");
 
-  for (unsigned i = 0, e = opType->getRank(); i != e; ++i) {
-    int opDim = opType->getDimSize(i), resultDim = resType->getDimSize(i);
+  for (unsigned i = 0, e = opType.getRank(); i != e; ++i) {
+    int opDim = opType.getDimSize(i), resultDim = resType.getDimSize(i);
     if (opDim != -1 && resultDim != -1 && opDim != resultDim)
       return emitOpError("requires static dimensions to match");
   }
@@ -923,14 +923,14 @@
   p->printOperands(getIndices());
   *p << ']';
   p->printOptionalAttrDict(getAttrs());
-  *p << " : " << *getMemRefType();
+  *p << " : " << getMemRefType();
 }
 
 bool StoreOp::parse(OpAsmParser *parser, OperationState *result) {
   OpAsmParser::OperandType storeValueInfo;
   OpAsmParser::OperandType memrefInfo;
   SmallVector<OpAsmParser::OperandType, 4> indexInfo;
-  MemRefType *memrefType;
+  MemRefType memrefType;
 
   auto affineIntTy = parser->getBuilder().getIndexType();
   return parser->parseOperand(storeValueInfo) || parser->parseComma() ||
@@ -939,7 +939,7 @@
                                   OpAsmParser::Delimiter::Square) ||
          parser->parseOptionalAttributeDict(result->attributes) ||
          parser->parseColonType(memrefType) ||
-         parser->resolveOperand(storeValueInfo, memrefType->getElementType(),
+         parser->resolveOperand(storeValueInfo, memrefType.getElementType(),
                                 result->operands) ||
          parser->resolveOperand(memrefInfo, memrefType, result->operands) ||
          parser->resolveOperands(indexInfo, affineIntTy, result->operands);
@@ -950,19 +950,19 @@
     return emitOpError("expected a value to store and a memref");
 
   // Second operand is a memref type.
-  auto *memRefType = dyn_cast<MemRefType>(getMemRef()->getType());
+  auto memRefType = getMemRef()->getType().dyn_cast<MemRefType>();
   if (!memRefType)
     return emitOpError("second operand must be a memref");
 
   // First operand must have same type as memref element type.
-  if (getValueToStore()->getType() != memRefType->getElementType())
+  if (getValueToStore()->getType() != memRefType.getElementType())
     return emitOpError("first operand must have same type memref element type");
 
-  if (getNumOperands() != 2 + memRefType->getRank())
+  if (getNumOperands() != 2 + memRefType.getRank())
     return emitOpError("store index operand count not equal to memref rank");
 
   for (auto *idx : getIndices())
-    if (!idx->getType()->isIndex())
+    if (!idx->getType().isIndex())
       return emitOpError("index to load must have 'index' type");
 
   // TODO: Verify we have the right number of indices.
@@ -1046,31 +1046,31 @@
 //===----------------------------------------------------------------------===//
 
 bool TensorCastOp::verify() const {
-  auto *opType = dyn_cast<TensorType>(getOperand()->getType());
-  auto *resType = dyn_cast<TensorType>(getType());
+  auto opType = getOperand()->getType().dyn_cast<TensorType>();
+  auto resType = getType().dyn_cast<TensorType>();
   if (!opType || !resType)
     return emitOpError("requires input and result types to be tensors");
 
   if (opType == resType)
     return emitOpError("requires the input and result type to be different");
 
-  if (opType->getElementType() != resType->getElementType())
+  if (opType.getElementType() != resType.getElementType())
     return emitOpError(
         "requires input and result element types to be the same");
 
   // If the source or destination are unranked, then the cast is valid.
-  auto *opRType = dyn_cast<RankedTensorType>(opType);
-  auto *resRType = dyn_cast<RankedTensorType>(resType);
+  auto opRType = opType.dyn_cast<RankedTensorType>();
+  auto resRType = resType.dyn_cast<RankedTensorType>();
   if (!opRType || !resRType)
     return false;
 
   // If they are both ranked, they have to have the same rank, and any specified
   // dimensions must match.
-  if (opRType->getRank() != resRType->getRank())
+  if (opRType.getRank() != resRType.getRank())
     return emitOpError("requires input and result ranks to match");
 
-  for (unsigned i = 0, e = opRType->getRank(); i != e; ++i) {
-    int opDim = opRType->getDimSize(i), resultDim = resRType->getDimSize(i);
+  for (unsigned i = 0, e = opRType.getRank(); i != e; ++i) {
+    int opDim = opRType.getDimSize(i), resultDim = resRType.getDimSize(i);
     if (opDim != -1 && resultDim != -1 && opDim != resultDim)
       return emitOpError("requires static dimensions to match");
   }
diff --git a/lib/Transforms/ConstantFold.cpp b/lib/Transforms/ConstantFold.cpp
index 81994dd..15dd89b 100644
--- a/lib/Transforms/ConstantFold.cpp
+++ b/lib/Transforms/ConstantFold.cpp
@@ -31,7 +31,7 @@
   SmallVector<SSAValue *, 8> existingConstants;
   // Operation statements that were folded and that need to be erased.
   std::vector<OperationStmt *> opStmtsToErase;
-  using ConstantFactoryType = std::function<SSAValue *(Attribute, Type *)>;
+  using ConstantFactoryType = std::function<SSAValue *(Attribute, Type)>;
 
   bool foldOperation(Operation *op,
                      SmallVectorImpl<SSAValue *> &existingConstants,
@@ -106,7 +106,7 @@
     for (auto instIt = bb.begin(), e = bb.end(); instIt != e;) {
       auto &inst = *instIt++;
 
-      auto constantFactory = [&](Attribute value, Type *type) -> SSAValue * {
+      auto constantFactory = [&](Attribute value, Type type) -> SSAValue * {
         builder.setInsertionPoint(&inst);
         return builder.create<ConstantOp>(inst.getLoc(), value, type);
       };
@@ -134,7 +134,7 @@
 
 // Override the walker's operation statement visit for constant folding.
 void ConstantFold::visitOperationStmt(OperationStmt *stmt) {
-  auto constantFactory = [&](Attribute value, Type *type) -> SSAValue * {
+  auto constantFactory = [&](Attribute value, Type type) -> SSAValue * {
     MLFuncBuilder builder(stmt);
     return builder.create<ConstantOp>(stmt->getLoc(), value, type);
   };
diff --git a/lib/Transforms/PipelineDataTransfer.cpp b/lib/Transforms/PipelineDataTransfer.cpp
index d96d65b..9042181 100644
--- a/lib/Transforms/PipelineDataTransfer.cpp
+++ b/lib/Transforms/PipelineDataTransfer.cpp
@@ -77,23 +77,23 @@
   bInner.setInsertionPoint(forStmt, forStmt->begin());
 
   // Doubles the shape with a leading dimension extent of 2.
-  auto doubleShape = [&](MemRefType *oldMemRefType) -> MemRefType * {
+  auto doubleShape = [&](MemRefType oldMemRefType) -> MemRefType {
     // Add the leading dimension in the shape for the double buffer.
-    ArrayRef<int> shape = oldMemRefType->getShape();
+    ArrayRef<int> shape = oldMemRefType.getShape();
     SmallVector<int, 4> shapeSizes(shape.begin(), shape.end());
     shapeSizes.insert(shapeSizes.begin(), 2);
 
-    auto *newMemRefType =
-        bInner.getMemRefType(shapeSizes, oldMemRefType->getElementType(), {},
-                             oldMemRefType->getMemorySpace());
+    auto newMemRefType =
+        bInner.getMemRefType(shapeSizes, oldMemRefType.getElementType(), {},
+                             oldMemRefType.getMemorySpace());
     return newMemRefType;
   };
 
-  auto *newMemRefType = doubleShape(cast<MemRefType>(oldMemRef->getType()));
+  auto newMemRefType = doubleShape(oldMemRef->getType().cast<MemRefType>());
 
   // Create and place the alloc at the top level.
   MLFuncBuilder topBuilder(forStmt->getFunction());
-  auto *newMemRef = cast<MLValue>(
+  auto newMemRef = cast<MLValue>(
       topBuilder.create<AllocOp>(forStmt->getLoc(), newMemRefType)
           ->getResult());
 
diff --git a/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index cdf5b71..4ec8942 100644
--- a/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -78,7 +78,7 @@
   /// As part of canonicalization, we move constants to the top of the entry
   /// block of the current function and de-duplicate them.  This keeps track of
   /// constants we have done this for.
-  DenseMap<std::pair<Attribute, Type *>, Operation *> uniquedConstants;
+  DenseMap<std::pair<Attribute, Type>, Operation *> uniquedConstants;
 };
 }; // end anonymous namespace
 
diff --git a/lib/Transforms/Utils/Utils.cpp b/lib/Transforms/Utils/Utils.cpp
index edd8ce8..ad9d6dc 100644
--- a/lib/Transforms/Utils/Utils.cpp
+++ b/lib/Transforms/Utils/Utils.cpp
@@ -52,9 +52,9 @@
                                     MLValue *newMemRef,
                                     ArrayRef<MLValue *> extraIndices,
                                     AffineMap indexRemap) {
-  unsigned newMemRefRank = cast<MemRefType>(newMemRef->getType())->getRank();
+  unsigned newMemRefRank = newMemRef->getType().cast<MemRefType>().getRank();
   (void)newMemRefRank; // unused in opt mode
-  unsigned oldMemRefRank = cast<MemRefType>(oldMemRef->getType())->getRank();
+  unsigned oldMemRefRank = oldMemRef->getType().cast<MemRefType>().getRank();
   (void)newMemRefRank;
   if (indexRemap) {
     assert(indexRemap.getNumInputs() == oldMemRefRank);
@@ -64,8 +64,8 @@
   }
 
   // Assert same elemental type.
-  assert(cast<MemRefType>(oldMemRef->getType())->getElementType() ==
-         cast<MemRefType>(newMemRef->getType())->getElementType());
+  assert(oldMemRef->getType().cast<MemRefType>().getElementType() ==
+         newMemRef->getType().cast<MemRefType>().getElementType());
 
   // Check if memref was used in a non-deferencing context.
   for (const StmtOperand &use : oldMemRef->getUses()) {
@@ -139,7 +139,7 @@
                     opStmt->operand_end());
 
     // Result types don't change. Both memref's are of the same elemental type.
-    SmallVector<Type *, 8> resultTypes;
+    SmallVector<Type, 8> resultTypes;
     resultTypes.reserve(opStmt->getNumResults());
     for (const auto *result : opStmt->getResults())
       resultTypes.push_back(result->getType());
diff --git a/lib/Transforms/Vectorize.cpp b/lib/Transforms/Vectorize.cpp
index d7a1f53..511afa9 100644
--- a/lib/Transforms/Vectorize.cpp
+++ b/lib/Transforms/Vectorize.cpp
@@ -202,15 +202,15 @@
 /// sizes specified by vectorSize. The MemRef lives in the same memory space as
 /// tmpl. The MemRef should be promoted to a closer memory address space in a
 /// later pass.
-static MemRefType *getVectorizedMemRefType(MemRefType *tmpl,
-                                           ArrayRef<int> vectorSizes) {
-  auto *elementType = tmpl->getElementType();
-  assert(!dyn_cast<VectorType>(elementType) &&
+static MemRefType getVectorizedMemRefType(MemRefType tmpl,
+                                          ArrayRef<int> vectorSizes) {
+  auto elementType = tmpl.getElementType();
+  assert(!elementType.dyn_cast<VectorType>() &&
          "Can't vectorize an already vector type");
-  assert(tmpl->getAffineMaps().empty() &&
+  assert(tmpl.getAffineMaps().empty() &&
          "Unsupported non-implicit identity map");
   return MemRefType::get({1}, VectorType::get(vectorSizes, elementType), {},
-                         tmpl->getMemorySpace());
+                         tmpl.getMemorySpace());
 }
 
 /// Creates an unaligned load with the following semantics:
@@ -258,7 +258,7 @@
   operands.insert(operands.end(), dstMemRef);
   operands.insert(operands.end(), dstIndices.begin(), dstIndices.end());
   using functional::map;
-  std::function<Type *(SSAValue *)> getType = [](SSAValue *v) -> Type * {
+  std::function<Type(SSAValue *)> getType = [](SSAValue *v) -> Type {
     return v->getType();
   };
   auto types = map(getType, operands);
@@ -310,7 +310,7 @@
   operands.insert(operands.end(), dstMemRef);
   operands.insert(operands.end(), dstIndices.begin(), dstIndices.end());
   using functional::map;
-  std::function<Type *(SSAValue *)> getType = [](SSAValue *v) -> Type * {
+  std::function<Type(SSAValue *)> getType = [](SSAValue *v) -> Type {
     return v->getType();
   };
   auto types = map(getType, operands);
@@ -348,8 +348,9 @@
 template <typename LoadOrStoreOpPointer>
 static MLValue *materializeVector(MLValue *iv, LoadOrStoreOpPointer memoryOp,
                                   ArrayRef<int> vectorSize) {
-  auto *memRefType = cast<MemRefType>(memoryOp->getMemRef()->getType());
-  auto *vectorMemRefType = getVectorizedMemRefType(memRefType, vectorSize);
+  auto memRefType =
+      memoryOp->getMemRef()->getType().template cast<MemRefType>();
+  auto vectorMemRefType = getVectorizedMemRefType(memRefType, vectorSize);
 
   // Materialize a MemRef with 1 vector.
   auto *opStmt = cast<OperationStmt>(memoryOp->getOperation());