Add mulHi to SkNx

Add mulHi to base SkNx, and specialize implementations for Sk4u for
neon and sse.

Add casts for converting from uint8_t by 4 to uint32_t by 4.

Cq-Include-Trybots: skia.primary:Test-Debian9-Clang-GCE-CPU-AVX2-x86_64-Release-SKNX_NO_SIMD
Change-Id: I29a32e2ad9812a47fff841ceca334e562362836f
Reviewed-on: https://skia-review.googlesource.com/57960
Reviewed-by: Mike Klein <mtklein@chromium.org>
Commit-Queue: Herb Derby <herb@google.com>
diff --git a/src/core/SkNx.h b/src/core/SkNx.h
index 65e3fcb..6957cb0 100644
--- a/src/core/SkNx.h
+++ b/src/core/SkNx.h
@@ -119,10 +119,13 @@
     AI SkNx saturatedAdd(const SkNx& y) const {
         return { fLo.saturatedAdd(y.fLo), fHi.saturatedAdd(y.fHi) };
     }
+
+    AI SkNx mulHi(const SkNx& m) const {
+        return { fLo.mulHi(m.fLo), fHi.mulHi(m.fHi) };
+    }
     AI SkNx thenElse(const SkNx& t, const SkNx& e) const {
         return { fLo.thenElse(t.fLo, e.fLo), fHi.thenElse(t.fHi, e.fHi) };
     }
-
     AI static SkNx Min(const SkNx& x, const SkNx& y) {
         return { Half::Min(x.fLo, y.fLo), Half::Min(x.fHi, y.fHi) };
     }
@@ -214,6 +217,12 @@
         return sum < fVal ? std::numeric_limits<T>::max() : sum;
     }
 
+    AI SkNx mulHi(const SkNx& m) const {
+        static_assert(std::is_unsigned<T>::value, "");
+        static_assert(sizeof(T) <= 4, "");
+        return static_cast<T>((static_cast<uint64_t>(fVal) * m.fVal) >> (sizeof(T)*8));
+    }
+
     AI SkNx thenElse(const SkNx& t, const SkNx& e) const { return fVal != 0 ? t : e; }
 
 private:
diff --git a/src/opts/SkNx_neon.h b/src/opts/SkNx_neon.h
index b906a02..32be78f 100644
--- a/src/opts/SkNx_neon.h
+++ b/src/opts/SkNx_neon.h
@@ -497,6 +497,13 @@
     AI static SkNx Min(const SkNx& a, const SkNx& b) { return vminq_u32(a.fVec, b.fVec); }
     // TODO as needed
 
+    AI SkNx mulHi(const SkNx& m) const {
+        uint64x2_t hi = vmull_u32(vget_high_u32(fVec), vget_high_u32(m.fVec));
+        uint64x2_t lo = vmull_u32( vget_low_u32(fVec),  vget_low_u32(m.fVec));
+
+        return { vcombine_u32(vshrn_n_u64(lo,32), vshrn_n_u64(hi,32)) };
+    }
+
     AI SkNx thenElse(const SkNx& t, const SkNx& e) const {
         return vbslq_u32(fVec, t.fVec, e.fVec);
     }
@@ -529,9 +536,13 @@
     return vqmovn_u16(vcombine_u16(_16, _16));
 }
 
-template<> AI /*static*/ Sk4i SkNx_cast<int32_t, uint8_t>(const Sk4b& src) {
+template<> AI /*static*/ Sk4u SkNx_cast<uint32_t, uint8_t>(const Sk4b& src) {
     uint16x8_t _16 = vmovl_u8(src.fVec);
-    return vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(_16)));
+    return vmovl_u16(vget_low_u16(_16));
+}
+
+template<> AI /*static*/ Sk4i SkNx_cast<int32_t, uint8_t>(const Sk4b& src) {
+    return vreinterpretq_s32_u32(SkNx_cast<uint32_t>(src).fVec);
 }
 
 template<> AI /*static*/ Sk4f SkNx_cast<float, uint8_t>(const Sk4b& src) {
diff --git a/src/opts/SkNx_sse.h b/src/opts/SkNx_sse.h
index 469aefb..d4d4781 100644
--- a/src/opts/SkNx_sse.h
+++ b/src/opts/SkNx_sse.h
@@ -287,10 +287,16 @@
     #endif
     }
 
+    AI SkNx mulHi(SkNx m) const {
+        SkNx v20{_mm_mul_epu32(m.fVec, fVec)};
+        SkNx v31{_mm_mul_epu32(_mm_srli_si128(m.fVec, 4), _mm_srli_si128(fVec, 4))};
+
+        return SkNx{v20[1], v31[1], v20[3], v31[3]};
+    }
+
     __m128i fVec;
 };
 
-
 template <>
 class SkNx<4, uint16_t> {
 public:
@@ -568,7 +574,7 @@
 #endif
 }
 
-template<> AI /*static*/ Sk4i SkNx_cast<int32_t, uint8_t>(const Sk4b& src) {
+template<> AI /*static*/ Sk4u SkNx_cast<uint32_t, uint8_t>(const Sk4b& src) {
 #if SK_CPU_SSE_LEVEL >= SK_CPU_SSE_LEVEL_SSSE3
     const int _ = ~0;
     return _mm_shuffle_epi8(src.fVec, _mm_setr_epi8(0,_,_,_, 1,_,_,_, 2,_,_,_, 3,_,_,_));
@@ -578,6 +584,10 @@
 #endif
 }
 
+template<> AI /*static*/ Sk4i SkNx_cast<int32_t, uint8_t>(const Sk4b& src) {
+    return SkNx_cast<uint32_t>(src).fVec;
+}
+
 template<> AI /*static*/ Sk4f SkNx_cast<float, uint8_t>(const Sk4b& src) {
     return _mm_cvtepi32_ps(SkNx_cast<int32_t>(src).fVec);
 }
diff --git a/tests/SkNxTest.cpp b/tests/SkNxTest.cpp
index 240d7e0..afa6750 100644
--- a/tests/SkNxTest.cpp
+++ b/tests/SkNxTest.cpp
@@ -165,6 +165,20 @@
     }
 }
 
+DEF_TEST(SkNi_mulHi, r) {
+    // First 8 primes.
+    Sk4u a{ 0x00020000, 0x00030000, 0x00050000, 0x00070000 };
+    Sk4u b{ 0x000b0000, 0x000d0000, 0x00110000, 0x00130000 };
+
+    Sk4u q{22, 39, 85, 133};
+
+    Sk4u c = a.mulHi(b);
+    REPORTER_ASSERT(r, c[0] == q[0]);
+    REPORTER_ASSERT(r, c[1] == q[1]);
+    REPORTER_ASSERT(r, c[2] == q[2]);
+    REPORTER_ASSERT(r, c[3] == q[3]);
+}
+
 DEF_TEST(Sk4px_muldiv255round, r) {
     for (int a = 0; a < (1<<8); a++) {
     for (int b = 0; b < (1<<8); b++) {