API change to allow for NormalSource selection at the user level.

This CL's base is the CL for CPU handling: https://codereview.chromium.org/2050773002/

BUG=skia:
GOLD_TRYBOT_URL= https://gold.skia.org/search?issue=2063793002

Review-Url: https://codereview.chromium.org/2063793002
diff --git a/gm/lightingshader.cpp b/gm/lightingshader.cpp
index ae6a24f..3b1c224 100644
--- a/gm/lightingshader.cpp
+++ b/gm/lightingshader.cpp
@@ -7,7 +7,9 @@
 
 #include "gm.h"
 
+#include "SkBitmapProcShader.h"
 #include "SkLightingShader.h"
+#include "SkNormalSource.h"
 #include "SkPoint3.h"
 #include "SkShader.h"
 
@@ -49,7 +51,9 @@
         SkLights::Builder builder;
 
         builder.add(SkLights::Light(SkColor3f::Make(1.0f, 1.0f, 1.0f),
-                                    SkVector3::Make(1.0f, 0.0f, 0.0f)));
+                                    SkVector3::Make(SK_ScalarRoot2Over2,
+                                                    0.0f,
+                                                    SK_ScalarRoot2Over2)));
         builder.add(SkLights::Light(SkColor3f::Make(0.2f, 0.2f, 0.2f)));
 
         fLights = builder.finish();
@@ -95,12 +99,13 @@
 
         const SkMatrix& ctm = canvas->getTotalMatrix();
 
-        // TODO: correctly pull out the pure rotation
-        SkVector invNormRotation = { ctm[SkMatrix::kMScaleX], ctm[SkMatrix::kMSkewY] };
-
         SkPaint paint;
-        paint.setShader(SkLightingShader::Make(fDiffuse, fNormalMaps[mapType], fLights,
-                                               invNormRotation, &matrix, &matrix));
+        sk_sp<SkShader> normalMap = SkMakeBitmapShader(fNormalMaps[mapType],
+                SkShader::kClamp_TileMode, SkShader::kClamp_TileMode, &matrix, nullptr);
+        sk_sp<SkNormalSource> normalSource = SkNormalSource::MakeFromNormalMap(std::move(normalMap),
+                                                                               ctm);
+        paint.setShader(SkLightingShader::Make(fDiffuse, fLights, &matrix,
+                                               std::move(normalSource)));
 
         canvas->drawRect(r, paint);
     }
diff --git a/samplecode/SampleLighting.cpp b/samplecode/SampleLighting.cpp
index a27aa9d..5949f49 100755
--- a/samplecode/SampleLighting.cpp
+++ b/samplecode/SampleLighting.cpp
@@ -8,8 +8,10 @@
 #include "SampleCode.h"
 #include "Resources.h"
 
+#include "SkBitmapProcShader.h"
 #include "SkCanvas.h"
 #include "SkLightingShader.h"
+#include "SkNormalSource.h"
 #include "SkPoint3.h"
 
 static sk_sp<SkLights> create_lights(SkScalar angle, SkScalar blue) {
@@ -64,9 +66,12 @@
 
         sk_sp<SkLights> lights(create_lights(fLightAngle, fColorFactor));
         SkPaint paint;
-        paint.setShader(SkLightingShader::Make(fDiffuseBitmap, fNormalBitmap,
-                                               std::move(lights), SkVector::Make(1.0f, 0.0f),
-                                               nullptr, nullptr));
+        sk_sp<SkShader> normalMap = SkMakeBitmapShader(fNormalBitmap,
+            SkShader::kClamp_TileMode, SkShader::kClamp_TileMode, nullptr, nullptr);
+        sk_sp<SkNormalSource> normalSource = SkNormalSource::MakeFromNormalMap(
+                std::move(normalMap), SkMatrix::I());
+        paint.setShader(SkLightingShader::Make(fDiffuseBitmap, std::move(lights), nullptr,
+                                               std::move(normalSource)));
         paint.setColor(SK_ColorBLACK);
 
         SkRect r = SkRect::MakeWH((SkScalar)fDiffuseBitmap.width(),
diff --git a/samplecode/SampleLitAtlas.cpp b/samplecode/SampleLitAtlas.cpp
index f1d67e5..ba42ed8 100644
--- a/samplecode/SampleLitAtlas.cpp
+++ b/samplecode/SampleLitAtlas.cpp
@@ -7,13 +7,15 @@
 
 #include "SampleCode.h"
 #include "SkAnimTimer.h"
-#include "SkView.h"
+#include "SkBitmapProcShader.h"
 #include "SkCanvas.h"
 #include "SkDrawable.h"
 #include "SkLightingShader.h"
 #include "SkLights.h"
+#include "SkNormalSource.h"
 #include "SkRandom.h"
 #include "SkRSXform.h"
+#include "SkView.h"
 
 #include "sk_tool_utils.h"
 
@@ -128,12 +130,12 @@
             SkMatrix m;
             m.setRSXform(xforms[i]);
 
-            // TODO: correctly pull out the pure rotation
-            SkVector invNormRotation = { m[SkMatrix::kMScaleX], m[SkMatrix::kMSkewY] };
-            SkASSERT(SkScalarNearlyEqual(invNormRotation.lengthSqd(), SK_Scalar1));
-
-            paint.setShader(SkLightingShader::Make(fAtlas, fAtlas, fLights,
-                                                   invNormRotation, &diffMat, &normalMat));
+            sk_sp<SkShader> normalMap = SkMakeBitmapShader(fAtlas, SkShader::kClamp_TileMode,
+                    SkShader::kClamp_TileMode, &normalMat, nullptr);
+            sk_sp<SkNormalSource> normalSource = SkNormalSource::MakeFromNormalMap(
+                    std::move(normalMap), m);
+            paint.setShader(SkLightingShader::Make(fAtlas, fLights, &diffMat,
+                                                   std::move(normalSource)));
 
             canvas->save();
                 canvas->setMatrix(m);
diff --git a/src/core/SkLightingShader.cpp b/src/core/SkLightingShader.cpp
index f32aa9f..b40f4a7 100644
--- a/src/core/SkLightingShader.cpp
+++ b/src/core/SkLightingShader.cpp
@@ -45,34 +45,18 @@
     /** Create a new lighting shader that uses the provided normal map and
         lights to light the diffuse bitmap.
         @param diffuse           the diffuse bitmap
-        @param normal            the normal map
         @param lights            the lights applied to the normal map
-        @param invNormRotation   rotation applied to the normal map's normals
         @param diffLocalM        the local matrix for the diffuse coordinates
-        @param normLocalM        the local matrix for the normal coordinates
-        @param normalSource      the normal source for GPU computations
+        @param normalSource      the source of normals for lighting computation
     */
-    SkLightingShaderImpl(const SkBitmap& diffuse, const SkBitmap& normal,
+    SkLightingShaderImpl(const SkBitmap& diffuse,
                          const sk_sp<SkLights> lights,
-                         const SkVector& invNormRotation,
-                         const SkMatrix* diffLocalM, const SkMatrix* normLocalM,
+                         const SkMatrix* diffLocalM,
                          sk_sp<SkNormalSource> normalSource)
         : INHERITED(diffLocalM)
         , fDiffuseMap(diffuse)
-        , fNormalMap(normal)
         , fLights(std::move(lights))
-        , fInvNormRotation(invNormRotation) {
-
-        if (normLocalM) {
-            fNormLocalMatrix = *normLocalM;
-        } else {
-            fNormLocalMatrix.reset();
-        }
-        // Pre-cache so future calls to fNormLocalMatrix.getType() are threadsafe.
-        (void)fNormLocalMatrix.getType();
-
-        fNormalSource = std::move(normalSource);
-    }
+        , fNormalSource(std::move(normalSource)) {}
 
     bool isOpaque() const override;
 
@@ -117,13 +101,8 @@
 
 private:
     SkBitmap        fDiffuseMap;
-    SkBitmap        fNormalMap;
-
     sk_sp<SkLights> fLights;
 
-    SkMatrix        fNormLocalMatrix;
-    SkVector        fInvNormRotation;
-
     sk_sp<SkNormalSource> fNormalSource;
 
     friend class SkLightingShader;
@@ -327,8 +306,6 @@
     // we assume diffuse and normal maps have same width and height
     // TODO: support different sizes, will be addressed when diffuse maps are factored out of
     //       SkLightingShader in a future CL
-    SkASSERT(fDiffuseMap.width() == fNormalMap.width() &&
-             fDiffuseMap.height() == fNormalMap.height());
     SkMatrix diffM;
 
     if (!make_mat(fDiffuseMap, this->getLocalMatrix(), localMatrix, &diffM)) {
@@ -500,26 +477,12 @@
         diffLocalM.reset();
     }
 
-    SkMatrix normLocalM;
-    bool hasNormLocalM = buf.readBool();
-    if (hasNormLocalM) {
-        buf.readMatrix(&normLocalM);
-    } else {
-        normLocalM.reset();
-    }
-
     SkBitmap diffuse;
     if (!buf.readBitmap(&diffuse)) {
         return nullptr;
     }
     diffuse.setImmutable();
 
-    SkBitmap normal;
-    if (!buf.readBitmap(&normal)) {
-        return nullptr;
-    }
-    normal.setImmutable();
-
     int numLights = buf.readInt();
 
     SkLights::Builder builder;
@@ -545,28 +508,16 @@
 
     sk_sp<SkLights> lights(builder.finish());
 
-    SkVector invNormRotation = {1,0};
-    if (!buf.isVersionLT(SkReadBuffer::kLightingShaderWritesInvNormRotation)) {
-        invNormRotation = buf.readPoint();
-    }
-
     sk_sp<SkNormalSource> normalSource(buf.readFlattenable<SkNormalSource>());
 
-    return sk_make_sp<SkLightingShaderImpl>(diffuse, normal, std::move(lights), invNormRotation,
-                                            &diffLocalM, &normLocalM, std::move(normalSource));
+    return sk_make_sp<SkLightingShaderImpl>(diffuse, std::move(lights), &diffLocalM,
+                                            std::move(normalSource));
 }
 
 void SkLightingShaderImpl::flatten(SkWriteBuffer& buf) const {
     this->INHERITED::flatten(buf);
 
-    bool hasNormLocalM = !fNormLocalMatrix.isIdentity();
-    buf.writeBool(hasNormLocalM);
-    if (hasNormLocalM) {
-        buf.writeMatrix(fNormLocalMatrix);
-    }
-
     buf.writeBitmap(fDiffuseMap);
-    buf.writeBitmap(fNormalMap);
 
     buf.writeInt(fLights->numLights());
     for (int l = 0; l < fLights->numLights(); ++l) {
@@ -580,7 +531,6 @@
             buf.writeScalarArray(&light.dir().fX, 3);
         }
     }
-    buf.writePoint(fInvNormRotation);
 
     buf.writeFlattenable(fNormalSource.get());
 }
@@ -625,27 +575,21 @@
 
 ///////////////////////////////////////////////////////////////////////////////
 
-sk_sp<SkShader> SkLightingShader::Make(const SkBitmap& diffuse, const SkBitmap& normal,
+sk_sp<SkShader> SkLightingShader::Make(const SkBitmap& diffuse,
                                        sk_sp<SkLights> lights,
-                                       const SkVector& invNormRotation,
-                                       const SkMatrix* diffLocalM, const SkMatrix* normLocalM) {
-    if (diffuse.isNull() || SkBitmapProcShader::BitmapIsTooBig(diffuse) ||
-        normal.isNull()  || SkBitmapProcShader::BitmapIsTooBig(normal) ||
-        diffuse.width()  != normal.width() ||
-        diffuse.height() != normal.height()) {
+                                       const SkMatrix* diffLocalM,
+                                       sk_sp<SkNormalSource> normalSource) {
+    if (diffuse.isNull() || SkBitmapProcShader::BitmapIsTooBig(diffuse)) {
         return nullptr;
     }
-    SkASSERT(SkScalarNearlyEqual(invNormRotation.lengthSqd(), SK_Scalar1));
 
-    // TODO: support other tile modes
-    sk_sp<SkShader> mapShader = SkMakeBitmapShader(normal, SkShader::kClamp_TileMode,
-                                                   SkShader::kClamp_TileMode, normLocalM, nullptr);
+    if (!normalSource) {
+        // TODO: Use a default implementation of normalSource instead
+        return nullptr;
+    }
 
-    sk_sp<SkNormalSource> normalSource = SkNormalSource::MakeFromNormalMap(mapShader,
-                                                                           invNormRotation);
-
-    return sk_make_sp<SkLightingShaderImpl>(diffuse, normal, std::move(lights),
-            invNormRotation, diffLocalM, normLocalM, std::move(normalSource));
+    return sk_make_sp<SkLightingShaderImpl>(diffuse, std::move(lights), diffLocalM,
+                                            std::move(normalSource));
 }
 
 ///////////////////////////////////////////////////////////////////////////////
diff --git a/src/core/SkLightingShader.h b/src/core/SkLightingShader.h
index e21b942..f103823 100644
--- a/src/core/SkLightingShader.h
+++ b/src/core/SkLightingShader.h
@@ -13,45 +13,33 @@
 
 class SkBitmap;
 class SkMatrix;
+class SkNormalSource;
 
 class SK_API SkLightingShader {
 public:
-    /** Returns a shader that lights the diffuse and normal maps with a set of lights.
+    /** Returns a shader that lights the diffuse map using the normals and a set of lights.
 
         It returns a shader with a reference count of 1.
         The caller should decrement the shader's reference count when done with the shader.
         It is an error for count to be < 2.
         @param  diffuse     the diffuse bitmap
-        @param  normal      the normal map
         @param  lights       the lights applied to the normal map
-        @param  invNormRotation rotation applied to the normal map's normals
         @param  diffLocalMatrix the local matrix for the diffuse map (transform from
                                 texture coordinates to shape source coordinates). nullptr is
                                 interpreted as an identity matrix.
-        @param  normLocalMatrix the local matrix for the normal map (transform from
-                                texture coordinates to shape source coordinates). nullptr is
-                                interpreted as an identity matrix.
+        @param  normalSource the source for the normals
 
         nullptr will be returned if:
-            either 'diffuse' or 'normal' are empty
-            either 'diffuse' or 'normal' are too big (> 65535 on a side)
-            'diffuse' and 'normal' aren't the same size
+            'diffuse' is empty
+            'diffuse' is too big (> 65535 on any side)
 
         The lighting equation is currently:
             result = LightColor * DiffuseColor * (Normal * LightDir) + AmbientColor
 
-        The normal map is currently assumed to be an 8888 image where the normal at a texel
-        is retrieved by:
-            N.x = R-127;
-            N.y = G-127;
-            N.z = B-127;
-            N.normalize();
-        The +Z axis is thus encoded in RGB as (127, 127, 255) while the -Z axis is
-        (127, 127, 0).
     */
-    static sk_sp<SkShader> Make(const SkBitmap& diffuse, const SkBitmap& normal,
-                                sk_sp<SkLights> lights, const SkVector& invNormRotation,
-                                const SkMatrix* diffLocalMatrix, const SkMatrix* normLocalMatrix);
+    static sk_sp<SkShader> Make(const SkBitmap& diffuse, sk_sp<SkLights> lights,
+                                const SkMatrix* diffLocalMatrix,
+                                sk_sp<SkNormalSource> normalSource);
 
     SK_DECLARE_FLATTENABLE_REGISTRAR_GROUP()
 };
diff --git a/src/core/SkNormalSource.cpp b/src/core/SkNormalSource.cpp
index 2f52530..52bb4ad 100644
--- a/src/core/SkNormalSource.cpp
+++ b/src/core/SkNormalSource.cpp
@@ -8,6 +8,7 @@
 #include "SkError.h"
 #include "SkErrorInternals.h"
 #include "SkLightingShader.h"
+#include "SkMatrix.h"
 #include "SkNormalSource.h"
 #include "SkReadBuffer.h"
 #include "SkWriteBuffer.h"
@@ -19,9 +20,9 @@
 
 class NormalMapSourceImpl : public SkNormalSource {
 public:
-    NormalMapSourceImpl(sk_sp<SkShader> mapShader, const SkVector &normRotation)
+    NormalMapSourceImpl(sk_sp<SkShader> mapShader, const SkMatrix& invCTM)
         : fMapShader(std::move(mapShader))
-        , fNormRotation(normRotation) {}
+        , fInvCTM(invCTM) {}
 
 #if SK_SUPPORT_GPU
     sk_sp<GrFragmentProcessor> asFragmentProcessor(GrContext*,
@@ -58,7 +59,7 @@
     };
 
     sk_sp<SkShader> fMapShader;
-    SkVector        fNormRotation;
+    SkMatrix        fInvCTM; // Inverse of the canvas total matrix, used for rotating normals.
 
     friend class SkNormalSource;
 
@@ -78,8 +79,8 @@
 
 class NormalMapFP : public GrFragmentProcessor {
 public:
-    NormalMapFP(sk_sp<GrFragmentProcessor> mapFP, const SkVector& normRotation)
-        : fNormRotation(normRotation) {
+    NormalMapFP(sk_sp<GrFragmentProcessor> mapFP, const SkMatrix& invCTM)
+        : fInvCTM(invCTM) {
         this->registerChildProcessor(mapFP);
 
         this->initClassID<NormalMapFP>();
@@ -87,34 +88,46 @@
 
     class GLSLNormalMapFP : public GrGLSLFragmentProcessor {
     public:
-        GLSLNormalMapFP() {
-            fNormRotation.set(0.0f, 0.0f);
-        }
+        GLSLNormalMapFP()
+            : fColumnMajorInvCTM22{0.0f} {}
 
         void emitCode(EmitArgs& args) override {
-
             GrGLSLFragmentBuilder* fragBuilder = args.fFragBuilder;
             GrGLSLUniformHandler* uniformHandler = args.fUniformHandler;
 
             // add uniform
             const char* xformUniName = nullptr;
-            fXformUni = uniformHandler->addUniform(kFragment_GrShaderFlag,
-                                                   kVec2f_GrSLType, kDefault_GrSLPrecision,
-                                                   "Xform", &xformUniName);
+            fXformUni = uniformHandler->addUniform(kFragment_GrShaderFlag, kMat22f_GrSLType,
+                                                   kDefault_GrSLPrecision, "Xform", &xformUniName);
 
             SkString dstNormalColorName("dstNormalColor");
             this->emitChild(0, nullptr, &dstNormalColorName, args);
-            fragBuilder->codeAppendf("vec3 normal = %s.rgb - vec3(0.5);",
+            fragBuilder->codeAppendf("vec3 normal = normalize(%s.rgb - vec3(0.5));",
                                      dstNormalColorName.c_str());
 
             // TODO: inverse map the light direction vectors in the vertex shader rather than
             // transforming all the normals here!
-            fragBuilder->codeAppendf(
-                    "mat3 m = mat3(%s.x, -%s.y, 0.0, %s.y, %s.x, 0.0, 0.0, 0.0, 1.0);",
-                    xformUniName, xformUniName, xformUniName, xformUniName);
 
-            fragBuilder->codeAppend("normal = normalize(m*normal);");
-            fragBuilder->codeAppendf("%s = vec4(normal, 0);", args.fOutputColor);
+            // If there's no x & y components, return (0, 0, +/- 1) instead to avoid division by 0
+            fragBuilder->codeAppend( "if (abs(normal.z) > 0.9999) {");
+            fragBuilder->codeAppendf("    %s = normalize(vec4(0.0, 0.0, normal.z, 0.0));",
+                    args.fOutputColor);
+            // Else, Normalizing the transformed X and Y, while keeping constant both Z and the
+            // vector's angle in the XY plane. This maintains the "slope" for the surface while
+            // appropriately rotating the normal for any anisotropic scaling that occurs.
+            // Here, we call scaling factor the number that must divide the transformed X and Y so
+            // that the normal's length remains equal to 1.
+            fragBuilder->codeAppend( "} else {");
+            fragBuilder->codeAppendf("    vec2 transformed = %s * normal.xy;",
+                    xformUniName);
+            fragBuilder->codeAppend( "    float scalingFactorSquared = "
+                                                 "( (transformed.x * transformed.x) "
+                                                   "+ (transformed.y * transformed.y) )"
+                                                 "/(1.0 - (normal.z * normal.z));");
+            fragBuilder->codeAppendf("    %s = vec4(transformed*inversesqrt(scalingFactorSquared),"
+                                                   "normal.z, 0.0);",
+                    args.fOutputColor);
+            fragBuilder->codeAppend( "}");
         }
 
         static void GenKey(const GrProcessor& proc, const GrGLSLCaps&,
@@ -126,15 +139,16 @@
         void onSetData(const GrGLSLProgramDataManager& pdman, const GrProcessor& proc) override {
             const NormalMapFP& normalMapFP = proc.cast<NormalMapFP>();
 
-            const SkVector& normRotation = normalMapFP.normRotation();
-            if (normRotation != fNormRotation) {
-                pdman.set2fv(fXformUni, 1, &normRotation.fX);
-                fNormRotation = normRotation;
-            }
+            const SkMatrix& invCTM = normalMapFP.invCTM();
+            fColumnMajorInvCTM22[0] = invCTM.get(SkMatrix::kMScaleX);
+            fColumnMajorInvCTM22[1] = invCTM.get(SkMatrix::kMSkewY);
+            fColumnMajorInvCTM22[2] = invCTM.get(SkMatrix::kMSkewX);
+            fColumnMajorInvCTM22[3] = invCTM.get(SkMatrix::kMScaleY);
+            pdman.setMatrix2f(fXformUni, fColumnMajorInvCTM22);
         }
 
     private:
-        SkVector fNormRotation;
+        float fColumnMajorInvCTM22[4];
         GrGLSLProgramDataManager::UniformHandle fXformUni;
     };
 
@@ -148,17 +162,17 @@
         inout->setToUnknown(GrInvariantOutput::ReadInput::kWillNot_ReadInput);
     }
 
-    const SkVector& normRotation() const { return fNormRotation; }
+    const SkMatrix& invCTM() const { return fInvCTM; }
 
 private:
     GrGLSLFragmentProcessor* onCreateGLSLInstance() const override { return new GLSLNormalMapFP; }
 
     bool onIsEqual(const GrFragmentProcessor& proc) const override {
         const NormalMapFP& normalMapFP = proc.cast<NormalMapFP>();
-        return fNormRotation == normalMapFP.fNormRotation;
+        return fInvCTM == normalMapFP.fInvCTM;
     }
 
-    SkVector fNormRotation;
+    SkMatrix fInvCTM;
 };
 
 sk_sp<GrFragmentProcessor> NormalMapSourceImpl::asFragmentProcessor(
@@ -171,7 +185,7 @@
     sk_sp<GrFragmentProcessor> mapFP = fMapShader->asFragmentProcessor(context, viewM,
             localMatrix, filterQuality, gammaTreatment);
 
-    return sk_make_sp<NormalMapFP>(std::move(mapFP), fNormRotation);
+    return sk_make_sp<NormalMapFP>(std::move(mapFP), fInvCTM);
 }
 
 #endif // SK_SUPPORT_GPU
@@ -239,11 +253,28 @@
                          SkIntToScalar(SkGetPackedB32(tmpNormalColors[i])) - 127.0f);
             tempNorm.normalize();
 
-            output[i].fX = fSource.fNormRotation.fX * tempNorm.fX +
-                           fSource.fNormRotation.fY * tempNorm.fY;
-            output[i].fY = -fSource.fNormRotation.fY * tempNorm.fX +
-                           fSource.fNormRotation.fX * tempNorm.fY;
-            output[i].fZ = tempNorm.fZ;
+            if (!SkScalarNearlyEqual(SkScalarAbs(tempNorm.fZ), 1.0f)) {
+                SkVector transformed = fSource.fInvCTM.mapVector(tempNorm.fX, tempNorm.fY);
+
+                // Normalizing the transformed X and Y, while keeping constant both Z and the
+                // vector's angle in the XY plane. This maintains the "slope" for the surface while
+                // appropriately rotating the normal for any anisotropic scaling that occurs.
+                // Here, we call scaling factor the number that must divide the transformed X and Y
+                // so that the normal's length remains equal to 1.
+                SkScalar scalingFactorSquared =
+                        (SkScalarSquare(transformed.fX) + SkScalarSquare(transformed.fY))
+                        / (1.0f - SkScalarSquare(tempNorm.fZ));
+                SkScalar invScalingFactor = SkScalarInvert(SkScalarSqrt(scalingFactorSquared));
+
+                output[i].fX = transformed.fX * invScalingFactor;
+                output[i].fY = transformed.fY * invScalingFactor;
+                output[i].fZ = tempNorm.fZ;
+            } else {
+                output[i] = {0.0f, 0.0f, tempNorm.fZ};
+                output[i].normalize();
+            }
+
+            SkASSERT(SkScalarNearlyEqual(output[i].length(), 1.0f))
         }
 
         output += n;
@@ -258,31 +289,29 @@
 
     sk_sp<SkShader> mapShader = buf.readFlattenable<SkShader>();
 
-    SkVector normRotation = {1,0};
-    if (!buf.isVersionLT(SkReadBuffer::kLightingShaderWritesInvNormRotation)) {
-        normRotation = buf.readPoint();
-    }
+    SkMatrix invCTM;
+    buf.readMatrix(&invCTM);
 
-    return sk_make_sp<NormalMapSourceImpl>(std::move(mapShader), normRotation);
+    return sk_make_sp<NormalMapSourceImpl>(std::move(mapShader), invCTM);
 }
 
 void NormalMapSourceImpl::flatten(SkWriteBuffer& buf) const {
     this->INHERITED::flatten(buf);
 
     buf.writeFlattenable(fMapShader.get());
-    buf.writePoint(fNormRotation);
+    buf.writeMatrix(fInvCTM);
 }
 
 ////////////////////////////////////////////////////////////////////////////
 
-sk_sp<SkNormalSource> SkNormalSource::MakeFromNormalMap(sk_sp<SkShader> map,
-                                                        const SkVector &normRotation) {
-    SkASSERT(SkScalarNearlyEqual(normRotation.lengthSqd(), SK_Scalar1));
-    if (!map) {
+sk_sp<SkNormalSource> SkNormalSource::MakeFromNormalMap(sk_sp<SkShader> map, const SkMatrix& ctm) {
+    SkMatrix invCTM;
+
+    if (!ctm.invert(&invCTM) || !map) {
         return nullptr;
     }
 
-    return sk_make_sp<NormalMapSourceImpl>(std::move(map), normRotation);
+    return sk_make_sp<NormalMapSourceImpl>(std::move(map), invCTM);
 }
 
 ////////////////////////////////////////////////////////////////////////////
diff --git a/src/core/SkNormalSource.h b/src/core/SkNormalSource.h
index 0d0c672..7517770 100644
--- a/src/core/SkNormalSource.h
+++ b/src/core/SkNormalSource.h
@@ -10,8 +10,7 @@
 
 #include "SkFlattenable.h"
 
-/** Abstract class that generates or reads in normals for use by SkLightingShader. Not to be
-    used as part of the API yet. Used internally by SkLightingShader.
+/** Abstract class that generates or reads in normals for use by SkLightingShader.
 */
 class SK_API SkNormalSource : public SkFlattenable {
 public:
@@ -49,11 +48,9 @@
     virtual size_t providerSize(const SkShader::ContextRec&) const = 0;
 
     /** Returns a normal source that provides normals sourced from the the normal map argument.
-        Not to be used as part of the API yet. Used internally by SkLightingShader.
 
-        @param  map              a shader that outputs the normal map
-        @param  normRotation     rotation applied to the normal map's normals, in the
-                                 [cos a, sin a] form.
+        @param  map  a shader that outputs the normal map
+        @param  ctm  the current canvas' total matrix, used to rotate normals when necessary.
 
         nullptr will be returned if 'map' is null
 
@@ -66,8 +63,7 @@
         The +Z axis is thus encoded in RGB as (127, 127, 255) while the -Z axis is
         (127, 127, 0).
     */
-    static sk_sp<SkNormalSource> MakeFromNormalMap(sk_sp<SkShader> map,
-                                                   const SkVector& normRotation);
+    static sk_sp<SkNormalSource> MakeFromNormalMap(sk_sp<SkShader> map, const SkMatrix& ctm);
 
     SK_DEFINE_FLATTENABLE_TYPE(SkNormalSource)
     SK_DECLARE_FLATTENABLE_REGISTRAR_GROUP()
diff --git a/tests/SerializationTest.cpp b/tests/SerializationTest.cpp
index e4e1a7a..1e96eef 100644
--- a/tests/SerializationTest.cpp
+++ b/tests/SerializationTest.cpp
@@ -7,6 +7,7 @@
 
 #include "Resources.h"
 #include "SkAnnotationKeys.h"
+#include "SkBitmapProcShader.h"
 #include "SkCanvas.h"
 #include "SkFixed.h"
 #include "SkFontDescriptor.h"
@@ -14,6 +15,7 @@
 #include "SkImageSource.h"
 #include "SkLightingShader.h"
 #include "SkMallocPixelRef.h"
+#include "SkNormalSource.h"
 #include "SkOSFile.h"
 #include "SkPictureRecorder.h"
 #include "SkTableColorFilter.h"
@@ -573,13 +575,18 @@
         SkRect r = SkRect::MakeWH(SkIntToScalar(kTexSize), SkIntToScalar(kTexSize));
         matrix.setRectToRect(bitmapBounds, r, SkMatrix::kFill_ScaleToFit);
 
-        SkVector invNormRotation = { SkScalarSqrt(0.3f), SkScalarSqrt(0.7f) };
+        SkMatrix ctm;
+        ctm.setRotate(45);
         SkBitmap normals;
         normals.allocN32Pixels(kTexSize, kTexSize);
 
         sk_tool_utils::create_frustum_normal_map(&normals, SkIRect::MakeWH(kTexSize, kTexSize));
-        sk_sp<SkShader> lightingShader = SkLightingShader::Make(diffuse, normals, fLights,
-                invNormRotation, &matrix, &matrix);
+        sk_sp<SkShader> normalMap = SkMakeBitmapShader(normals, SkShader::kClamp_TileMode,
+                                                       SkShader::kClamp_TileMode, &matrix, nullptr);
+        sk_sp<SkNormalSource> normalSource = SkNormalSource::MakeFromNormalMap(std::move(normalMap),
+                                                                               ctm);
+        sk_sp<SkShader> lightingShader = SkLightingShader::Make(diffuse, fLights, &matrix,
+                                                                std::move(normalSource));
 
         SkAutoTUnref<SkShader>(TestFlattenableSerialization(lightingShader.get(), true, reporter));
         // TODO test equality?