SkLS now accepting nullptr for diffuse shader and normal source, now accurately handling alpha

This CL's base is the CL for taking in a diffuse shader into SkLS on the API side: https://codereview.chromium.org/2064153002

BUG=skia:5502,skia:5517
GOLD_TRYBOT_URL= https://gold.skia.org/search?issue=2132113002

Review-Url: https://codereview.chromium.org/2132113002
diff --git a/gm/lightingshader2.cpp b/gm/lightingshader2.cpp
new file mode 100644
index 0000000..c2f4d6c
--- /dev/null
+++ b/gm/lightingshader2.cpp
@@ -0,0 +1,232 @@
+/*
+ * Copyright 2016 Google Inc.
+ *
+ * Use of this source code is governed by a BSD-style license that can be
+ * found in the LICENSE file.
+ */
+
+#include "gm.h"
+
+#include "SkBitmapProcShader.h"
+#include "SkLightingShader.h"
+#include "SkNormalSource.h"
+#include "SkPoint3.h"
+#include "SkShader.h"
+
+// Create a truncated pyramid normal map
+static SkBitmap make_frustum_normalmap(int texSize) {
+    SkBitmap frustum;
+    frustum.allocN32Pixels(texSize, texSize);
+
+    sk_tool_utils::create_frustum_normal_map(&frustum, SkIRect::MakeWH(texSize, texSize));
+    return frustum;
+}
+
+namespace skiagm {
+
+// This GM exercises lighting shaders. Specifically, nullptr arguments, scaling when using
+// normal maps, and paint transparency.
+class LightingShader2GM : public GM {
+public:
+    LightingShader2GM() {
+        this->setBGColor(sk_tool_utils::color_to_565(0xFFCCCCCC));
+    }
+
+protected:
+    SkString onShortName() override {
+        return SkString("lightingshader2");
+    }
+
+    SkISize onISize() override {
+        return SkISize::Make(kGMSize, kGMSize);
+    }
+
+    void onOnceBeforeDraw() override {
+        SkLights::Builder builder;
+        const SkVector3 kLightFromUpperRight = SkVector3::Make(0.788f, 0.394f, 0.473f);
+
+        builder.add(SkLights::Light(SkColor3f::Make(1.0f, 1.0f, 1.0f),
+                                    kLightFromUpperRight));
+        builder.add(SkLights::Light(SkColor3f::Make(0.2f, 0.2f, 0.2f)));
+        fLights = builder.finish();
+
+        fRect = SkRect::MakeIWH(kTexSize, kTexSize);
+        SkMatrix matrix;
+        SkRect bitmapBounds = SkRect::MakeIWH(kTexSize, kTexSize);
+        matrix.setRectToRect(bitmapBounds, fRect, SkMatrix::kFill_ScaleToFit);
+
+        SkBitmap opaqueDiffuseMap = sk_tool_utils::create_checkerboard_bitmap(
+                kTexSize, kTexSize,
+                sk_tool_utils::color_to_565(0x0),
+                sk_tool_utils::color_to_565(0xFF804020),
+                8);
+        fOpaqueDiffuse = SkMakeBitmapShader(opaqueDiffuseMap, SkShader::kClamp_TileMode,
+                                            SkShader::kClamp_TileMode, &matrix, nullptr);
+
+        SkBitmap translucentDiffuseMap = sk_tool_utils::create_checkerboard_bitmap(
+                kTexSize, kTexSize,
+                SkColorSetARGB(0x55, 0x00, 0x00, 0x00),
+                SkColorSetARGB(0x55, 0x80, 0x40, 0x20),
+                8);
+        fTranslucentDiffuse = SkMakeBitmapShader(translucentDiffuseMap, SkShader::kClamp_TileMode,
+                                                 SkShader::kClamp_TileMode, &matrix, nullptr);
+
+        SkBitmap normalMap = make_frustum_normalmap(kTexSize);
+        fNormalMapShader = SkMakeBitmapShader(normalMap, SkShader::kClamp_TileMode,
+                                              SkShader::kClamp_TileMode, &matrix, nullptr);
+
+    }
+
+    // Scales shape around origin, rotates shape around origin, then translates shape to origin
+    void positionCTM(SkCanvas *canvas, SkScalar scaleX, SkScalar scaleY, SkScalar rotate) const {
+        canvas->translate(kTexSize/2.0f, kTexSize/2.0f);
+        canvas->scale(scaleX, scaleY);
+        canvas->rotate(rotate);
+        canvas->translate(-kTexSize/2.0f, -kTexSize/2.0f);
+    }
+
+    static constexpr int NUM_BOOLEAN_PARAMS = 4;
+    void drawRect(SkCanvas* canvas, SkScalar scaleX, SkScalar scaleY,
+                  SkScalar rotate, bool useNormalSource, bool useDiffuseShader,
+                  bool useTranslucentPaint, bool useTranslucentShader) {
+        canvas->save();
+
+        this->positionCTM(canvas, scaleX, scaleY, rotate);
+
+        const SkMatrix& ctm = canvas->getTotalMatrix();
+
+        SkPaint paint;
+        sk_sp<SkNormalSource> normalSource = nullptr;
+        sk_sp<SkShader> diffuseShader = nullptr;
+
+        if (useNormalSource) {
+            normalSource = SkNormalSource::MakeFromNormalMap(fNormalMapShader, ctm);
+        }
+
+        if (useDiffuseShader) {
+            diffuseShader = (useTranslucentShader) ? fTranslucentDiffuse : fOpaqueDiffuse;
+        } else {
+            paint.setColor(0xFF00FF00);
+        }
+
+        if (useTranslucentPaint) {
+            paint.setAlpha(0x99);
+        }
+
+        paint.setShader(SkLightingShader::Make(std::move(diffuseShader), std::move(normalSource),
+                                               fLights));
+        canvas->drawRect(fRect, paint);
+
+        canvas->restore();
+    }
+
+    void onDraw(SkCanvas* canvas) override {
+
+        constexpr SkScalar LABEL_SIZE = 10.0f;
+        SkPaint labelPaint;
+        labelPaint.setTypeface(sk_tool_utils::create_portable_typeface("sans-serif",
+                                                                       SkFontStyle()));
+        labelPaint.setAntiAlias(true);
+        labelPaint.setTextSize(LABEL_SIZE);
+
+        constexpr int GRID_COLUMN_NUM = 4;
+        constexpr SkScalar GRID_CELL_WIDTH = kTexSize + 20.0f + NUM_BOOLEAN_PARAMS * LABEL_SIZE;
+
+        int gridNum = 0;
+
+        // Running through all possible bool parameter combinations
+        for (bool useNormalSource : {true, false}) {
+            for (bool useDiffuseShader : {true, false}) {
+                for (bool useTranslucentPaint : {true, false}) {
+                    for (bool useTranslucentShader : {true, false}) {
+
+                        // Determining position
+                        SkScalar xPos = (gridNum % GRID_COLUMN_NUM) * GRID_CELL_WIDTH;
+                        SkScalar yPos = (gridNum / GRID_COLUMN_NUM) * GRID_CELL_WIDTH;
+
+                        canvas->save();
+
+                        canvas->translate(xPos, yPos);
+                        this->drawRect(canvas, 1.0f, 1.0f, 0.f, useNormalSource, useDiffuseShader,
+                                       useTranslucentPaint, useTranslucentShader);
+                        // Drawing labels
+                        canvas->translate(0.0f, SkIntToScalar(kTexSize));
+                        {
+                            canvas->translate(0.0f, LABEL_SIZE);
+                            SkString label;
+                            label.appendf("useNormalSource: %d", useNormalSource);
+                            canvas->drawText(label.c_str(), label.size(), 0.0f, 0.0f, labelPaint);
+                        }
+                        {
+                            canvas->translate(0.0f, LABEL_SIZE);
+                            SkString label;
+                            label.appendf("useDiffuseShader: %d", useDiffuseShader);
+                            canvas->drawText(label.c_str(), label.size(), 0.0f, 0.0f, labelPaint);
+                        }
+                        {
+                            canvas->translate(0.0f, LABEL_SIZE);
+                            SkString label;
+                            label.appendf("useTranslucentPaint: %d", useTranslucentPaint);
+                            canvas->drawText(label.c_str(), label.size(), 0.0f, 0.0f, labelPaint);
+                        }
+                        {
+                            canvas->translate(0.0f, LABEL_SIZE);
+                            SkString label;
+                            label.appendf("useTranslucentShader: %d", useTranslucentShader);
+                            canvas->drawText(label.c_str(), label.size(), 0.0f, 0.0f, labelPaint);
+                        }
+
+                        canvas->restore();
+
+                        gridNum++;
+                    }
+                }
+            }
+        }
+
+
+        // Rotation/scale test
+        {
+            SkScalar xPos = (gridNum % GRID_COLUMN_NUM) * GRID_CELL_WIDTH;
+            SkScalar yPos = (gridNum / GRID_COLUMN_NUM) * GRID_CELL_WIDTH;
+
+            canvas->save();
+            canvas->translate(xPos, yPos);
+            this->drawRect(canvas, 0.6f, 0.6f, 45.0f, true, true, true, true);
+            canvas->restore();
+
+            gridNum++;
+        }
+
+        // Anisotropic scale test
+        {
+            SkScalar xPos = (gridNum % GRID_COLUMN_NUM) * GRID_CELL_WIDTH;
+            SkScalar yPos = (gridNum / GRID_COLUMN_NUM) * GRID_CELL_WIDTH;
+
+            canvas->save();
+            canvas->translate(xPos, yPos);
+            this->drawRect(canvas, 0.6f, 0.4f, 30.0f, true, true, true, true);
+            canvas->restore();
+
+            gridNum++;
+        }
+    }
+
+private:
+    static const int kTexSize = 96;
+    static const int kGMSize  = 512;
+
+    sk_sp<SkShader> fOpaqueDiffuse;
+    sk_sp<SkShader> fTranslucentDiffuse;
+    sk_sp<SkShader> fNormalMapShader;
+
+    SkRect fRect;
+    sk_sp<SkLights> fLights;
+
+    typedef GM INHERITED;
+};
+
+//////////////////////////////////////////////////////////////////////////////
+
+DEF_GM(return new LightingShader2GM;)
+}
diff --git a/include/gpu/GrFragmentProcessor.h b/include/gpu/GrFragmentProcessor.h
index b8ebeca..4b56fbf 100644
--- a/include/gpu/GrFragmentProcessor.h
+++ b/include/gpu/GrFragmentProcessor.h
@@ -48,6 +48,12 @@
     static sk_sp<GrFragmentProcessor> OverrideInput(sk_sp<GrFragmentProcessor>, GrColor);
 
     /**
+     *  Returns a fragment processor that premuls the input before calling the passed in fragment
+     *  processor.
+     */
+    static sk_sp<GrFragmentProcessor> PremulInput(sk_sp<GrFragmentProcessor>);
+
+    /**
      * Returns a fragment processor that runs the passed in array of fragment processors in a
      * series. The original input is passed to the first, the first's output is passed to the
      * second, etc. The output of the returned processor is the output of the last processor of the
diff --git a/src/core/SkLightingShader.cpp b/src/core/SkLightingShader.cpp
index 02f14b3..92f41ad 100644
--- a/src/core/SkLightingShader.cpp
+++ b/src/core/SkLightingShader.cpp
@@ -21,15 +21,11 @@
 
 /*
    SkLightingShader TODOs:
-        support other than clamp mode
-        allow 'diffuse' & 'normal' to be of different dimensions?
         support different light types
         support multiple lights
-        enforce normal map is 4 channel
-        use SkImages instead if SkBitmaps
+        fix non-opaque diffuse textures
 
     To Test:
-        non-opaque diffuse textures
         A8 diffuse textures
         down & upsampled draws
 */
@@ -81,6 +77,7 @@
     private:
         SkShader::Context*        fDiffuseContext;
         SkNormalSource::Provider* fNormalProvider;
+        SkColor                   fPaintColor;
         uint32_t                  fFlags;
 
         void* fHeapAllocated;
@@ -121,6 +118,9 @@
 #include "SkGr.h"
 #include "SkGrPriv.h"
 
+// This FP expects a premul'd color input for its diffuse color. Premul'ing of the paint's color is
+// handled by the asFragmentProcessor() factory, but shaders providing diffuse color must output it
+// premul'd.
 class LightingFP : public GrFragmentProcessor {
 public:
     LightingFP(sk_sp<GrFragmentProcessor> normalFP, sk_sp<SkLights> lights) {
@@ -180,9 +180,13 @@
                                      lightDirUniName);
             // diffuse light
             fragBuilder->codeAppendf("vec3 result = %s*diffuseColor.rgb*NdotL;", lightColorUniName);
-            // ambient light
-            fragBuilder->codeAppendf("result += %s;", ambientColorUniName);
-            fragBuilder->codeAppendf("%s = vec4(result.rgb, diffuseColor.a);", args.fOutputColor);
+            // ambient light (multiplied by input color's alpha because we're working in premul'd
+            // space)
+            fragBuilder->codeAppendf("result += diffuseColor.a * %s;", ambientColorUniName);
+
+            // Clamping to alpha (equivalent to an unpremul'd clamp to 1.0)
+            fragBuilder->codeAppendf("%s = vec4(clamp(result.rgb, 0.0, diffuseColor.a), "
+                                               "diffuseColor.a);", args.fOutputColor);
         }
 
         static void GenKey(const GrProcessor& proc, const GrGLSLCaps&,
@@ -270,18 +274,25 @@
         return nullptr;
     }
 
-    sk_sp<GrFragmentProcessor> fpPipeline[] = {
+    if (fDiffuseShader) {
+        sk_sp<GrFragmentProcessor> fpPipeline[] = {
             fDiffuseShader->asFragmentProcessor(context, viewM, localMatrix, filterQuality,
                                                 gammaTreatment),
             sk_make_sp<LightingFP>(std::move(normalFP), fLights)
-    };
-    if(!fpPipeline[0]) {
-        return nullptr;
+        };
+        if(!fpPipeline[0]) {
+            return nullptr;
+        }
+
+        sk_sp<GrFragmentProcessor> innerLightFP = GrFragmentProcessor::RunInSeries(fpPipeline, 2);
+        // FP is wrapped because paint's alpha needs to be applied to output
+        return GrFragmentProcessor::MulOutputByInputAlpha(std::move(innerLightFP));
+    } else {
+        // FP is wrapped because paint comes in unpremul'd to fragment shader, but LightingFP
+        // expects premul'd color.
+        return GrFragmentProcessor::PremulInput(sk_make_sp<LightingFP>(std::move(normalFP),
+                                                                       fLights));
     }
-
-    sk_sp<GrFragmentProcessor> inner(GrFragmentProcessor::RunInSeries(fpPipeline, 2));
-
-    return GrFragmentProcessor::MulOutputByInputAlpha(std::move(inner));
 }
 
 #endif
@@ -289,7 +300,7 @@
 ////////////////////////////////////////////////////////////////////////////
 
 bool SkLightingShaderImpl::isOpaque() const {
-    return fDiffuseShader->isOpaque();
+    return (fDiffuseShader ? fDiffuseShader->isOpaque() : false);
 }
 
 SkLightingShaderImpl::LightingShaderContext::LightingShaderContext(
@@ -308,13 +319,16 @@
         flags |= kOpaqueAlpha_Flag;
     }
 
+    fPaintColor = rec.fPaint->getColor();
     fFlags = flags;
 }
 
 SkLightingShaderImpl::LightingShaderContext::~LightingShaderContext() {
     // The dependencies have been created outside of the context on memory that was allocated by
     // the onCreateContext() method. Call the destructors and free the memory.
-    fDiffuseContext->~Context();
+    if (fDiffuseContext) {
+        fDiffuseContext->~Context();
+    }
     fNormalProvider->~Provider();
 
     sk_free(fHeapAllocated);
@@ -352,15 +366,21 @@
     SkPMColor diffuse[BUFFER_MAX];
     SkPoint3 normals[BUFFER_MAX];
 
+    SkColor diffColor = fPaintColor;
+
     do {
         int n = SkTMin(count, BUFFER_MAX);
 
-        fDiffuseContext->shadeSpan(x, y, diffuse, n);
         fNormalProvider->fillScanLine(x, y, normals, n);
 
-        for (int i = 0; i < n; ++i) {
+        if (fDiffuseContext) {
+            fDiffuseContext->shadeSpan(x, y, diffuse, n);
+        }
 
-            SkColor diffColor = SkUnPreMultiply::PMColorToColor(diffuse[i]);
+        for (int i = 0; i < n; ++i) {
+            if (fDiffuseContext) {
+                diffColor = SkUnPreMultiply::PMColorToColor(diffuse[i]);
+            }
 
             SkColor3f accum = SkColor3f::Make(0.0f, 0.0f, 0.0f);
             // This is all done in linear unpremul color space (each component 0..255.0f though)
@@ -381,6 +401,7 @@
                 }
             }
 
+            // convert() premultiplies the accumulate color with alpha
             result[i] = convert(accum, SkColorGetA(diffColor));
         }
 
@@ -430,7 +451,12 @@
     sk_sp<SkLights> lights(builder.finish());
 
     sk_sp<SkNormalSource> normalSource(buf.readFlattenable<SkNormalSource>());
-    sk_sp<SkShader> diffuseShader(buf.readFlattenable<SkShader>());
+
+    bool hasDiffuse = buf.readBool();
+    sk_sp<SkShader> diffuseShader = nullptr;
+    if (hasDiffuse) {
+        diffuseShader = buf.readFlattenable<SkShader>();
+    }
 
     return sk_make_sp<SkLightingShaderImpl>(std::move(diffuseShader), std::move(normalSource),
                                             std::move(lights));
@@ -453,7 +479,10 @@
     }
 
     buf.writeFlattenable(fNormalSource.get());
-    buf.writeFlattenable(fDiffuseShader.get());
+    buf.writeBool(fDiffuseShader);
+    if (fDiffuseShader) {
+        buf.writeFlattenable(fDiffuseShader.get());
+    }
 }
 
 size_t SkLightingShaderImpl::onContextSize(const ContextRec& rec) const {
@@ -462,18 +491,23 @@
 
 SkShader::Context* SkLightingShaderImpl::onCreateContext(const ContextRec& rec,
                                                          void* storage) const {
-    size_t heapRequired = fDiffuseShader->contextSize(rec) +
+    size_t heapRequired = (fDiffuseShader ? fDiffuseShader->contextSize(rec) : 0) +
                           fNormalSource->providerSize(rec);
     void* heapAllocated = sk_malloc_throw(heapRequired);
 
     void* diffuseContextStorage = heapAllocated;
-    SkShader::Context* diffuseContext = fDiffuseShader->createContext(rec, diffuseContextStorage);
-    if (!diffuseContext) {
-        sk_free(heapAllocated);
-        return nullptr;
+    void* normalProviderStorage = (char*) diffuseContextStorage +
+                                  (fDiffuseShader ? fDiffuseShader->contextSize(rec) : 0);
+
+    SkShader::Context *diffuseContext = nullptr;
+    if (fDiffuseShader) {
+        diffuseContext = fDiffuseShader->createContext(rec, diffuseContextStorage);
+        if (!diffuseContext) {
+            sk_free(heapAllocated);
+            return nullptr;
+        }
     }
 
-    void* normalProviderStorage = (char*)heapAllocated + fDiffuseShader->contextSize(rec);
     SkNormalSource::Provider* normalProvider = fNormalSource->asProvider(rec,
                                                                          normalProviderStorage);
     if (!normalProvider) {
@@ -491,10 +525,8 @@
 sk_sp<SkShader> SkLightingShader::Make(sk_sp<SkShader> diffuseShader,
                                        sk_sp<SkNormalSource> normalSource,
                                        sk_sp<SkLights> lights) {
-    if (!diffuseShader || !normalSource) {
-        // TODO: Use paint's color in absence of a diffuseShader
-        // TODO: Use a default implementation of normalSource instead
-        return nullptr;
+    if (!normalSource) {
+        normalSource = SkNormalSource::MakeFlat();
     }
 
     return sk_make_sp<SkLightingShaderImpl>(std::move(diffuseShader), std::move(normalSource),
diff --git a/src/core/SkLightingShader.h b/src/core/SkLightingShader.h
index bb64261..41e3ca2 100644
--- a/src/core/SkLightingShader.h
+++ b/src/core/SkLightingShader.h
@@ -23,8 +23,10 @@
         It returns a shader with a reference count of 1.
         The caller should decrement the shader's reference count when done with the shader.
         It is an error for count to be < 2.
-        @param  diffuseShader     the shader that provides the colors
-        @param  normalSource      the source for the shape's normals
+        @param  diffuseShader     the shader that provides the colors. If nullptr, uses the paint's
+                                  color.
+        @param  normalSource      the source for the shape's normals. If nullptr, assumes straight
+                                  up normals (<0,0,1>).
         @param  lights            the lights applied to the normals
 
         The lighting equation is currently:
diff --git a/src/core/SkNormalSource.cpp b/src/core/SkNormalSource.cpp
index ce72532..c082d84 100644
--- a/src/core/SkNormalSource.cpp
+++ b/src/core/SkNormalSource.cpp
@@ -10,6 +10,7 @@
 #include "SkLightingShader.h"
 #include "SkMatrix.h"
 #include "SkNormalSource.h"
+#include "SkPM4f.h"
 #include "SkReadBuffer.h"
 #include "SkWriteBuffer.h"
 
@@ -46,15 +47,19 @@
 private:
     class Provider : public SkNormalSource::Provider {
     public:
-        Provider(const NormalMapSourceImpl& source, SkShader::Context* fMapContext);
+        Provider(const NormalMapSourceImpl& source, SkShader::Context* mapContext,
+                 SkPaint* overridePaint);
 
         virtual ~Provider() override;
 
         void fillScanLine(int x, int y, SkPoint3 output[], int count) const override;
+
     private:
         const NormalMapSourceImpl& fSource;
         SkShader::Context* fMapContext;
 
+        SkPaint* fOverridePaint;
+
         typedef SkNormalSource::Provider INHERITED;
     };
 
@@ -105,17 +110,14 @@
             fragBuilder->codeAppendf("vec3 normal = normalize(%s.rgb - vec3(0.5));",
                                      dstNormalColorName.c_str());
 
-            // TODO: inverse map the light direction vectors in the vertex shader rather than
-            // transforming all the normals here!
-
             // If there's no x & y components, return (0, 0, +/- 1) instead to avoid division by 0
             fragBuilder->codeAppend( "if (abs(normal.z) > 0.999) {");
             fragBuilder->codeAppendf("    %s = normalize(vec4(0.0, 0.0, normal.z, 0.0));",
                     args.fOutputColor);
             // Else, Normalizing the transformed X and Y, while keeping constant both Z and the
             // vector's angle in the XY plane. This maintains the "slope" for the surface while
-            // appropriately rotating the normal for any anisotropic scaling that occurs.
-            // Here, we call scaling factor the number that must divide the transformed X and Y so
+            // appropriately rotating the normal regardless of any anisotropic scaling that occurs.
+            // Here, we call 'scaling factor' the number that must divide the transformed X and Y so
             // that the normal's length remains equal to 1.
             fragBuilder->codeAppend( "} else {");
             fragBuilder->codeAppendf("    vec2 transformed = %s * normal.xy;",
@@ -195,13 +197,15 @@
 ////////////////////////////////////////////////////////////////////////////
 
 NormalMapSourceImpl::Provider::Provider(const NormalMapSourceImpl& source,
-                                        SkShader::Context* mapContext)
+                                        SkShader::Context* mapContext,
+                                        SkPaint* overridePaint)
     : fSource(source)
-    , fMapContext(mapContext) {
-}
+    , fMapContext(mapContext)
+    , fOverridePaint(overridePaint) {}
 
 NormalMapSourceImpl::Provider::~Provider() {
     fMapContext->~Context();
+    fOverridePaint->~SkPaint();
 }
 
 SkNormalSource::Provider* NormalMapSourceImpl::asProvider(
@@ -211,17 +215,24 @@
         return nullptr;
     }
 
-    void* mapContextStorage = (char*)storage + sizeof(Provider);
-    SkShader::Context* context = fMapShader->createContext(rec, mapContextStorage);
+    // Overriding paint's alpha because we need the normal map's RGB channels to be unpremul'd
+    void* paintStorage = (char*)storage + sizeof(Provider);
+    SkPaint* overridePaint = new (paintStorage) SkPaint(*(rec.fPaint));
+    overridePaint->setAlpha(0xFF);
+    SkShader::ContextRec overrideRec(*overridePaint, *(rec.fMatrix), rec.fLocalMatrix,
+                                     rec.fPreferredDstType);
+
+    void* mapContextStorage = (char*) paintStorage + sizeof(SkPaint);
+    SkShader::Context* context = fMapShader->createContext(overrideRec, mapContextStorage);
     if (!context) {
         return nullptr;
     }
 
-    return new (storage) Provider(*this, context);
+    return new (storage) Provider(*this, context, overridePaint);
 }
 
 size_t NormalMapSourceImpl::providerSize(const SkShader::ContextRec& rec) const {
-    return sizeof(Provider) + fMapShader->contextSize(rec);
+    return sizeof(Provider) + sizeof(SkPaint) + fMapShader->contextSize(rec);
 }
 
 bool NormalMapSourceImpl::computeNormTotalInverse(const SkShader::ContextRec& rec,
@@ -253,8 +264,10 @@
             tempNorm.set(SkIntToScalar(SkGetPackedR32(tmpNormalColors[i])) - 127.0f,
                          SkIntToScalar(SkGetPackedG32(tmpNormalColors[i])) - 127.0f,
                          SkIntToScalar(SkGetPackedB32(tmpNormalColors[i])) - 127.0f);
+
             tempNorm.normalize();
 
+
             if (!SkScalarNearlyEqual(SkScalarAbs(tempNorm.fZ), 1.0f)) {
                 SkVector transformed = fSource.fInvCTM.mapVector(tempNorm.fX, tempNorm.fY);
 
@@ -316,10 +329,151 @@
     return sk_make_sp<NormalMapSourceImpl>(std::move(map), invCTM);
 }
 
+///////////////////////////////////////////////////////////////////////////////
+
+class SK_API NormalFlatSourceImpl : public SkNormalSource {
+public:
+    NormalFlatSourceImpl(){}
+
+#if SK_SUPPORT_GPU
+    sk_sp<GrFragmentProcessor> asFragmentProcessor(GrContext*,
+                                                   const SkMatrix& viewM,
+                                                   const SkMatrix* localMatrix,
+                                                   SkFilterQuality,
+                                                   SkSourceGammaTreatment) const override;
+#endif
+
+    SkNormalSource::Provider* asProvider(const SkShader::ContextRec& rec,
+                                         void* storage) const override;
+    size_t providerSize(const SkShader::ContextRec& rec) const override;
+
+    SK_DECLARE_PUBLIC_FLATTENABLE_DESERIALIZATION_PROCS(NormalFlatSourceImpl)
+
+protected:
+    void flatten(SkWriteBuffer& buf) const override;
+
+private:
+    class Provider : public SkNormalSource::Provider {
+    public:
+        Provider();
+
+        virtual ~Provider();
+
+        void fillScanLine(int x, int y, SkPoint3 output[], int count) const override;
+
+    private:
+        typedef SkNormalSource::Provider INHERITED;
+    };
+
+    friend class SkNormalSource;
+
+    typedef SkNormalSource INHERITED;
+};
+
+////////////////////////////////////////////////////////////////////////////
+
+#if SK_SUPPORT_GPU
+
+class NormalFlatFP : public GrFragmentProcessor {
+public:
+    NormalFlatFP() {
+        this->initClassID<NormalFlatFP>();
+    }
+
+    class GLSLNormalFlatFP : public GrGLSLFragmentProcessor {
+    public:
+        GLSLNormalFlatFP() {}
+
+        void emitCode(EmitArgs& args) override {
+            GrGLSLFragmentBuilder* fragBuilder = args.fFragBuilder;
+
+            fragBuilder->codeAppendf("%s = vec4(0, 0, 1, 0);", args.fOutputColor);
+        }
+
+        static void GenKey(const GrProcessor& proc, const GrGLSLCaps&,
+                           GrProcessorKeyBuilder* b) {
+            b->add32(0x0);
+        }
+
+    protected:
+        void onSetData(const GrGLSLProgramDataManager& pdman, const GrProcessor& proc) override {}
+    };
+
+    void onGetGLSLProcessorKey(const GrGLSLCaps& caps, GrProcessorKeyBuilder* b) const override {
+        GLSLNormalFlatFP::GenKey(*this, caps, b);
+    }
+
+    const char* name() const override { return "NormalFlatFP"; }
+
+    void onComputeInvariantOutput(GrInvariantOutput* inout) const override {
+        inout->setToUnknown(GrInvariantOutput::ReadInput::kWillNot_ReadInput);
+    }
+
+private:
+    GrGLSLFragmentProcessor* onCreateGLSLInstance() const override { return new GLSLNormalFlatFP; }
+
+    bool onIsEqual(const GrFragmentProcessor& proc) const override {
+        return true;
+    }
+};
+
+sk_sp<GrFragmentProcessor> NormalFlatSourceImpl::asFragmentProcessor(
+                                                     GrContext *context,
+                                                     const SkMatrix &viewM,
+                                                     const SkMatrix *localMatrix,
+                                                     SkFilterQuality filterQuality,
+                                                     SkSourceGammaTreatment gammaTreatment) const {
+
+    return sk_make_sp<NormalFlatFP>();
+}
+
+#endif // SK_SUPPORT_GPU
+
+////////////////////////////////////////////////////////////////////////////
+
+NormalFlatSourceImpl::Provider::Provider() {}
+
+NormalFlatSourceImpl::Provider::~Provider() {}
+
+SkNormalSource::Provider* NormalFlatSourceImpl::asProvider(const SkShader::ContextRec &rec,
+                                                           void *storage) const {
+    return new (storage) Provider();
+}
+
+size_t NormalFlatSourceImpl::providerSize(const SkShader::ContextRec&) const {
+    return sizeof(Provider);
+}
+
+void NormalFlatSourceImpl::Provider::fillScanLine(int x, int y, SkPoint3 output[],
+                                                  int count) const {
+    for (int i = 0; i < count; i++) {
+        output[i] = {0.0f, 0.0f, 1.0f};
+    }
+}
+
+////////////////////////////////////////////////////////////////////////////////
+
+sk_sp<SkFlattenable> NormalFlatSourceImpl::CreateProc(SkReadBuffer& buf) {
+    return sk_make_sp<NormalFlatSourceImpl>();
+}
+
+void NormalFlatSourceImpl::flatten(SkWriteBuffer& buf) const {
+    this->INHERITED::flatten(buf);
+}
+
+////////////////////////////////////////////////////////////////////////////
+
+sk_sp<SkNormalSource> SkNormalSource::MakeFlat() {
+    return sk_make_sp<NormalFlatSourceImpl>();
+}
+
+////////////////////////////////////////////////////////////////////////////
+
 ////////////////////////////////////////////////////////////////////////////
 
 SK_DEFINE_FLATTENABLE_REGISTRAR_GROUP_START(SkNormalSource)
     SK_DEFINE_FLATTENABLE_REGISTRAR_ENTRY(NormalMapSourceImpl)
+    SK_DEFINE_FLATTENABLE_REGISTRAR_ENTRY(NormalFlatSourceImpl)
 SK_DEFINE_FLATTENABLE_REGISTRAR_GROUP_END
 
 ////////////////////////////////////////////////////////////////////////////
diff --git a/src/core/SkNormalSource.h b/src/core/SkNormalSource.h
index 7517770..e46e2da 100644
--- a/src/core/SkNormalSource.h
+++ b/src/core/SkNormalSource.h
@@ -65,6 +65,10 @@
     */
     static sk_sp<SkNormalSource> MakeFromNormalMap(sk_sp<SkShader> map, const SkMatrix& ctm);
 
+    /** Returns a normal source that provides straight-up normals only (0, 0, 1).
+    */
+    static sk_sp<SkNormalSource> MakeFlat();
+
     SK_DEFINE_FLATTENABLE_TYPE(SkNormalSource)
     SK_DECLARE_FLATTENABLE_REGISTRAR_GROUP()
 };
diff --git a/src/gpu/GrFragmentProcessor.cpp b/src/gpu/GrFragmentProcessor.cpp
index bad7ebe..fdba610 100644
--- a/src/gpu/GrFragmentProcessor.cpp
+++ b/src/gpu/GrFragmentProcessor.cpp
@@ -143,6 +143,46 @@
                                                              SkXfermode::kDstIn_Mode);
 }
 
+sk_sp<GrFragmentProcessor> GrFragmentProcessor::PremulInput(sk_sp<GrFragmentProcessor> fp) {
+
+    class PremulInputFragmentProcessor : public GrFragmentProcessor {
+    public:
+        PremulInputFragmentProcessor() {
+            this->initClassID<PremulInputFragmentProcessor>();
+        }
+
+        const char* name() const override { return "PremultiplyInput"; }
+
+    private:
+        GrGLSLFragmentProcessor* onCreateGLSLInstance() const override {
+            class GLFP : public GrGLSLFragmentProcessor {
+            public:
+                void emitCode(EmitArgs& args) override {
+                    GrGLSLFPFragmentBuilder* fragBuilder = args.fFragBuilder;
+
+                    fragBuilder->codeAppendf("%s = %s;", args.fOutputColor, args.fInputColor);
+                    fragBuilder->codeAppendf("%s.rgb *= %s.a;",
+                                             args.fOutputColor, args.fInputColor);
+                }
+            };
+            return new GLFP;
+        }
+
+        void onGetGLSLProcessorKey(const GrGLSLCaps&, GrProcessorKeyBuilder*) const override {}
+
+        bool onIsEqual(const GrFragmentProcessor&) const override { return true; }
+
+        void onComputeInvariantOutput(GrInvariantOutput* inout) const override {
+            inout->premulFourChannelColor();
+        }
+    };
+    if (!fp) {
+        return nullptr;
+    }
+    sk_sp<GrFragmentProcessor> fpPipeline[] = { sk_make_sp<PremulInputFragmentProcessor>(), fp};
+    return GrFragmentProcessor::RunInSeries(fpPipeline, 2);
+}
+
 sk_sp<GrFragmentProcessor> GrFragmentProcessor::MulOutputByInputUnpremulColor(
     sk_sp<GrFragmentProcessor> fp) {
 
diff --git a/tests/SerializationTest.cpp b/tests/SerializationTest.cpp
index 88f88fc..f806c4a 100644
--- a/tests/SerializationTest.cpp
+++ b/tests/SerializationTest.cpp
@@ -587,13 +587,26 @@
                                                                                ctm);
         sk_sp<SkShader> diffuseShader = SkMakeBitmapShader(diffuse, SkShader::kClamp_TileMode,
                 SkShader::kClamp_TileMode, &matrix, nullptr);
-        sk_sp<SkShader> lightingShader = SkLightingShader::Make(std::move(diffuseShader),
-                                                                std::move(normalSource),
+
+        sk_sp<SkShader> lightingShader = SkLightingShader::Make(diffuseShader,
+                                                                normalSource,
                                                                 fLights);
-
         SkAutoTUnref<SkShader>(TestFlattenableSerialization(lightingShader.get(), true, reporter));
-        // TODO test equality?
 
+        lightingShader = SkLightingShader::Make(std::move(diffuseShader),
+                                                nullptr,
+                                                fLights);
+        SkAutoTUnref<SkShader>(TestFlattenableSerialization(lightingShader.get(), true, reporter));
+
+        lightingShader = SkLightingShader::Make(nullptr,
+                                                std::move(normalSource),
+                                                fLights);
+        SkAutoTUnref<SkShader>(TestFlattenableSerialization(lightingShader.get(), true, reporter));
+
+        lightingShader = SkLightingShader::Make(nullptr,
+                                                nullptr,
+                                                fLights);
+        SkAutoTUnref<SkShader>(TestFlattenableSerialization(lightingShader.get(), true, reporter));
     }
 }