Introduce IR support for MLIRContext, primitive types, function types, and
vector types.

tensors and memref types are still TODO, and would be a good starter project
for someone.

PiperOrigin-RevId: 201782748
diff --git a/lib/IR/Function.cpp b/lib/IR/Function.cpp
index 6407e9c..54c9166 100644
--- a/lib/IR/Function.cpp
+++ b/lib/IR/Function.cpp
@@ -16,14 +16,37 @@
 // =============================================================================
 
 #include "mlir/IR/Function.h"
+#include "mlir/IR/Types.h"
+#include "mlir/Support/STLExtras.h"
 #include "llvm/Support/raw_ostream.h"
 using namespace mlir;
 
-Function::Function(StringRef name) : name(name.str()) {
+Function::Function(StringRef name, FunctionType *type)
+  : name(name.str()), type(type) {
 }
 
 void Function::print(raw_ostream &os) {
-  os << "extfunc @" << name << "()\n";
+  os << "extfunc @" << name << '(';
+  interleave(type->getInputs(),
+             [&](Type *eltType) { os << *eltType; },
+             [&]() { os << ", "; });
+  os << ')';
+
+  switch (type->getResults().size()) {
+  case 0: break;
+  case 1:
+    os << " -> " << *type->getResults()[0];
+    break;
+  default:
+    os << " -> (";
+    interleave(type->getResults(),
+               [&](Type *eltType) { os << *eltType; },
+               [&]() { os << ", "; });
+    os << ')';
+    break;
+  }
+
+  os << "\n";
 }
 
 void Function::dump() {
diff --git a/lib/IR/MLIRContext.cpp b/lib/IR/MLIRContext.cpp
new file mode 100644
index 0000000..a2befc3
--- /dev/null
+++ b/lib/IR/MLIRContext.cpp
@@ -0,0 +1,200 @@
+//===- MLIRContext.cpp - MLIR Type Classes --------------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Types.h"
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/DenseSet.h"
+#include "llvm/Support/Allocator.h"
+using namespace mlir;
+using namespace llvm;
+
+namespace {
+struct FunctionTypeKeyInfo : DenseMapInfo<FunctionType*> {
+  // Functions are uniqued based on their inputs and results.
+  using KeyTy = std::pair<ArrayRef<Type*>, ArrayRef<Type*>>;
+  using DenseMapInfo<FunctionType*>::getHashValue;
+  using DenseMapInfo<FunctionType*>::isEqual;
+
+  static unsigned getHashValue(KeyTy key) {
+    return hash_combine(hash_combine_range(key.first.begin(), key.first.end()),
+                        hash_combine_range(key.second.begin(),
+                                           key.second.end()));
+  }
+
+  static bool isEqual(const KeyTy &lhs, const FunctionType *rhs) {
+    if (rhs == getEmptyKey() || rhs == getTombstoneKey())
+      return false;
+    return lhs == KeyTy(rhs->getInputs(), rhs->getResults());
+  }
+};
+struct VectorTypeKeyInfo : DenseMapInfo<VectorType*> {
+  // Vectors are uniqued based on their element type and shape.
+  using KeyTy = std::pair<Type*, ArrayRef<unsigned>>;
+  using DenseMapInfo<VectorType*>::getHashValue;
+  using DenseMapInfo<VectorType*>::isEqual;
+
+  static unsigned getHashValue(KeyTy key) {
+    return hash_combine(DenseMapInfo<Type*>::getHashValue(key.first),
+                        hash_combine_range(key.second.begin(),
+                                           key.second.end()));
+  }
+
+  static bool isEqual(const KeyTy &lhs, const VectorType *rhs) {
+    if (rhs == getEmptyKey() || rhs == getTombstoneKey())
+      return false;
+    return lhs == KeyTy(rhs->getElementType(), rhs->getShape());
+  }
+};
+} // end anonymous namespace.
+
+
+namespace mlir {
+/// This is the implementation of the MLIRContext class, using the pImpl idiom.
+/// This class is completely private to this file, so everything is public.
+class MLIRContextImpl {
+public:
+  /// We put immortal objects into this allocator.
+  llvm::BumpPtrAllocator allocator;
+
+  // Primitive type uniquing.
+  PrimitiveType *primitives[int(TypeKind::LAST_PRIMITIVE_TYPE)+1] = { nullptr };
+
+  /// Function type uniquing.
+  using FunctionTypeSet = DenseSet<FunctionType*, FunctionTypeKeyInfo>;
+  FunctionTypeSet functions;
+
+  /// Vector type uniquing.
+  using VectorTypeSet = DenseSet<VectorType*, VectorTypeKeyInfo>;
+  VectorTypeSet vectors;
+
+
+public:
+  /// Copy the specified array of elements into memory managed by our bump
+  /// pointer allocator.  This assumes the elements are all PODs.
+  template<typename T>
+  ArrayRef<T> copyInto(ArrayRef<T> elements) {
+    auto result = allocator.Allocate<T>(elements.size());
+    std::uninitialized_copy(elements.begin(), elements.end(), result);
+    return ArrayRef<T>(result, elements.size());
+  }
+};
+} // end namespace mlir
+
+MLIRContext::MLIRContext() : impl(new MLIRContextImpl()) {
+}
+
+MLIRContext::~MLIRContext() {
+}
+
+
+PrimitiveType::PrimitiveType(TypeKind kind, MLIRContext *context)
+  : Type(kind, context) {
+
+}
+
+PrimitiveType *PrimitiveType::get(TypeKind kind, MLIRContext *context) {
+  assert(kind <= TypeKind::LAST_PRIMITIVE_TYPE && "Not a primitive type kind");
+  auto &impl = context->getImpl();
+
+  // We normally have these types.
+  if (impl.primitives[(int)kind])
+    return impl.primitives[(int)kind];
+
+  // On the first use, we allocate them into the bump pointer.
+  auto *ptr = impl.allocator.Allocate<PrimitiveType>();
+
+  // Initialize the memory using placement new.
+  new(ptr) PrimitiveType(kind, context);
+
+  // Cache and return it.
+  return impl.primitives[(int)kind] = ptr;
+}
+
+FunctionType::FunctionType(Type *const *inputsAndResults, unsigned numInputs,
+                           unsigned numResults, MLIRContext *context)
+  : Type(TypeKind::Function, context, numInputs),
+    numResults(numResults), inputsAndResults(inputsAndResults) {
+}
+
+FunctionType *FunctionType::get(ArrayRef<Type*> inputs, ArrayRef<Type*> results,
+                                MLIRContext *context) {
+  auto &impl = context->getImpl();
+
+  // Look to see if we already have this function type.
+  FunctionTypeKeyInfo::KeyTy key(inputs, results);
+  auto existing = impl.functions.insert_as(nullptr, key);
+
+  // If we already have it, return that value.
+  if (!existing.second)
+    return *existing.first;
+
+  // On the first use, we allocate them into the bump pointer.
+  auto *result = impl.allocator.Allocate<FunctionType>();
+
+  // Copy the inputs and results into the bump pointer.
+  SmallVector<Type*, 16> types;
+  types.reserve(inputs.size()+results.size());
+  types.append(inputs.begin(), inputs.end());
+  types.append(results.begin(), results.end());
+  auto typesList = impl.copyInto(ArrayRef<Type*>(types));
+
+  // Initialize the memory using placement new.
+  new (result) FunctionType(typesList.data(), inputs.size(), results.size(),
+                            context);
+
+  // Cache and return it.
+  return *existing.first = result;
+}
+
+
+
+VectorType::VectorType(ArrayRef<unsigned> shape, PrimitiveType *elementType,
+                       MLIRContext *context)
+  : Type(TypeKind::Vector, context, shape.size()),
+    shapeElements(shape.data()), elementType(elementType) {
+}
+
+
+VectorType *VectorType::get(ArrayRef<unsigned> shape, Type *elementType) {
+  assert(!shape.empty() && "vector types must have at least one dimension");
+  assert(isa<PrimitiveType>(elementType) &&
+         "vectors elements must be primitives");
+
+  auto *context = elementType->getContext();
+  auto &impl = context->getImpl();
+
+  // Look to see if we already have this vector type.
+  VectorTypeKeyInfo::KeyTy key(elementType, shape);
+  auto existing = impl.vectors.insert_as(nullptr, key);
+
+  // If we already have it, return that value.
+  if (!existing.second)
+    return *existing.first;
+
+  // On the first use, we allocate them into the bump pointer.
+  auto *result = impl.allocator.Allocate<VectorType>();
+
+  // Copy the shape into the bump pointer.
+  shape = impl.copyInto(shape);
+
+  // Initialize the memory using placement new.
+  new (result) VectorType(shape, cast<PrimitiveType>(elementType), context);
+
+  // Cache and return it.
+  return *existing.first = result;
+}
diff --git a/lib/IR/Types.cpp b/lib/IR/Types.cpp
new file mode 100644
index 0000000..5485995
--- /dev/null
+++ b/lib/IR/Types.cpp
@@ -0,0 +1,68 @@
+//===- Types.cpp - MLIR Type Classes --------------------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/IR/Types.h"
+#include "mlir/Support/LLVM.h"
+#include "llvm/Support/raw_ostream.h"
+#include "mlir/Support/STLExtras.h"
+using namespace mlir;
+
+void Type::print(raw_ostream &os) const {
+  switch (getKind()) {
+  case TypeKind::I1:   os << "i1"; return;
+  case TypeKind::I8:   os << "i8"; return;
+  case TypeKind::I16:  os << "i16"; return;
+  case TypeKind::I32:  os << "i32"; return;
+  case TypeKind::I64:  os << "i64"; return;
+  case TypeKind::Int:  os << "int"; return;
+  case TypeKind::BF16: os << "bf16"; return;
+  case TypeKind::F16:  os << "f16"; return;
+  case TypeKind::F32:  os << "f32"; return;
+  case TypeKind::F64:  os << "f64"; return;
+  case TypeKind::Function: {
+    auto *func = cast<FunctionType>(this);
+    os << '(';
+    interleave(func->getInputs(),
+               [&](Type *type) { os << *type; },
+               [&]() { os << ", "; });
+    os << ") -> ";
+    auto results = func->getResults();
+    if (results.size() == 1)
+      os << *results[0];
+    else {
+      os << '(';
+      interleave(results,
+                 [&](Type *type) { os << *type; },
+                 [&]() { os << ", "; });
+      os << ")";
+    }
+    return;
+  }
+  case TypeKind::Vector: {
+    auto *v = cast<VectorType>(this);
+    os << "vector<";
+    for (auto dim : v->getShape())
+      os << dim << 'x';
+    os << *v->getElementType() << '>';
+    return;
+  }
+  }
+}
+
+void Type::dump() const {
+  print(llvm::errs());
+}