[MLIR] Extend vectorization to 2+-D patterns

This CL adds support for vectorization using more interesting 2-D and 3-D
patterns. Note in particular the fact that we match some pretty complex
imperfectly nested 2-D patterns with a quite minimal change to the
implementation: we just add a bit of recursion to traverse the matched
patterns and actually vectorize the loops.

For instance, vectorizing the following loop by 128:
```
for %i3 = 0 to %0 {
  %7 = affine_apply (d0) -> (d0)(%i3)
  %8 = load %arg0[%c0_0, %7] : memref<?x?xf32>
}
```

Currently generates:
```
#map0 = ()[s0] -> (s0 + 127)
#map1 = (d0) -> (d0)
for %i3 = 0 to #map0()[%0] step 128 {
  %9 = affine_apply #map1(%i3)
  %10 = alloc() : memref<1xvector<128xf32>>
  %11 = "n_d_unaligned_load"(%arg0, %c0_0, %9, %10, %c0) :
    (memref<?x?xf32>, index, index, memref<1xvector<128xf32>>, index) ->
    (memref<?x?xf32>, index, index, memref<1xvector<128xf32>>, index)
   %12 = load %10[%c0] : memref<1xvector<128xf32>>
}
```

The above is subject to evolution.

PiperOrigin-RevId: 219629745
diff --git a/lib/Analysis/LoopAnalysis.cpp b/lib/Analysis/LoopAnalysis.cpp
index 1904a63..ce0bc6c 100644
--- a/lib/Analysis/LoopAnalysis.cpp
+++ b/lib/Analysis/LoopAnalysis.cpp
@@ -119,12 +119,12 @@
 }
 
 bool mlir::isAccessInvariant(const MLValue &input, MemRefType memRefType,
-                             ArrayRef<MLValue *> indices, unsigned dim) {
+                             ArrayRef<const MLValue *> indices, unsigned dim) {
   assert(indices.size() == memRefType.getRank());
   assert(dim < indices.size());
   auto layoutMap = memRefType.getAffineMaps();
   assert(memRefType.getAffineMaps().size() <= 1);
-  // TODO(ntv): remove dependency on Builder once we support non-identity
+  // TODO(ntv): remove dependence on Builder once we support non-identity
   // layout map.
   Builder b(memRefType.getContext());
   assert(layoutMap.empty() ||
@@ -132,7 +132,8 @@
   (void)layoutMap;
 
   SmallVector<OperationStmt *, 4> affineApplyOps;
-  getReachableAffineApplyOps({indices[dim]}, affineApplyOps);
+  getReachableAffineApplyOps({const_cast<MLValue *>(indices[dim])},
+                             affineApplyOps);
 
   if (affineApplyOps.empty()) {
     // Pointer equality test because of MLValue pointer semantics.
@@ -168,7 +169,7 @@
                                LoadOrStoreOpPointer memoryOp,
                                unsigned fastestVaryingDim) {
   using namespace functional;
-  auto indices = map([](SSAValue *val) { return dyn_cast<MLValue>(val); },
+  auto indices = map([](const SSAValue *val) { return dyn_cast<MLValue>(val); },
                      memoryOp->getIndices());
   auto memRefType = memoryOp->getMemRefType();
   for (unsigned d = 0, numIndices = indices.size(); d < numIndices; ++d) {
@@ -188,7 +189,11 @@
   return memRefType.getElementType().template isa<VectorType>();
 }
 
-bool mlir::isVectorizableLoop(const ForStmt &loop, unsigned fastestVaryingDim) {
+using VectorizableStmtFun =
+    std::function<bool(const ForStmt &, const OperationStmt &)>;
+
+static bool isVectorizableLoopWithCond(const ForStmt &loop,
+                                       VectorizableStmtFun isVectorizableStmt) {
   if (!matcher::isParallelLoop(loop) && !matcher::isReductionLoop(loop)) {
     return false;
   }
@@ -214,15 +219,32 @@
     if (vector) {
       return false;
     }
-    bool contiguous = load ? isContiguousAccess(loop, load, fastestVaryingDim)
-                           : isContiguousAccess(loop, store, fastestVaryingDim);
-    if (!contiguous) {
+    if (!isVectorizableStmt(loop, *op)) {
       return false;
     }
   }
   return true;
 }
 
+bool mlir::isVectorizableLoopAlongFastestVaryingMemRefDim(
+    const ForStmt &loop, unsigned fastestVaryingDim) {
+  VectorizableStmtFun fun(
+      [fastestVaryingDim](const ForStmt &loop, const OperationStmt &op) {
+        auto load = op.dyn_cast<LoadOp>();
+        auto store = op.dyn_cast<StoreOp>();
+        return load ? isContiguousAccess(loop, load, fastestVaryingDim)
+                    : isContiguousAccess(loop, store, fastestVaryingDim);
+      });
+  return isVectorizableLoopWithCond(loop, fun);
+}
+
+bool mlir::isVectorizableLoop(const ForStmt &loop) {
+  VectorizableStmtFun fun(
+      // TODO: implement me
+      [](const ForStmt &loop, const OperationStmt &op) { return true; });
+  return isVectorizableLoopWithCond(loop, fun);
+}
+
 /// Checks whether SSA dominance would be violated if a for stmt's body
 /// statements are shifted by the specified shifts. This method checks if a
 /// 'def' and all its uses have the same shift factor.