Rolling forward "[tf.SparseTensor] Optimize `SparseTensor::IndicesValid()` for "small" 2D tensors."

The previous version had an error in the pointer arithmetic, which caused it to skip the first row of the index array, and read one row after the end of that array. It was caught by an MSAN continuous test.

PiperOrigin-RevId: 290114585
Change-Id: If5fd2b560f97ce625bb175062f32bc1db053c99f
diff --git a/tensorflow/core/util/sparse/sparse_tensor.cc b/tensorflow/core/util/sparse/sparse_tensor.cc
index 1eb9cb9..e58bd95 100644
--- a/tensorflow/core/util/sparse/sparse_tensor.cc
+++ b/tensorflow/core/util/sparse/sparse_tensor.cc
@@ -108,6 +108,84 @@
   DCHECK_EQ(shape.size(), dims_) << "Shape rank must be SparseTensor rank.";
 }
 
+// Optimized version of `IndicesValid()` with the following requirements:
+// * The sparse tensor is two-dimensional.
+// * The tensor's indices are in the "standard" (lexicographic) order.
+// * All of the tensor's indices fit within the range of a signed int32.
+//
+// Returns true if the indices are valid, otherwise false.
+// NOTE(mrry): If this method returns false, call IndicesValidHelper<true>()
+// to obtain a meaningful error message.
+bool SparseTensor::IndicesValid32BitFastPath() const {
+  const auto ix_t = ix_.matrix<int64>();
+  const int64* const shape_ptr = shape_.data();
+
+  DCHECK_EQ(shape_.size(), 2);
+  DCHECK_EQ(order_[0], 0);
+  DCHECK_EQ(order_[1], 1);
+  DCHECK_LE(shape_ptr[0], std::numeric_limits<int32>::max());
+  DCHECK_LE(shape_ptr[1], std::numeric_limits<int32>::max());
+
+  const int32 max_rows = static_cast<int32>(shape_ptr[0]);
+  const int32 max_cols = static_cast<int32>(shape_ptr[1]);
+
+  // We maintain separate bools for each validation predicate to enable
+  // vectorization across loop iterations.
+  bool row_zeros_valid = true;
+  bool row_in_range_valid = true;
+  bool col_zeros_valid = true;
+  bool col_in_range_valid = true;
+  bool order_valid = true;
+
+  int64 prev_index = -1;
+
+  // Points to the beginning of the current row of the indices matrix.
+  // Each row has two int64 elements, but we use an int32 pointer to access
+  // the low and high 32 bits of each element separately. This means that our
+  // stride per row is 4 elements.
+  const int32* const index_base_ptr =
+      reinterpret_cast<const int32*>(ix_t.data());
+  const size_t kInt32ElementsPerRow = 4;
+
+  for (std::size_t n = 0; n < ix_t.dimension(0); ++n) {
+    const int32* const index_ptr = index_base_ptr + n * kInt32ElementsPerRow;
+
+    // Unpack the values on the current row of the indices matrix.
+#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
+    const int32 row_zeros = index_ptr[0];
+    const int32 row_32 = index_ptr[1];
+    const int32 col_zeros = index_ptr[2];
+    const int32 col_32 = index_ptr[3];
+#else
+    const int32 row_32 = index_ptr[0];
+    const int32 row_zeros = index_ptr[1];
+    const int32 col_32 = index_ptr[2];
+    const int32 col_zeros = index_ptr[3];
+#endif
+
+    // Validate that the high 32 bits of the row and column indices are zero.
+    row_zeros_valid = row_zeros_valid & (row_zeros == 0);
+    col_zeros_valid = col_zeros_valid & (col_zeros == 0);
+
+    // Validate that the low 32 bits of the row and column indices are within
+    // range of the shape.
+    row_in_range_valid =
+        row_in_range_valid & (row_32 >= 0) & (row_32 < max_rows);
+    col_in_range_valid =
+        col_in_range_valid & (col_32 >= 0) & (col_32 < max_cols);
+
+    // Interpret the row and column as a concatenated 64-bit integer, and
+    // validate that the concatenated indices are in strictly increasing order.
+    const int64 concatenated_index =
+        (static_cast<int64>(row_32) << 32) + col_32;
+    order_valid = order_valid & (concatenated_index > prev_index);
+    prev_index = concatenated_index;
+  }
+
+  return row_zeros_valid & row_in_range_valid & col_zeros_valid &
+         col_in_range_valid & order_valid;
+}
+
 template <bool standard_order>
 Status SparseTensor::IndicesValidHelper() const {
   const auto ix_t = ix_.matrix<int64>();
@@ -174,6 +252,12 @@
   }
 
   if (standard_order) {
+    if (shape_.size() == 2 && shape_[0] <= std::numeric_limits<int32>::max() &&
+        shape_[1] <= std::numeric_limits<int32>::max()) {
+      if (IndicesValid32BitFastPath()) {
+        return Status::OK();
+      }
+    }
     return IndicesValidHelper<true>();
   } else {
     return IndicesValidHelper<false>();
diff --git a/tensorflow/core/util/sparse/sparse_tensor.h b/tensorflow/core/util/sparse/sparse_tensor.h
index 1de1374..03ae4fe 100644
--- a/tensorflow/core/util/sparse/sparse_tensor.h
+++ b/tensorflow/core/util/sparse/sparse_tensor.h
@@ -201,6 +201,8 @@
     return vec;
   }
 
+  bool IndicesValid32BitFastPath() const;
+
   template <bool standard_order>
   Status IndicesValidHelper() const;