allow two immediates

This lets us move the calculation of extract's shift up to Builder time.

Most immediates live in y.imm now, with extract using both y.imm (mask)
and z.imm (shift), and pack using z.imm for its shift because it needs
both x and y regs.

Change-Id: I0081382f4d4c02198cae5819294b3adeea7341bb
Reviewed-on: https://skia-review.googlesource.com/c/skia/+/218249
Commit-Queue: Mike Klein <mtklein@google.com>
Reviewed-by: Brian Osman <brianosman@google.com>
Auto-Submit: Mike Klein <mtklein@google.com>
diff --git a/src/core/SkVM.cpp b/src/core/SkVM.cpp
index b71e264..f8784b4 100644
--- a/src/core/SkVM.cpp
+++ b/src/core/SkVM.cpp
@@ -9,12 +9,12 @@
 #include "src/core/SkOpts.h"
 #include "src/core/SkVM.h"
 #include <string.h>
+#if defined(SK_BUILD_FOR_WIN)
+    #include <intrin.h>
+#endif
 
 namespace skvm {
 
-    // We reserve the last ID as a sentinel meaning none, n/a, null, nil, etc.
-    static const ID NA = ~0;
-
     Program::Program(std::vector<Instruction> instructions, int regs)
         : fInstructions(std::move(instructions))
         , fRegs(regs)
@@ -94,14 +94,11 @@
                 inst.op,
                 lookup_register(id),
                 lookup_register(inst.x),
-                lookup_register(inst.y),
+               {lookup_register(inst.y)},
                {lookup_register(inst.z)},
             };
-            // If the z argument is the N/A sentinel, copy in the immediate instead.
-            // (No Op uses both 3 arguments and an immediate.)
-            if (inst.z == NA) {
-                pinst.z.imm = inst.imm;
-            }
+            if (inst.y == NA) { pinst.y.imm = inst.immy; }
+            if (inst.z == NA) { pinst.z.imm = inst.immz; }
             program.push_back(pinst);
         }
 
@@ -111,19 +108,19 @@
     // Most instructions produce a value and return it by ID,
     // the value-producing instruction's own index in the program vector.
 
-    ID Builder::push(Op op, ID x=NA, ID y=NA, ID z=NA, int imm=0) {
-        Instruction inst{op, /*life=*/NA, x, y, z, imm};
+    ID Builder::push(Op op, ID x, ID y, ID z, int immy, int immz) {
+        Instruction inst{op, /*life=*/NA, x, y, z, immy, immz};
 
         // Simple peepholes that come up fairly often.
-        if (op == Op::extract && imm == (int)0xff000000) { inst = { Op::shr, NA, x,NA,NA, 24 }; }
+        if (op == Op::extract && immy == (int)0xff000000) { inst = {Op::shr,NA, x,NA,NA, 24,0}; }
 
         auto is_zero = [&](ID id) {
-            return fProgram[id].op  == Op::splat
-                && fProgram[id].imm == 0;
+            return fProgram[id].op   == Op::splat
+                && fProgram[id].immy == 0;
         };
 
         // x*y+0 --> x*y
-        if (op == Op::mad_f32 && is_zero(z)) { inst = { Op::mul_f32, NA, x,y,NA, 0 }; }
+        if (op == Op::mad_f32 && is_zero(z)) { inst = {Op::mul_f32,NA, x,y,NA, 0,0}; }
 
 
         // Basic common subexpression elimination:
@@ -175,15 +172,27 @@
 
     I32 Builder::mul_unorm8(I32 x, I32 y) { return {this->push(Op::mul_unorm8, x.id, y.id)}; }
 
-    I32 Builder::extract(I32 x, int mask) { return {this->push(Op::extract, x.id,NA,NA, mask)}; }
-    I32 Builder::pack(I32 x, I32 y, int bits) { return {this->push(Op::pack, x.id,y.id,NA, bits)}; }
+    I32 Builder::extract(I32 x, int mask) {
+        SkASSERT(mask != 0);
+    #if defined(SK_BUILD_FOR_WIN)
+        unsigned long shift;
+        _BitScanForward(&shift, mask);
+    #else
+        const int shift = __builtin_ctz(mask);
+    #endif
+        return {this->push(Op::extract, x.id,NA,NA, mask, shift)};
+    }
+
+    I32 Builder::pack(I32 x, I32 y, int bits) {
+        return {this->push(Op::pack, x.id,y.id,NA, 0,bits)};
+    }
 
     F32 Builder::to_f32(I32 x) { return {this->push(Op::to_f32, x.id)}; }
     I32 Builder::to_i32(F32 x) { return {this->push(Op::to_i32, x.id)}; }
 
     // ~~~~ Program::dump() and co. ~~~~ //
 
-    struct Reg { ID id; };
+    struct R { ID id; };
     struct Shift { int bits; };
     struct Mask  { int bits; };
     struct Splat { int bits; };
@@ -197,7 +206,7 @@
         o->writeDecAsText(a.ix);
         write(o, ")");
     }
-    static void write(SkWStream* o, Reg r) {
+    static void write(SkWStream* o, R r) {
         write(o, "r");
         o->writeDecAsText(r.id);
     }
@@ -231,43 +240,43 @@
         for (const Instruction& inst : fInstructions) {
             Op  op = inst.op;
             ID   d = inst.d,
-                 x = inst.x,
-                 y = inst.y;
-            auto z = inst.z;
+                 x = inst.x;
+            auto y = inst.y,
+                 z = inst.z;
             switch (op) {
-                case Op::store8:  write(o, "store8" , Arg{z.imm}, Reg{x}); break;
-                case Op::store32: write(o, "store32", Arg{z.imm}, Reg{x}); break;
+                case Op::store8:  write(o, "store8" , Arg{y.imm}, R{x}); break;
+                case Op::store32: write(o, "store32", Arg{y.imm}, R{x}); break;
 
-                case Op::load8:  write(o, Reg{d}, "= load8" , Arg{z.imm}); break;
-                case Op::load32: write(o, Reg{d}, "= load32", Arg{z.imm}); break;
+                case Op::load8:  write(o, R{d}, "= load8" , Arg{y.imm}); break;
+                case Op::load32: write(o, R{d}, "= load32", Arg{y.imm}); break;
 
-                case Op::splat:  write(o, Reg{d}, "= splat", Splat{z.imm}); break;
+                case Op::splat:  write(o, R{d}, "= splat", Splat{y.imm}); break;
 
-                case Op::add_f32: write(o, Reg{d}, "= add_f32", Reg{x}, Reg{y}           ); break;
-                case Op::sub_f32: write(o, Reg{d}, "= sub_f32", Reg{x}, Reg{y}           ); break;
-                case Op::mul_f32: write(o, Reg{d}, "= mul_f32", Reg{x}, Reg{y}           ); break;
-                case Op::div_f32: write(o, Reg{d}, "= div_f32", Reg{x}, Reg{y}           ); break;
-                case Op::mad_f32: write(o, Reg{d}, "= mad_f32", Reg{x}, Reg{y}, Reg{z.id}); break;
+                case Op::add_f32: write(o, R{d}, "= add_f32", R{x}, R{y.id}           ); break;
+                case Op::sub_f32: write(o, R{d}, "= sub_f32", R{x}, R{y.id}           ); break;
+                case Op::mul_f32: write(o, R{d}, "= mul_f32", R{x}, R{y.id}           ); break;
+                case Op::div_f32: write(o, R{d}, "= div_f32", R{x}, R{y.id}           ); break;
+                case Op::mad_f32: write(o, R{d}, "= mad_f32", R{x}, R{y.id}, R{z.id}); break;
 
-                case Op::add_i32: write(o, Reg{d}, "= add_i32", Reg{x}, Reg{y}); break;
-                case Op::sub_i32: write(o, Reg{d}, "= sub_i32", Reg{x}, Reg{y}); break;
-                case Op::mul_i32: write(o, Reg{d}, "= mul_i32", Reg{x}, Reg{y}); break;
+                case Op::add_i32: write(o, R{d}, "= add_i32", R{x}, R{y.id}); break;
+                case Op::sub_i32: write(o, R{d}, "= sub_i32", R{x}, R{y.id}); break;
+                case Op::mul_i32: write(o, R{d}, "= mul_i32", R{x}, R{y.id}); break;
 
-                case Op::bit_and: write(o, Reg{d}, "= bit_and", Reg{x}, Reg{y}); break;
-                case Op::bit_or : write(o, Reg{d}, "= bit_or" , Reg{x}, Reg{y}); break;
-                case Op::bit_xor: write(o, Reg{d}, "= bit_xor", Reg{x}, Reg{y}); break;
+                case Op::bit_and: write(o, R{d}, "= bit_and", R{x}, R{y.id}); break;
+                case Op::bit_or : write(o, R{d}, "= bit_or" , R{x}, R{y.id}); break;
+                case Op::bit_xor: write(o, R{d}, "= bit_xor", R{x}, R{y.id}); break;
 
-                case Op::shl: write(o, Reg{d}, "= shl", Reg{x}, Shift{z.imm}); break;
-                case Op::shr: write(o, Reg{d}, "= shr", Reg{x}, Shift{z.imm}); break;
-                case Op::sra: write(o, Reg{d}, "= sra", Reg{x}, Shift{z.imm}); break;
+                case Op::shl: write(o, R{d}, "= shl", R{x}, Shift{y.imm}); break;
+                case Op::shr: write(o, R{d}, "= shr", R{x}, Shift{y.imm}); break;
+                case Op::sra: write(o, R{d}, "= sra", R{x}, Shift{y.imm}); break;
 
-                case Op::mul_unorm8: write(o, Reg{d}, "= mul_unorm8", Reg{x}, Reg{y}); break;
+                case Op::mul_unorm8: write(o, R{d}, "= mul_unorm8", R{x}, R{y.id}); break;
 
-                case Op::extract: write(o, Reg{d}, "= extract", Reg{x}, Mask{z.imm}); break;
-                case Op::pack: write(o, Reg{d}, "= pack", Reg{x}, Reg{y}, Shift{z.imm}); break;
+                case Op::extract: write(o, R{d}, "= extract", R{x}, Mask{y.imm}); break;
+                case Op::pack: write(o, R{d}, "= pack", R{x}, R{y.id}, Shift{z.imm}); break;
 
-                case Op::to_f32: write(o, Reg{d}, "= to_f32", Reg{x}); break;
-                case Op::to_i32: write(o, Reg{d}, "= to_i32", Reg{x}); break;
+                case Op::to_f32: write(o, R{d}, "= to_f32", R{x}); break;
+                case Op::to_i32: write(o, R{d}, "= to_i32", R{x}); break;
             }
             write(o, "\n");
         }
diff --git a/src/core/SkVM.h b/src/core/SkVM.h
index 5e220d6..bc035c9 100644
--- a/src/core/SkVM.h
+++ b/src/core/SkVM.h
@@ -33,10 +33,10 @@
 
     class Program {
     public:
-        struct Instruction {   // d = op(x,y, z.id/z.imm)
+        struct Instruction {   // d = op(x, y.id/y.imm, z.id/z.imm)
             Op op;
-            ID d,x,y;
-            union { ID id; int imm; } z;
+            ID d,x;
+            union { ID id; int imm; } y,z;
         };
 
         Program(std::vector<Instruction>, int regs);
@@ -117,11 +117,14 @@
         I32 to_i32(F32 x);
 
     private:
+        // We reserve the last ID as a sentinel meaning none, n/a, null, nil, etc.
+        static const ID NA = ~0;
+
         struct Instruction {
-            Op  op;      // v* = op(x,y,z,imm), where * == index of this Instruction.
-            ID  life;    // ID of last instruction using this instruction's result.
-            ID  x,y,z;   // Enough arguments for mad().
-            int imm;     // Immediate bit pattern, shift count, or argument index.
+            Op  op;         // v* = op(x,y,z,imm), where * == index of this Instruction.
+            ID  life;       // ID of last instruction using this instruction's result.
+            ID  x,y,z;      // Enough arguments for mad().
+            int immy,immz;  // Immediate bit patterns, shift counts, argument indexes.
 
             bool operator==(const Instruction& o) const {
                 return op   == o.op
@@ -129,7 +132,8 @@
                     && x    == o.x
                     && y    == o.y
                     && z    == o.z
-                    && imm  == o.imm;
+                    && immy == o.immy
+                    && immz == o.immz;
             }
         };
 
@@ -144,14 +148,15 @@
                      ^ Hash(inst.x)
                      ^ Hash(inst.y)
                      ^ Hash(inst.z)
-                     ^ Hash(inst.imm);
+                     ^ Hash(inst.immy)
+                     ^ Hash(inst.immz);
             }
         };
 
         std::unordered_map<Instruction, ID, InstructionHash> fIndex;
         std::vector<Instruction>                             fProgram;
 
-        ID push(Op, ID, ID, ID, int);
+        ID push(Op, ID x, ID y=NA, ID z=NA, int immy=0, int immz=0);
     };
 
     // TODO: comparison operations, if_then_else
diff --git a/src/opts/SkVM_opts.h b/src/opts/SkVM_opts.h
index 3549779..c824184 100644
--- a/src/opts/SkVM_opts.h
+++ b/src/opts/SkVM_opts.h
@@ -9,9 +9,6 @@
 #define SkVM_opts_DEFINED
 
 #include "src/core/SkVM.h"
-#if defined(SK_BUILD_FOR_WIN)
-    #include <intrin.h>
-#endif
 
 namespace SK_OPTS_NS {
 
@@ -76,11 +73,11 @@
             for (int i = 0; i < ninsts; i++) {
                 skvm::Program::Instruction inst = insts[i];
 
-                // d = op(x, y, z.id/z.imm)
+                // d = op(x, y.id/z.imm, z.id/z.imm)
                 ID   d = inst.d,
-                     x = inst.x,
-                     y = inst.y;
-                auto z = inst.z;
+                     x = inst.x;
+                auto y = inst.y,
+                     z = inst.z;
 
                 // Ops that interact with memory need to know whether we're stride=1 or stride=K,
                 // but all non-memory ops can run the same code no matter the stride.
@@ -88,57 +85,48 @@
 
                 #define STRIDE_1(op) case 2*(int)op
                 #define STRIDE_K(op) case 2*(int)op + 1
-                    STRIDE_1(Op::store8 ): memcpy(arg(z.imm), &r(x).i32, 1); break;
-                    STRIDE_1(Op::store32): memcpy(arg(z.imm), &r(x).i32, 4); break;
+                    STRIDE_1(Op::store8 ): memcpy(arg(y.imm), &r(x).i32, 1); break;
+                    STRIDE_1(Op::store32): memcpy(arg(y.imm), &r(x).i32, 4); break;
 
-                    STRIDE_K(Op::store8 ): skvx::cast<uint8_t>(r(x).i32).store(arg(z.imm)); break;
-                    STRIDE_K(Op::store32):                    (r(x).i32).store(arg(z.imm)); break;
+                    STRIDE_K(Op::store8 ): skvx::cast<uint8_t>(r(x).i32).store(arg(y.imm)); break;
+                    STRIDE_K(Op::store32):                    (r(x).i32).store(arg(y.imm)); break;
 
-                    STRIDE_1(Op::load8 ): r(d).i32 = 0; memcpy(&r(d).i32, arg(z.imm), 1); break;
-                    STRIDE_1(Op::load32): r(d).i32 = 0; memcpy(&r(d).i32, arg(z.imm), 4); break;
+                    STRIDE_1(Op::load8 ): r(d).i32 = 0; memcpy(&r(d).i32, arg(y.imm), 1); break;
+                    STRIDE_1(Op::load32): r(d).i32 = 0; memcpy(&r(d).i32, arg(y.imm), 4); break;
 
-                    STRIDE_K(Op::load8 ): r(d).i32 = skvx::cast<int>(U8 ::Load(arg(z.imm))); break;
-                    STRIDE_K(Op::load32): r(d).i32 =                 I32::Load(arg(z.imm)) ; break;
+                    STRIDE_K(Op::load8 ): r(d).i32 = skvx::cast<int>(U8 ::Load(arg(y.imm))); break;
+                    STRIDE_K(Op::load32): r(d).i32 =                 I32::Load(arg(y.imm)) ; break;
                 #undef STRIDE_1
                 #undef STRIDE_K
 
                     // Ops that don't interact with memory should never care about the stride.
                 #define CASE(op) case 2*(int)op: /*fallthrough*/ case 2*(int)op+1
-                    CASE(Op::splat): r(d).i32 = z.imm; break;
+                    CASE(Op::splat): r(d).i32 = y.imm; break;
 
-                    CASE(Op::add_f32): r(d).f32 = r(x).f32 + r(y).f32; break;
-                    CASE(Op::sub_f32): r(d).f32 = r(x).f32 - r(y).f32; break;
-                    CASE(Op::mul_f32): r(d).f32 = r(x).f32 * r(y).f32; break;
-                    CASE(Op::div_f32): r(d).f32 = r(x).f32 / r(y).f32; break;
+                    CASE(Op::add_f32): r(d).f32 = r(x).f32 + r(y.id).f32; break;
+                    CASE(Op::sub_f32): r(d).f32 = r(x).f32 - r(y.id).f32; break;
+                    CASE(Op::mul_f32): r(d).f32 = r(x).f32 * r(y.id).f32; break;
+                    CASE(Op::div_f32): r(d).f32 = r(x).f32 / r(y.id).f32; break;
 
-                    CASE(Op::mad_f32): r(d).f32 = r(x).f32 * r(y).f32 + r(z.id).f32; break;
+                    CASE(Op::mad_f32): r(d).f32 = r(x).f32 * r(y.id).f32 + r(z.id).f32; break;
 
-                    CASE(Op::add_i32): r(d).i32 = r(x).i32 + r(y).i32; break;
-                    CASE(Op::sub_i32): r(d).i32 = r(x).i32 - r(y).i32; break;
-                    CASE(Op::mul_i32): r(d).i32 = r(x).i32 * r(y).i32; break;
+                    CASE(Op::add_i32): r(d).i32 = r(x).i32 + r(y.id).i32; break;
+                    CASE(Op::sub_i32): r(d).i32 = r(x).i32 - r(y.id).i32; break;
+                    CASE(Op::mul_i32): r(d).i32 = r(x).i32 * r(y.id).i32; break;
 
-                    CASE(Op::bit_and): r(d).i32 = r(x).i32 & r(y).i32; break;
-                    CASE(Op::bit_or ): r(d).i32 = r(x).i32 | r(y).i32; break;
-                    CASE(Op::bit_xor): r(d).i32 = r(x).i32 ^ r(y).i32; break;
+                    CASE(Op::bit_and): r(d).i32 = r(x).i32 & r(y.id).i32; break;
+                    CASE(Op::bit_or ): r(d).i32 = r(x).i32 | r(y.id).i32; break;
+                    CASE(Op::bit_xor): r(d).i32 = r(x).i32 ^ r(y.id).i32; break;
 
-                    CASE(Op::shl): r(d).i32 =                 r(x).i32 << z.imm ; break;
-                    CASE(Op::sra): r(d).i32 =                 r(x).i32 >> z.imm ; break;
-                    CASE(Op::shr): r(d).i32 = skvx::cast<int>(r(x).u32 >> z.imm); break;
+                    CASE(Op::shl): r(d).i32 = r(x).i32 << y.imm; break;
+                    CASE(Op::sra): r(d).i32 = r(x).i32 >> y.imm; break;
+                    CASE(Op::shr): r(d).u32 = r(x).u32 >> y.imm; break;
 
-                    CASE(Op::mul_unorm8): r(d).i32 = (r(x).i32 * r(y).i32 + 255) / 256; break;
+                    CASE(Op::mul_unorm8): r(d).i32 = (r(x).i32 * r(y.id).i32 + 255) / 256; break;
 
-                    CASE(Op::extract): {
-                        SkASSERT(z.imm != 0);
-                    #if defined(SK_BUILD_FOR_WIN)
-                        unsigned long shift;
-                        _BitScanForward(&shift, z.imm);
-                    #else
-                        const int shift = __builtin_ctz(z.imm);
-                    #endif
-                        r(d).i32 = skvx::cast<int>( (r(x).u32 & z.imm) >> shift );
-                    } break;
+                    CASE(Op::extract): r(d).u32 = (r(x).u32 & y.imm) >> z.imm; break;
 
-                    CASE(Op::pack): r(d).i32 = r(x).i32 | (r(y).i32 << z.imm); break;
+                    CASE(Op::pack): r(d).i32 = r(x).i32 | (r(y.id).i32 << z.imm); break;
 
                     CASE(Op::to_f32): r(d).f32 = skvx::cast<float>(r(x).i32); break;
                     CASE(Op::to_i32): r(d).i32 = skvx::cast<int>  (r(x).f32); break;