Merge pull request #28205 from anuj-rawat:maskedload_gemmrhspack

PiperOrigin-RevId: 246193951
diff --git a/tensorflow/core/kernels/eigen_spatial_convolutions-inl.h b/tensorflow/core/kernels/eigen_spatial_convolutions-inl.h
index 4559ac3..ad20027 100644
--- a/tensorflow/core/kernels/eigen_spatial_convolutions-inl.h
+++ b/tensorflow/core/kernels/eigen_spatial_convolutions-inl.h
@@ -21,6 +21,37 @@
 
 namespace internal {
 
+// OptimizedPacketLoadOverTwoColumns<TensorEvaluatorType, PacketType, IndexType>
+// provides `value` that is true if TensorEvaluatorType has `PacketType
+// partialPacket(IndexType, unpacket_traits<PacketType>::mask_t) const` and if
+// the PacketType supports masked load. In this case, we can split the packet
+// over two columns and use partial loads for each individual part before
+// combining them to get the required packet. This class is used to pick the
+// correct implementation of loadPacketStandard function.
+template <typename TensorEvaluatorType, typename PacketType, typename IndexType>
+class OptimizedPacketLoadOverTwoColumns {
+ public:
+  template <typename TensorEvaluatorT, typename PacketT, typename IndexT>
+  static auto functionExistsSfinae(
+      typename std::enable_if<
+          unpacket_traits<PacketT>::masked_load_available &&
+          std::is_same<
+              PacketT,
+              decltype(std::declval<const TensorEvaluatorT>().partialPacket(
+                  std::declval<IndexT>(),
+                  std::declval<typename unpacket_traits<PacketT>::mask_t>()))>::
+              value>::type*) -> std::true_type;
+
+  template <typename TensorEvaluatorT, typename PacketT, typename IndexT>
+  static auto functionExistsSfinae(...) -> std::false_type;
+
+  typedef decltype(
+      functionExistsSfinae<TensorEvaluatorType, PacketType, IndexType>(
+          nullptr)) status;
+
+  static const bool value = status::value;
+};
+
 // WARNING: Most of the code here implicitly assumes that the matrix is in
 // ColMajor layout. This is guaranteed by the tensor contraction (see
 // TensorContraction.h).
@@ -347,13 +378,146 @@
     if (nonStandardPatches()) {
       return packetWithPossibleZero(patchId, rowIndex, colIndex, otherIndex);
     }
-    return loadPacketStandard(patchId, rowIndex, colIndex, otherIndex);
+    typedef decltype(m_impl) TensorEvaluatorT;
+    return loadPacketStandard<Packet, TensorEvaluatorT>(patchId, rowIndex,
+                                                        colIndex, otherIndex);
   }
 
+  // Helper function to load a 'partial' packet - this is the single column
+  // part of a packet that is split across two columns. In the 'partial' packet,
+  // the elements corresponding to the column (specified through colOffset) are
+  // loaded and the rest of the elements are zero-filled into the 'partial'
+  // packet. This function is called from loadPacketStandardFromTwoColumns().
+  // This code path is exercied only when the packet type supports masked load
+  // and when the partial packet load is available in the TensorEvaluator.
   EIGEN_DEVICE_FUNC
-  EIGEN_ALWAYS_INLINE Packet loadPacketStandard(Index patchId, Index rowIndex,
-                                                Index colIndex,
-                                                Index otherIndex) const {
+  EIGEN_ALWAYS_INLINE Packet loadPartialPacketStandard(
+      Index rowIndex, Index colIndex, Index otherIndex, Index patchId,
+      const Index span[], const Index patchOffsets[], Index colOffset) const {
+    const Index inputCol = colIndex + colOffset;
+    const Index rowOffsets[2] = {patchOffsets[0] - colOffset * m_colStride,
+                                 patchOffsets[1] - colOffset * m_colStride};
+    const Index inputRows[2] = {rowIndex + rowOffsets[0],
+                                rowIndex + rowOffsets[1]};
+
+    if (inputRows[0] >= m_inputRows || inputRows[1] < 0 ||
+        inputCol >= m_inputCols || inputCol < 0) {
+      // Partial packet is all zeros
+      return internal::pset1<Packet>(Scalar(0));
+    } else if (inputRows[0] >= 0 && inputRows[1] < m_inputRows) {
+      // From inputIndex-span[0], we need to load elements starting from index
+      // span[0] all the way upto (and including) span[1].
+      const Index depth = patchId - patchOffsets[0] * patchDepth();
+      const Index inputIndex = depth + inputRows[0] * m_rowInputStride +
+                               inputCol * m_colInputStride + otherIndex;
+      // Determine mask corresponding to the partial packet. If the mask bit is
+      // 1, element will be loaded, otherwise 0 will be loaded.
+      const Index packetSize = internal::unpacket_traits<Packet>::size;
+      typename internal::unpacket_traits<Packet>::mask_t mask_t_max =
+          +std::numeric_limits<
+              typename internal::unpacket_traits<Packet>::mask_t>::max();
+      typename internal::unpacket_traits<Packet>::mask_t umask =
+          (mask_t_max >> (packetSize - span[1] - 1)) ^
+          (mask_t_max >> (packetSize - span[0]));
+      return m_impl.partialPacket(inputIndex - span[0], umask);
+    } else {
+      // Using slow path for this partial packet.
+      // We need to load elements starting from index span[0] all the way upto
+      // (and including) span[1]. We split this load into 3 parts:
+      // 0 : span[0]-1 - Zeros will be loaded for these indices
+      // span[0] : span[1] - Elements will be loaded here for these indices
+      // span[1]+1 : packetSize-1 - Zeross will be loaded for these indices
+      const Index packetSize = internal::unpacket_traits<Packet>::size;
+      EIGEN_ALIGN_MAX
+      typename internal::remove_const<Scalar>::type values[packetSize];
+      for (int i = 0; i < span[0]; ++i) values[i] = Scalar(0);
+      for (int i = span[0]; i < span[1] + 1; ++i)
+        values[i] =
+            loadCoeff(patchId - span[0] + i, rowIndex, colIndex, otherIndex);
+      for (int i = span[1] + 1; i < packetSize; ++i) values[i] = Scalar(0);
+      return internal::pload<Packet>(values);
+    }
+  }
+
+  // Helper function to load a packet that is split across two columns.
+  // If required, this function is called from loadPacketStandard() when the
+  // packet type supports masked load and when the partial packet load is
+  // available in the TensorEvaluator.
+  EIGEN_DEVICE_FUNC
+  EIGEN_ALWAYS_INLINE Packet loadPacketStandardFromTwoColumns(
+      Index patchId, Index rowIndex, Index colIndex, Index otherIndex,
+      const Index patchOffsets[], const Index colOffsets[]) const {
+    eigen_assert(colOffsets[1] == colOffsets[0] + 1);
+    const Index packetSize = internal::unpacket_traits<Packet>::size;
+
+    // Packet to load will be split into 2 parts where each part spans a single
+    // column. First determine where to split.
+    const Index patchIdSplit =
+        ((colOffsets[1] * m_colStride) * m_rowInputStride) - 1;
+    const Index patchOffsetSplit = patchIdSplit / m_fastDimZero;
+
+    // patchIds[i]:          patchId corresponding to partial packet i
+    // spans[i]:             Start and end indices corresponding to the elements
+    //                       to be loaded for partial packet i
+    // patchOffsets2Cols[i]: patchOffsets corresponding to partial packet i
+    const Index patchIds[2] = {patchId, patchIdSplit + 1};
+    const Index spans[2][2] = {{0, patchIdSplit - patchId},
+                               {patchIdSplit - patchId + 1, packetSize - 1}};
+    const Index patchOffsets2Cols[2][2] = {
+        {patchOffsets[0], patchOffsetSplit},
+        {patchOffsetSplit + 1, patchOffsets[1]}};
+
+    // Load partial packets and do bit-wise OR to generate required packet
+    return internal::por<Packet>(
+        loadPartialPacketStandard(rowIndex, colIndex, otherIndex, patchIds[0],
+                                  spans[0], patchOffsets2Cols[0],
+                                  colOffsets[0]),
+        loadPartialPacketStandard(rowIndex, colIndex, otherIndex, patchIds[1],
+                                  spans[1], patchOffsets2Cols[1],
+                                  colOffsets[1]));
+  }
+
+  // Helper function to load a packet that is present in a single columns.
+  // If required, this function is called from loadPacketStandard().
+  EIGEN_DEVICE_FUNC
+  EIGEN_ALWAYS_INLINE Packet loadPacketStandardFromSingleColumn(
+      Index patchId, Index rowIndex, Index colIndex, Index otherIndex,
+      const Index patchOffsets[], const Index colOffsets[],
+      const Index inputCols[]) const {
+    eigen_assert(colOffsets[0] == colOffsets[1]);
+    const Index rowOffsets[2] = {patchOffsets[0] - colOffsets[0] * m_colStride,
+                                 patchOffsets[1] - colOffsets[1] * m_colStride};
+    eigen_assert(rowOffsets[0] <= rowOffsets[1]);
+    const Index inputRows[2] = {rowIndex + rowOffsets[0],
+                                rowIndex + rowOffsets[1]};
+
+    if (inputRows[0] >= m_inputRows || inputRows[1] < 0) {
+      // all zeros
+      return internal::pset1<Packet>(Scalar(0));  // all zeros
+    }
+
+    if (inputRows[0] >= 0 && inputRows[1] < m_inputRows) {
+      // no padding
+      const Index depth = patchId - patchOffsets[0] * patchDepth();
+      const Index inputIndex = depth + inputRows[0] * m_rowInputStride +
+                               inputCols[0] * m_colInputStride + otherIndex;
+      return m_impl.template packet<Unaligned>(inputIndex);
+    }
+    return packetWithPossibleZero(patchId, rowIndex, colIndex, otherIndex);
+  }
+
+  // Load standard packet from a patch specified by the "within patch offset"
+  // (patchId) and the precomputed indices of the first element of the patch.
+  // This function will be called if partial packet loading is not available
+  // for the TesnorEvaluator or if the packet type does not support masked
+  // load.
+  template <typename PacketT, typename TensorEvaluatorT>
+  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
+      typename std::enable_if<!OptimizedPacketLoadOverTwoColumns<
+                                  TensorEvaluatorT, PacketT, Index>::value,
+                              PacketT>::type
+      loadPacketStandard(Index patchId, Index rowIndex, Index colIndex,
+                         Index otherIndex) const {
     const Index packetSize = internal::unpacket_traits<Packet>::size;
     EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
     eigen_assert(patchId < patchDepth() * patchRows() * m_patch_cols);
@@ -362,43 +526,78 @@
 
     if ((patchDepth() % packetSize) == 0) {
       return loadPacketFast(patchId, rowIndex, colIndex, otherIndex);
-    } else {
-      // Offsets and input calculation here are identical to
-      // loadCoeffStandard(...), but repeated twice.
+    }
 
-      const Index patchOffsets[2] = {
-          patchId / m_fastDimZero, (patchId + packetSize - 1) / m_fastDimZero};
+    // Offsets and input calculation here are identical to
+    // loadCoeffStandard(...), but repeated twice.
+    const Index patchOffsets[2] = {patchId / m_fastDimZero,
+                                   (patchId + packetSize - 1) / m_fastDimZero};
+    const Index colOffsets[2] = {patchOffsets[0] / m_fastColStride,
+                                 patchOffsets[1] / m_fastColStride};
+    const Index inputCols[2] = {colIndex + colOffsets[0],
+                                colIndex + colOffsets[1]};
 
-      const Index colOffsets[2] = {patchOffsets[0] / m_fastColStride,
-                                   patchOffsets[1] / m_fastColStride};
-      const Index inputCols[2] = {colIndex + colOffsets[0],
-                                  colIndex + colOffsets[1]};
-      if (inputCols[0] >= m_inputCols || inputCols[1] < 0) {
-        // all zeros
-        return internal::pset1<Packet>(Scalar(0));
-      }
+    if (inputCols[0] >= m_inputCols || inputCols[1] < 0) {
+      // all zeros
+      return internal::pset1<Packet>(Scalar(0));
+    }
+    if (inputCols[0] == inputCols[1]) {
+      return loadPacketStandardFromSingleColumn(patchId, rowIndex, colIndex,
+                                                otherIndex, patchOffsets,
+                                                colOffsets, inputCols);
+    }
+    return packetWithPossibleZero(patchId, rowIndex, colIndex, otherIndex);
+  }
 
-      if (inputCols[0] == inputCols[1]) {
-        const Index rowOffsets[2] = {
-            patchOffsets[0] - colOffsets[0] * m_colStride,
-            patchOffsets[1] - colOffsets[1] * m_colStride};
-        eigen_assert(rowOffsets[0] <= rowOffsets[1]);
-        const Index inputRows[2] = {rowIndex + rowOffsets[0],
-                                    rowIndex + rowOffsets[1]};
+  // Load standard packet from a patch specified by the "within patch offset"
+  // (patchId) and the precomputed indices of the first element of the patch.
+  // This function will be called if partial packet loading is available for
+  // the TesnorEvaluator and if the packet type supports masked load.
+  // The only difference between this and the other case is that if the packet
+  // to load is split across two columns, then in this case instead of going to
+  // the slow (element-by-element) load, we load two packets - each containing
+  // elements from one of the columns (rest of the elements of the packets are
+  // zeroes), and then combine these two packets to generate the required
+  // packet. The idea is to enable fast load (if possible) of these 'partial'
+  // packets.
+  template <typename PacketT, typename TensorEvaluatorT>
+  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
+      typename std::enable_if<OptimizedPacketLoadOverTwoColumns<
+                                  TensorEvaluatorT, PacketT, Index>::value,
+                              PacketT>::type
+      loadPacketStandard(Index patchId, Index rowIndex, Index colIndex,
+                         Index otherIndex) const {
+    const Index packetSize = internal::unpacket_traits<PacketT>::size;
+    EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
+    eigen_assert(patchId < patchDepth() * patchRows() * m_patch_cols);
 
-        if (inputRows[0] >= m_inputRows || inputRows[1] < 0) {
-          // all zeros
-          return internal::pset1<Packet>(Scalar(0));
-        }
+    eigen_assert(!nonStandardPatches());
 
-        if (inputRows[0] >= 0 && inputRows[1] < m_inputRows) {
-          // no padding
-          const Index depth = patchId - patchOffsets[0] * patchDepth();
-          const Index inputIndex = depth + inputRows[0] * m_rowInputStride +
-                                   inputCols[0] * m_colInputStride + otherIndex;
-          return m_impl.template packet<Unaligned>(inputIndex);
-        }
-      }
+    if ((patchDepth() % packetSize) == 0) {
+      return loadPacketFast(patchId, rowIndex, colIndex, otherIndex);
+    }
+
+    // Offsets and input calculation here are identical to
+    // loadCoeffStandard(...), but repeated twice.
+    const Index patchOffsets[2] = {patchId / m_fastDimZero,
+                                   (patchId + packetSize - 1) / m_fastDimZero};
+    const Index colOffsets[2] = {patchOffsets[0] / m_fastColStride,
+                                 patchOffsets[1] / m_fastColStride};
+    const Index inputCols[2] = {colIndex + colOffsets[0],
+                                colIndex + colOffsets[1]};
+
+    if (inputCols[0] >= m_inputCols || inputCols[1] < 0) {
+      // all zeros
+      return internal::pset1<PacketT>(Scalar(0));
+    }
+    if (inputCols[0] == inputCols[1]) {
+      return loadPacketStandardFromSingleColumn(patchId, rowIndex, colIndex,
+                                                otherIndex, patchOffsets,
+                                                colOffsets, inputCols);
+    }
+    if (inputCols[1] == inputCols[0] + 1) {
+      return loadPacketStandardFromTwoColumns(
+          patchId, rowIndex, colIndex, otherIndex, patchOffsets, colOffsets);
     }
     return packetWithPossibleZero(patchId, rowIndex, colIndex, otherIndex);
   }
@@ -591,8 +790,9 @@
   }
   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet
   loadPacketStandard(Index i) const {
-    return m_base_mapper.loadPacketStandard(i + m_depth_offset, m_rowIndex,
-                                            m_colIndex, m_otherIndex);
+    typedef decltype(m_base_mapper.m_impl) TensorEvaluatorT;
+    return m_base_mapper.template loadPacketStandard<Packet, TensorEvaluatorT>(
+        i + m_depth_offset, m_rowIndex, m_colIndex, m_otherIndex);
   }
   template <typename Packet>
   EIGEN_DEVICE_FUNC bool aligned(Index) const {
diff --git a/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX2.h b/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX2.h
index 8df6782..f0d7bda 100644
--- a/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX2.h
+++ b/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX2.h
@@ -2,8 +2,8 @@
 #define CXX11_SRC_FIXEDPOINT_PACKETMATHAVX2_H_
 #ifdef _MSC_VER
 
-#include <immintrin.h>
 #include <emmintrin.h>
+#include <immintrin.h>
 #include <smmintrin.h>
 
 #endif
@@ -178,37 +178,67 @@
 struct unpacket_traits<Packet32q8i> {
   typedef QInt8 type;
   typedef Packet16q8i half;
-  enum { size = 32, alignment = Aligned32, vectorizable = true };
+  enum {
+    size = 32,
+    alignment = Aligned32,
+    vectorizable = true,
+    masked_load_available = false
+  };
 };
 template <>
 struct unpacket_traits<Packet16q8i> {
   typedef QInt8 type;
   typedef Packet16q8i half;
-  enum { size = 16, alignment = Aligned32, vectorizable = true };
+  enum {
+    size = 16,
+    alignment = Aligned32,
+    vectorizable = true,
+    masked_load_available = false
+  };
 };
 template <>
 struct unpacket_traits<Packet16q16i> {
   typedef QInt16 type;
   typedef Packet8q16i half;
-  enum { size = 16, alignment = Aligned32, vectorizable = true };
+  enum {
+    size = 16,
+    alignment = Aligned32,
+    vectorizable = true,
+    masked_load_available = false
+  };
 };
 template <>
 struct unpacket_traits<Packet8q16i> {
   typedef QInt16 type;
   typedef Packet8q16i half;
-  enum { size = 8, alignment = Aligned32, vectorizable = true };
+  enum {
+    size = 8,
+    alignment = Aligned32,
+    vectorizable = true,
+    masked_load_available = false
+  };
 };
 template <>
 struct unpacket_traits<Packet32q8u> {
   typedef QUInt8 type;
   typedef Packet16q8u half;
-  enum { size = 32, alignment = Aligned32, vectorizable = true };
+  enum {
+    size = 32,
+    alignment = Aligned32,
+    vectorizable = true,
+    masked_load_available = false
+  };
 };
 template <>
 struct unpacket_traits<Packet8q32i> {
   typedef QInt32 type;
   typedef Packet4q32i half;
-  enum { size = 8, alignment = Aligned32, vectorizable = true };
+  enum {
+    size = 8,
+    alignment = Aligned32,
+    vectorizable = true,
+    masked_load_available = false
+  };
 };
 
 // Unaligned load
@@ -232,7 +262,7 @@
   EIGEN_DEBUG_UNALIGNED_LOAD return _mm256_loadu_si256(
       reinterpret_cast<const __m256i*>(from));
 }
-template<>
+template <>
 EIGEN_STRONG_INLINE Packet8q16i ploadu<Packet8q16i>(const QInt16* from) {
   EIGEN_DEBUG_UNALIGNED_LOAD return _mm_loadu_si128(
       reinterpret_cast<const __m128i*>(from));
@@ -283,8 +313,8 @@
 }
 template <>
 EIGEN_STRONG_INLINE void pstoreu<QInt8>(QInt8* to, const Packet16q8i& from) {
-  EIGEN_DEBUG_UNALIGNED_STORE _mm_storeu_si128(
-      reinterpret_cast<__m128i*>(to), from.val);
+  EIGEN_DEBUG_UNALIGNED_STORE _mm_storeu_si128(reinterpret_cast<__m128i*>(to),
+                                               from.val);
 }
 template <>
 EIGEN_STRONG_INLINE void pstoreu<QUInt8>(QUInt8* to, const Packet32q8u& from) {
@@ -298,8 +328,8 @@
 }
 template <>
 EIGEN_STRONG_INLINE void pstoreu<QInt16>(QInt16* to, const Packet8q16i& from) {
-  EIGEN_DEBUG_UNALIGNED_STORE _mm_storeu_si128(
-      reinterpret_cast<__m128i*>(to), from.val);
+  EIGEN_DEBUG_UNALIGNED_STORE _mm_storeu_si128(reinterpret_cast<__m128i*>(to),
+                                               from.val);
 }
 template <>
 EIGEN_STRONG_INLINE void pstoreu<QInt32>(QInt32* to, const Packet8q32i& from) {
diff --git a/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX512.h b/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX512.h
index 84750c1..1c6c62d 100644
--- a/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX512.h
+++ b/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX512.h
@@ -127,25 +127,25 @@
 struct unpacket_traits<Packet64q8i> {
   typedef QInt8 type;
   typedef Packet32q8i half;
-  enum { size = 64, alignment = Aligned64 };
+  enum { size = 64, alignment = Aligned64, masked_load_available = false };
 };
 template <>
 struct unpacket_traits<Packet32q16i> {
   typedef QInt16 type;
   typedef Packet16q16i half;
-  enum { size = 32, alignment = Aligned64 };
+  enum { size = 32, alignment = Aligned64, masked_load_available = false };
 };
 template <>
 struct unpacket_traits<Packet64q8u> {
   typedef QUInt8 type;
   typedef Packet32q8u half;
-  enum { size = 64, alignment = Aligned64 };
+  enum { size = 64, alignment = Aligned64, masked_load_available = false };
 };
 template <>
 struct unpacket_traits<Packet16q32i> {
   typedef QInt32 type;
   typedef Packet8q32i half;
-  enum { size = 16, alignment = Aligned64 };
+  enum { size = 16, alignment = Aligned64, masked_load_available = false };
 };
 
 // Unaligned load