Revert of Revert of Extract most of the mutable state of SkShader into a separate Context object. (https://codereview.chromium.org/249643002/)

Reason for revert:
Chromium side change landed along side DEPS roll that includes r14323.

Original issue's description:
> Revert of Extract most of the mutable state of SkShader into a separate Context object. (https://codereview.chromium.org/207683004/)
>
> Reason for revert:
> This is blocking the DEPS roll into Chromium. Failures can be seen here:
>
> http://build.chromium.org/p/tryserver.chromium/builders/android_dbg/builds/174333
>
> Original issue's description:
> > Extract most of the mutable state of SkShader into a separate Context object.
> >
> > SkShader currently stores some state during draw calls via setContext(...).
> > Move that mutable state into a separate SkShader::Context class that is
> > constructed on demand for the duration of the draw.
> >
> > Calls to setContext() are replaced with createContext() which returns a context
> > corresponding to the shader object or NULL if the parameters to createContext
> > are invalid.
> >
> > TEST=out/Debug/dm
> > BUG=skia:1976
> >
> > Committed: http://code.google.com/p/skia/source/detail?r=14216
> >
> > Committed: http://code.google.com/p/skia/source/detail?r=14323
>
> TBR=scroggo@google.com,skyostil@chromium.org,tomhudson@chromium.org,senorblanco@chromium.org,reed@google.com,bungeman@google.com,dominikg@chromium.org
> NOTREECHECKS=true
> NOTRY=true
> BUG=skia:1976
>
> Committed: http://code.google.com/p/skia/source/detail?r=14326

R=scroggo@google.com, skyostil@chromium.org, tomhudson@chromium.org, senorblanco@chromium.org, reed@google.com, bungeman@google.com, dominikg@chromium.org
TBR=bungeman@google.com, dominikg@chromium.org, reed@google.com, scroggo@google.com, senorblanco@chromium.org, skyostil@chromium.org, tomhudson@chromium.org
NOTREECHECKS=true
NOTRY=true
BUG=skia:1976

Author: bsalomon@google.com

Review URL: https://codereview.chromium.org/246403013

git-svn-id: http://skia.googlecode.com/svn/trunk@14328 2bbb7eff-a529-9590-31e7-b0007b416f81
diff --git a/include/core/SkColorShader.h b/include/core/SkColorShader.h
index 975156c..56e5add 100644
--- a/include/core/SkColorShader.h
+++ b/include/core/SkColorShader.h
@@ -30,16 +30,35 @@
     */
     SkColorShader(SkColor c);
 
-    virtual ~SkColorShader();
-
-    virtual uint32_t getFlags() SK_OVERRIDE;
-    virtual uint8_t getSpan16Alpha() const SK_OVERRIDE;
     virtual bool isOpaque() const SK_OVERRIDE;
-    virtual bool setContext(const SkBitmap& device, const SkPaint& paint,
-                            const SkMatrix& matrix) SK_OVERRIDE;
-    virtual void shadeSpan(int x, int y, SkPMColor span[], int count) SK_OVERRIDE;
-    virtual void shadeSpan16(int x, int y, uint16_t span[], int count) SK_OVERRIDE;
-    virtual void shadeSpanAlpha(int x, int y, uint8_t alpha[], int count) SK_OVERRIDE;
+
+    virtual SkShader::Context* createContext(const SkBitmap& device,
+                                             const SkPaint& paint,
+                                             const SkMatrix& matrix,
+                                             void* storage) const SK_OVERRIDE;
+
+    virtual size_t contextSize() const SK_OVERRIDE {
+        return sizeof(ColorShaderContext);
+    }
+
+    class ColorShaderContext : public SkShader::Context {
+    public:
+        ColorShaderContext(const SkColorShader& shader, const SkBitmap& device,
+                           const SkPaint& paint, const SkMatrix& matrix);
+
+        virtual uint32_t getFlags() const SK_OVERRIDE;
+        virtual uint8_t getSpan16Alpha() const SK_OVERRIDE;
+        virtual void shadeSpan(int x, int y, SkPMColor span[], int count) SK_OVERRIDE;
+        virtual void shadeSpan16(int x, int y, uint16_t span[], int count) SK_OVERRIDE;
+        virtual void shadeSpanAlpha(int x, int y, uint8_t alpha[], int count) SK_OVERRIDE;
+
+    private:
+        SkPMColor   fPMColor;
+        uint32_t    fFlags;
+        uint16_t    fColor16;
+
+        typedef SkShader::Context INHERITED;
+    };
 
     // we return false for this, use asAGradient
     virtual BitmapType asABitmap(SkBitmap* outTexture,
@@ -56,11 +75,7 @@
     virtual void flatten(SkWriteBuffer&) const SK_OVERRIDE;
 
 private:
-
     SkColor     fColor;         // ignored if fInheritColor is true
-    SkPMColor   fPMColor;       // cached after setContext()
-    uint32_t    fFlags;         // cached after setContext()
-    uint16_t    fColor16;       // cached after setContext()
     SkBool8     fInheritColor;
 
     typedef SkShader INHERITED;
diff --git a/include/core/SkComposeShader.h b/include/core/SkComposeShader.h
index b54e5ef..d42da0c 100644
--- a/include/core/SkComposeShader.h
+++ b/include/core/SkComposeShader.h
@@ -34,10 +34,38 @@
     SkComposeShader(SkShader* sA, SkShader* sB, SkXfermode* mode = NULL);
     virtual ~SkComposeShader();
 
-    virtual bool setContext(const SkBitmap&, const SkPaint&,
-                            const SkMatrix&) SK_OVERRIDE;
-    virtual void endContext() SK_OVERRIDE;
-    virtual void shadeSpan(int x, int y, SkPMColor[], int count) SK_OVERRIDE;
+    virtual bool validContext(const SkBitmap&, const SkPaint&,
+                              const SkMatrix&, SkMatrix* totalInverse = NULL) const SK_OVERRIDE;
+    virtual SkShader::Context* createContext(const SkBitmap&, const SkPaint&,
+                                             const SkMatrix&, void*) const SK_OVERRIDE;
+    virtual size_t contextSize() const SK_OVERRIDE;
+
+    class ComposeShaderContext : public SkShader::Context {
+    public:
+        // When this object gets destroyed, it will call contextA and contextB's destructor
+        // but it will NOT free the memory.
+        ComposeShaderContext(const SkComposeShader&, const SkBitmap&,
+                             const SkPaint&, const SkMatrix&,
+                             SkShader::Context* contextA, SkShader::Context* contextB);
+
+        SkShader::Context* getShaderContextA() const { return fShaderContextA; }
+        SkShader::Context* getShaderContextB() const { return fShaderContextB; }
+
+        virtual ~ComposeShaderContext();
+
+        virtual void shadeSpan(int x, int y, SkPMColor[], int count) SK_OVERRIDE;
+
+    private:
+        SkShader::Context* fShaderContextA;
+        SkShader::Context* fShaderContextB;
+
+        typedef SkShader::Context INHERITED;
+    };
+
+#ifdef SK_DEBUG
+    SkShader* getShaderA() { return fShaderA; }
+    SkShader* getShaderB() { return fShaderB; }
+#endif
 
     SK_TO_STRING_OVERRIDE()
     SK_DECLARE_PUBLIC_FLATTENABLE_DESERIALIZATION_PROCS(SkComposeShader)
@@ -47,7 +75,6 @@
     virtual void flatten(SkWriteBuffer&) const SK_OVERRIDE;
 
 private:
-
     SkShader*   fShaderA;
     SkShader*   fShaderB;
     SkXfermode* fMode;
diff --git a/include/core/SkEmptyShader.h b/include/core/SkEmptyShader.h
index d2ebb61..7494eff 100644
--- a/include/core/SkEmptyShader.h
+++ b/include/core/SkEmptyShader.h
@@ -15,20 +15,28 @@
 
 /**
  *  \class SkEmptyShader
- *  A Shader that always draws nothing. Its setContext always returns false,
- *  so it never expects that its shadeSpan() methods will get called.
+ *  A Shader that always draws nothing. Its createContext always returns NULL.
  */
 class SK_API SkEmptyShader : public SkShader {
 public:
     SkEmptyShader() {}
 
-    virtual uint32_t getFlags() SK_OVERRIDE;
-    virtual uint8_t getSpan16Alpha() const SK_OVERRIDE;
-    virtual bool setContext(const SkBitmap&, const SkPaint&,
-                            const SkMatrix&) SK_OVERRIDE;
-    virtual void shadeSpan(int x, int y, SkPMColor span[], int count) SK_OVERRIDE;
-    virtual void shadeSpan16(int x, int y, uint16_t span[], int count) SK_OVERRIDE;
-    virtual void shadeSpanAlpha(int x, int y, uint8_t alpha[], int count) SK_OVERRIDE;
+    virtual size_t contextSize() const SK_OVERRIDE {
+        // Even though createContext returns NULL we have to return a value of at least
+        // sizeof(SkShader::Context) to satisfy SkSmallAllocator.
+        return sizeof(SkShader::Context);
+    }
+
+    virtual bool validContext(const SkBitmap&, const SkPaint&,
+                              const SkMatrix&, SkMatrix* totalInverse = NULL) const SK_OVERRIDE {
+        return false;
+    }
+
+    virtual SkShader::Context* createContext(const SkBitmap&, const SkPaint&,
+                                             const SkMatrix&, void*) const SK_OVERRIDE {
+        // validContext returns false.
+        return NULL;
+    }
 
     SK_TO_STRING_OVERRIDE()
     SK_DECLARE_PUBLIC_FLATTENABLE_DESERIALIZATION_PROCS(SkEmptyShader)
diff --git a/include/core/SkShader.h b/include/core/SkShader.h
index 6566e69..cc2cc75 100644
--- a/include/core/SkShader.h
+++ b/include/core/SkShader.h
@@ -38,7 +38,7 @@
     virtual ~SkShader();
 
     /**
-     * Returns true if the local matrix is not an identity matrix.
+     *  Returns true if the local matrix is not an identity matrix.
      */
     bool hasLocalMatrix() const { return !fLocalMatrix.isIdentity(); }
 
@@ -96,7 +96,7 @@
         */
         kIntrinsicly16_Flag = 0x04,
 
-        /** set (after setContext) if the spans only vary in X (const in Y).
+        /** set if the spans only vary in X (const in Y).
             e.g. an Nx1 bitmap that is being tiled in Y, or a linear-gradient
             that varies from left-to-right. This flag specifies this for
             shadeSpan().
@@ -112,84 +112,111 @@
     };
 
     /**
-     *  Called sometimes before drawing with this shader. Return the type of
-     *  alpha your shader will return. The default implementation returns 0.
-     *  Your subclass should override if it can (even sometimes) report a
-     *  non-zero value, since that will enable various blitters to perform
-     *  faster.
-     */
-    virtual uint32_t getFlags() { return 0; }
-
-    /**
      *  Returns true if the shader is guaranteed to produce only opaque
      *  colors, subject to the SkPaint using the shader to apply an opaque
      *  alpha value. Subclasses should override this to allow some
-     *  optimizations.  isOpaque() can be called at any time, unlike getFlags,
-     *  which only works properly when the context is set.
+     *  optimizations.
      */
     virtual bool isOpaque() const { return false; }
 
-    /**
-     *  Return the alpha associated with the data returned by shadeSpan16(). If
-     *  kHasSpan16_Flag is not set, this value is meaningless.
-     */
-    virtual uint8_t getSpan16Alpha() const { return fPaintAlpha; }
+    class Context : public ::SkNoncopyable {
+    public:
+        Context(const SkShader& shader, const SkBitmap& device,
+                const SkPaint& paint, const SkMatrix& matrix);
+
+        virtual ~Context();
+
+        /**
+         *  Called sometimes before drawing with this shader. Return the type of
+         *  alpha your shader will return. The default implementation returns 0.
+         *  Your subclass should override if it can (even sometimes) report a
+         *  non-zero value, since that will enable various blitters to perform
+         *  faster.
+         */
+        virtual uint32_t getFlags() const { return 0; }
+
+        /**
+         *  Return the alpha associated with the data returned by shadeSpan16(). If
+         *  kHasSpan16_Flag is not set, this value is meaningless.
+         */
+        virtual uint8_t getSpan16Alpha() const { return fPaintAlpha; }
+
+        /**
+         *  Called for each span of the object being drawn. Your subclass should
+         *  set the appropriate colors (with premultiplied alpha) that correspond
+         *  to the specified device coordinates.
+         */
+        virtual void shadeSpan(int x, int y, SkPMColor[], int count) = 0;
+
+        typedef void (*ShadeProc)(void* ctx, int x, int y, SkPMColor[], int count);
+        virtual ShadeProc asAShadeProc(void** ctx);
+
+        /**
+         *  Called only for 16bit devices when getFlags() returns
+         *  kOpaqueAlphaFlag | kHasSpan16_Flag
+         */
+        virtual void shadeSpan16(int x, int y, uint16_t[], int count);
+
+        /**
+         *  Similar to shadeSpan, but only returns the alpha-channel for a span.
+         *  The default implementation calls shadeSpan() and then extracts the alpha
+         *  values from the returned colors.
+         */
+        virtual void shadeSpanAlpha(int x, int y, uint8_t alpha[], int count);
+
+        /**
+         *  Helper function that returns true if this shader's shadeSpan16() method
+         *  can be called.
+         */
+        bool canCallShadeSpan16() {
+            return SkShader::CanCallShadeSpan16(this->getFlags());
+        }
+
+    protected:
+        // Reference to shader, so we don't have to dupe information.
+        const SkShader& fShader;
+
+        enum MatrixClass {
+            kLinear_MatrixClass,            // no perspective
+            kFixedStepInX_MatrixClass,      // fast perspective, need to call fixedStepInX() each
+                                            // scanline
+            kPerspective_MatrixClass        // slow perspective, need to mappoints each pixel
+        };
+        static MatrixClass ComputeMatrixClass(const SkMatrix&);
+
+        uint8_t             getPaintAlpha() const { return fPaintAlpha; }
+        const SkMatrix&     getTotalInverse() const { return fTotalInverse; }
+        MatrixClass         getInverseClass() const { return (MatrixClass)fTotalInverseClass; }
+
+    private:
+        SkMatrix            fTotalInverse;
+        uint8_t             fPaintAlpha;
+        uint8_t             fTotalInverseClass;
+
+        typedef SkNoncopyable INHERITED;
+    };
 
     /**
-     *  Called once before drawing, with the current paint and device matrix.
-     *  Return true if your shader supports these parameters, or false if not.
-     *  If false is returned, nothing will be drawn. If true is returned, then
-     *  a balancing call to endContext() will be made before the next call to
-     *  setContext.
-     *
-     *  Subclasses should be sure to call their INHERITED::setContext() if they
-     *  override this method.
+     *  Subclasses should be sure to call their INHERITED::validContext() if
+     *  they override this method.
      */
-    virtual bool setContext(const SkBitmap& device, const SkPaint& paint,
-                            const SkMatrix& matrix);
+    virtual bool validContext(const SkBitmap& device, const SkPaint& paint,
+                              const SkMatrix& matrix, SkMatrix* totalInverse = NULL) const;
 
     /**
-     *  Assuming setContext returned true, endContext() will be called when
-     *  the draw using the shader has completed. It is an error for setContext
-     *  to be called twice w/o an intervening call to endContext().
-     *
-     *  Subclasses should be sure to call their INHERITED::endContext() if they
-     *  override this method.
+     *  Create the actual object that does the shading.
+     *  Returns NULL if validContext() returns false.
+     *  Size of storage must be >= contextSize.
      */
-    virtual void endContext();
-
-    SkDEBUGCODE(bool setContextHasBeenCalled() const { return SkToBool(fInSetContext); })
+    virtual Context* createContext(const SkBitmap& device,
+                                   const SkPaint& paint,
+                                   const SkMatrix& matrix,
+                                   void* storage) const = 0;
 
     /**
-     *  Called for each span of the object being drawn. Your subclass should
-     *  set the appropriate colors (with premultiplied alpha) that correspond
-     *  to the specified device coordinates.
+     *  Return the size of a Context returned by createContext.
      */
-    virtual void shadeSpan(int x, int y, SkPMColor[], int count) = 0;
-
-    typedef void (*ShadeProc)(void* ctx, int x, int y, SkPMColor[], int count);
-    virtual ShadeProc asAShadeProc(void** ctx);
-
-    /**
-     *  Called only for 16bit devices when getFlags() returns
-     *  kOpaqueAlphaFlag | kHasSpan16_Flag
-     */
-    virtual void shadeSpan16(int x, int y, uint16_t[], int count);
-
-    /**
-     *  Similar to shadeSpan, but only returns the alpha-channel for a span.
-     *  The default implementation calls shadeSpan() and then extracts the alpha
-     *  values from the returned colors.
-     */
-    virtual void shadeSpanAlpha(int x, int y, uint8_t alpha[], int count);
-
-    /**
-     *  Helper function that returns true if this shader's shadeSpan16() method
-     *  can be called.
-     */
-    bool canCallShadeSpan16() {
-        return SkShader::CanCallShadeSpan16(this->getFlags());
-    }
+    virtual size_t contextSize() const = 0;
 
     /**
      *  Helper to check the flags to know if it is legal to call shadeSpan16()
@@ -322,7 +349,7 @@
      *  The incoming color to the effect has r=g=b=a all extracted from the SkPaint's alpha.
      *  The output color should be the computed SkShader premul color modulated by the incoming
      *  color. The GrContext may be used by the effect to create textures. The GPU device does not
-     *  call setContext. Instead we pass the SkPaint here in case the shader needs paint info.
+     *  call createContext. Instead we pass the SkPaint here in case the shader needs paint info.
      */
     virtual GrEffectRef* asNewEffect(GrContext* context, const SkPaint& paint) const;
 
@@ -362,26 +389,14 @@
     SK_DEFINE_FLATTENABLE_TYPE(SkShader)
 
 protected:
-    enum MatrixClass {
-        kLinear_MatrixClass,            // no perspective
-        kFixedStepInX_MatrixClass,      // fast perspective, need to call fixedStepInX() each scanline
-        kPerspective_MatrixClass        // slow perspective, need to mappoints each pixel
-    };
-    static MatrixClass ComputeMatrixClass(const SkMatrix&);
-
-    // These can be called by your subclass after setContext() has been called
-    uint8_t             getPaintAlpha() const { return fPaintAlpha; }
-    const SkMatrix&     getTotalInverse() const { return fTotalInverse; }
-    MatrixClass         getInverseClass() const { return (MatrixClass)fTotalInverseClass; }
 
     SkShader(SkReadBuffer& );
     virtual void flatten(SkWriteBuffer&) const SK_OVERRIDE;
+
 private:
     SkMatrix            fLocalMatrix;
-    SkMatrix            fTotalInverse;
-    uint8_t             fPaintAlpha;
-    uint8_t             fTotalInverseClass;
-    SkDEBUGCODE(SkBool8 fInSetContext;)
+
+    bool computeTotalInverse(const SkMatrix& matrix, SkMatrix* totalInverse) const;
 
     typedef SkFlattenable INHERITED;
 };
diff --git a/include/effects/SkPerlinNoiseShader.h b/include/effects/SkPerlinNoiseShader.h
index dfd5a8c..5b27029 100644
--- a/include/effects/SkPerlinNoiseShader.h
+++ b/include/effects/SkPerlinNoiseShader.h
@@ -72,10 +72,32 @@
     }
 
 
-    virtual bool setContext(const SkBitmap& device, const SkPaint& paint,
-                            const SkMatrix& matrix);
-    virtual void shadeSpan(int x, int y, SkPMColor[], int count) SK_OVERRIDE;
-    virtual void shadeSpan16(int x, int y, uint16_t[], int count) SK_OVERRIDE;
+    virtual SkShader::Context* createContext(
+        const SkBitmap& device, const SkPaint& paint,
+        const SkMatrix& matrix, void* storage) const SK_OVERRIDE;
+    virtual size_t contextSize() const SK_OVERRIDE;
+
+    class PerlinNoiseShaderContext : public SkShader::Context {
+    public:
+        PerlinNoiseShaderContext(const SkPerlinNoiseShader& shader, const SkBitmap& device,
+                                 const SkPaint& paint, const SkMatrix& matrix);
+        virtual ~PerlinNoiseShaderContext() {}
+
+        virtual void shadeSpan(int x, int y, SkPMColor[], int count) SK_OVERRIDE;
+        virtual void shadeSpan16(int x, int y, uint16_t[], int count) SK_OVERRIDE;
+
+    private:
+        SkPMColor shade(const SkPoint& point, StitchData& stitchData) const;
+        SkScalar calculateTurbulenceValueForPoint(
+            int channel, const PaintingData& paintingData,
+            StitchData& stitchData, const SkPoint& point) const;
+        SkScalar noise2D(int channel, const PaintingData& paintingData,
+                         const StitchData& stitchData, const SkPoint& noiseVector) const;
+
+        SkMatrix fMatrix;
+
+        typedef SkShader::Context INHERITED;
+    };
 
     virtual GrEffectRef* asNewEffect(GrContext* context, const SkPaint&) const SK_OVERRIDE;
 
@@ -92,14 +114,6 @@
                         const SkISize* tileSize);
     virtual ~SkPerlinNoiseShader();
 
-    SkScalar noise2D(int channel, const PaintingData& paintingData,
-                     const StitchData& stitchData, const SkPoint& noiseVector) const;
-
-    SkScalar calculateTurbulenceValueForPoint(int channel, const PaintingData& paintingData,
-                                              StitchData& stitchData, const SkPoint& point) const;
-
-    SkPMColor shade(const SkPoint& point, StitchData& stitchData) const;
-
     // TODO (scroggo): Once all SkShaders are created from a factory, and we have removed the
     // constructor that creates SkPerlinNoiseShader from an SkReadBuffer, several fields can
     // be made constant.
@@ -110,8 +124,6 @@
     /*const*/ SkScalar                  fSeed;
     /*const*/ SkISize                   fTileSize;
     /*const*/ bool                      fStitchTiles;
-    // TODO (scroggo): Once setContext creates a new object, place this on that object.
-    SkMatrix fMatrix;
 
     PaintingData* fPaintingData;
 
diff --git a/include/effects/SkTransparentShader.h b/include/effects/SkTransparentShader.h
index 7428d44..790e5ae 100644
--- a/include/effects/SkTransparentShader.h
+++ b/include/effects/SkTransparentShader.h
@@ -14,21 +14,31 @@
 public:
     SkTransparentShader() {}
 
-    virtual uint32_t getFlags() SK_OVERRIDE;
-    virtual bool    setContext(const SkBitmap& device,
-                               const SkPaint& paint,
-                               const SkMatrix& matrix) SK_OVERRIDE;
-    virtual void    shadeSpan(int x, int y, SkPMColor[], int count) SK_OVERRIDE;
-    virtual void    shadeSpan16(int x, int y, uint16_t span[], int count) SK_OVERRIDE;
+    virtual SkShader::Context* createContext(const SkBitmap& device, const SkPaint& paint,
+                                             const SkMatrix& matrix, void* storage) const
+            SK_OVERRIDE;
+    virtual size_t contextSize() const SK_OVERRIDE;
+
+    class TransparentShaderContext : public SkShader::Context {
+    public:
+        TransparentShaderContext(const SkTransparentShader& shader, const SkBitmap& device,
+                                 const SkPaint& paint, const SkMatrix& matrix);
+        virtual ~TransparentShaderContext();
+
+        virtual uint32_t getFlags() const SK_OVERRIDE;
+        virtual void shadeSpan(int x, int y, SkPMColor[], int count) SK_OVERRIDE;
+        virtual void shadeSpan16(int x, int y, uint16_t span[], int count) SK_OVERRIDE;
+
+    private:
+        const SkBitmap* fDevice;
+
+        typedef SkShader::Context INHERITED;
+    };
 
     SK_TO_STRING_OVERRIDE()
     SK_DECLARE_PUBLIC_FLATTENABLE_DESERIALIZATION_PROCS(SkTransparentShader)
 
 private:
-    // these are a cache from the call to setContext()
-    const SkBitmap* fDevice;
-    uint8_t         fAlpha;
-
     SkTransparentShader(SkReadBuffer& buffer) : INHERITED(buffer) {}
 
     typedef SkShader INHERITED;
diff --git a/src/core/SkBitmapProcShader.cpp b/src/core/SkBitmapProcShader.cpp
index a397b78..5f5eb18 100644
--- a/src/core/SkBitmapProcShader.cpp
+++ b/src/core/SkBitmapProcShader.cpp
@@ -34,18 +34,16 @@
 SkBitmapProcShader::SkBitmapProcShader(const SkBitmap& src,
                                        TileMode tmx, TileMode tmy) {
     fRawBitmap = src;
-    fState.fTileModeX = (uint8_t)tmx;
-    fState.fTileModeY = (uint8_t)tmy;
-    fFlags = 0; // computed in setContext
+    fTileModeX = (uint8_t)tmx;
+    fTileModeY = (uint8_t)tmy;
 }
 
 SkBitmapProcShader::SkBitmapProcShader(SkReadBuffer& buffer)
         : INHERITED(buffer) {
     buffer.readBitmap(&fRawBitmap);
     fRawBitmap.setImmutable();
-    fState.fTileModeX = buffer.readUInt();
-    fState.fTileModeY = buffer.readUInt();
-    fFlags = 0; // computed in setContext
+    fTileModeX = buffer.readUInt();
+    fTileModeY = buffer.readUInt();
 }
 
 SkShader::BitmapType SkBitmapProcShader::asABitmap(SkBitmap* texture,
@@ -58,8 +56,8 @@
         texM->reset();
     }
     if (xy) {
-        xy[0] = (TileMode)fState.fTileModeX;
-        xy[1] = (TileMode)fState.fTileModeY;
+        xy[0] = (TileMode)fTileModeX;
+        xy[1] = (TileMode)fTileModeY;
     }
     return kDefault_BitmapType;
 }
@@ -68,8 +66,8 @@
     this->INHERITED::flatten(buffer);
 
     buffer.writeBitmap(fRawBitmap);
-    buffer.writeUInt(fState.fTileModeX);
-    buffer.writeUInt(fState.fTileModeY);
+    buffer.writeUInt(fTileModeX);
+    buffer.writeUInt(fTileModeY);
 }
 
 static bool only_scale_and_translate(const SkMatrix& matrix) {
@@ -98,25 +96,67 @@
     return true;
 }
 
-bool SkBitmapProcShader::setContext(const SkBitmap& device,
-                                    const SkPaint& paint,
-                                    const SkMatrix& matrix) {
+bool SkBitmapProcShader::validInternal(const SkBitmap& device,
+                                       const SkPaint& paint,
+                                       const SkMatrix& matrix,
+                                       SkMatrix* totalInverse,
+                                       SkBitmapProcState* state) const {
     if (!fRawBitmap.getTexture() && !valid_for_drawing(fRawBitmap)) {
         return false;
     }
 
-    // do this first, so we have a correct inverse matrix
-    if (!this->INHERITED::setContext(device, paint, matrix)) {
+    // Make sure we can use totalInverse as a cache.
+    SkMatrix totalInverseLocal;
+    if (NULL == totalInverse) {
+        totalInverse = &totalInverseLocal;
+    }
+
+    // Do this first, so we know the matrix can be inverted.
+    if (!this->INHERITED::validContext(device, paint, matrix, totalInverse)) {
         return false;
     }
 
-    fState.fOrigBitmap = fRawBitmap;
-    if (!fState.chooseProcs(this->getTotalInverse(), paint)) {
-        this->INHERITED::endContext();
-        return false;
+    SkASSERT(state);
+    state->fTileModeX = fTileModeX;
+    state->fTileModeY = fTileModeY;
+    state->fOrigBitmap = fRawBitmap;
+    return state->chooseProcs(*totalInverse, paint);
+}
+
+bool SkBitmapProcShader::validContext(const SkBitmap& device,
+                                      const SkPaint& paint,
+                                      const SkMatrix& matrix,
+                                      SkMatrix* totalInverse) const {
+    SkBitmapProcState state;
+    return this->validInternal(device, paint, matrix, totalInverse, &state);
+}
+
+SkShader::Context* SkBitmapProcShader::createContext(const SkBitmap& device, const SkPaint& paint,
+                                                     const SkMatrix& matrix, void* storage) const {
+    void* stateStorage = (char*)storage + sizeof(BitmapProcShaderContext);
+    SkBitmapProcState* state = SkNEW_PLACEMENT(stateStorage, SkBitmapProcState);
+    if (!this->validInternal(device, paint, matrix, NULL, state)) {
+        state->~SkBitmapProcState();
+        return NULL;
     }
 
-    const SkBitmap& bitmap = *fState.fBitmap;
+    return SkNEW_PLACEMENT_ARGS(storage, BitmapProcShaderContext,
+                                (*this, device, paint, matrix, state));
+}
+
+size_t SkBitmapProcShader::contextSize() const {
+    // The SkBitmapProcState is stored outside of the context object, with the context holding
+    // a pointer to it.
+    return sizeof(BitmapProcShaderContext) + sizeof(SkBitmapProcState);
+}
+
+SkBitmapProcShader::BitmapProcShaderContext::BitmapProcShaderContext(
+        const SkBitmapProcShader& shader, const SkBitmap& device,
+        const SkPaint& paint, const SkMatrix& matrix, SkBitmapProcState* state)
+    : INHERITED(shader, device, paint, matrix)
+    , fState(state)
+{
+    const SkBitmap& bitmap = *fState->fBitmap;
     bool bitmapIsOpaque = bitmap.isOpaque();
 
     // update fFlags
@@ -157,12 +197,12 @@
     }
 
     fFlags = flags;
-    return true;
 }
 
-void SkBitmapProcShader::endContext() {
-    fState.endContext();
-    this->INHERITED::endContext();
+SkBitmapProcShader::BitmapProcShaderContext::~BitmapProcShaderContext() {
+    // The bitmap proc state has been created outside of the context on memory that will be freed
+    // elsewhere. Only call the destructor but leave the freeing of the memory to the caller.
+    fState->~SkBitmapProcState();
 }
 
 #define BUF_MAX     128
@@ -176,8 +216,9 @@
     #define TEST_BUFFER_EXTRA   0
 #endif
 
-void SkBitmapProcShader::shadeSpan(int x, int y, SkPMColor dstC[], int count) {
-    const SkBitmapProcState& state = fState;
+void SkBitmapProcShader::BitmapProcShaderContext::shadeSpan(int x, int y, SkPMColor dstC[],
+                                                            int count) {
+    const SkBitmapProcState& state = *fState;
     if (state.getShaderProc32()) {
         state.getShaderProc32()(state, x, y, dstC, count);
         return;
@@ -186,7 +227,7 @@
     uint32_t buffer[BUF_MAX + TEST_BUFFER_EXTRA];
     SkBitmapProcState::MatrixProc   mproc = state.getMatrixProc();
     SkBitmapProcState::SampleProc32 sproc = state.getSampleProc32();
-    int max = fState.maxCountForBufferSize(sizeof(buffer[0]) * BUF_MAX);
+    int max = state.maxCountForBufferSize(sizeof(buffer[0]) * BUF_MAX);
 
     SkASSERT(state.fBitmap->getPixels());
     SkASSERT(state.fBitmap->pixelRef() == NULL ||
@@ -220,16 +261,17 @@
     }
 }
 
-SkShader::ShadeProc SkBitmapProcShader::asAShadeProc(void** ctx) {
-    if (fState.getShaderProc32()) {
-        *ctx = &fState;
-        return (ShadeProc)fState.getShaderProc32();
+SkShader::Context::ShadeProc SkBitmapProcShader::BitmapProcShaderContext::asAShadeProc(void** ctx) {
+    if (fState->getShaderProc32()) {
+        *ctx = fState;
+        return (ShadeProc)fState->getShaderProc32();
     }
     return NULL;
 }
 
-void SkBitmapProcShader::shadeSpan16(int x, int y, uint16_t dstC[], int count) {
-    const SkBitmapProcState& state = fState;
+void SkBitmapProcShader::BitmapProcShaderContext::shadeSpan16(int x, int y, uint16_t dstC[],
+                                                              int count) {
+    const SkBitmapProcState& state = *fState;
     if (state.getShaderProc16()) {
         state.getShaderProc16()(state, x, y, dstC, count);
         return;
@@ -238,7 +280,7 @@
     uint32_t buffer[BUF_MAX];
     SkBitmapProcState::MatrixProc   mproc = state.getMatrixProc();
     SkBitmapProcState::SampleProc16 sproc = state.getSampleProc16();
-    int max = fState.maxCountForBufferSize(sizeof(buffer));
+    int max = state.maxCountForBufferSize(sizeof(buffer));
 
     SkASSERT(state.fBitmap->getPixels());
     SkASSERT(state.fBitmap->pixelRef() == NULL ||
@@ -342,8 +384,8 @@
     str->append("BitmapShader: (");
 
     str->appendf("(%s, %s)",
-                 gTileModeName[fState.fTileModeX],
-                 gTileModeName[fState.fTileModeY]);
+                 gTileModeName[fTileModeX],
+                 gTileModeName[fTileModeY]);
 
     str->append(" ");
     fRawBitmap.toString(str);
@@ -384,8 +426,8 @@
     matrix.preConcat(lmInverse);
 
     SkShader::TileMode tm[] = {
-        (TileMode)fState.fTileModeX,
-        (TileMode)fState.fTileModeY,
+        (TileMode)fTileModeX,
+        (TileMode)fTileModeY,
     };
 
     // Must set wrap and filter on the sampler before requesting a texture. In two places below
diff --git a/src/core/SkBitmapProcShader.h b/src/core/SkBitmapProcShader.h
index 8e225a5..e0c78b8 100644
--- a/src/core/SkBitmapProcShader.h
+++ b/src/core/SkBitmapProcShader.h
@@ -20,14 +20,16 @@
 
     // overrides from SkShader
     virtual bool isOpaque() const SK_OVERRIDE;
-    virtual bool setContext(const SkBitmap&, const SkPaint&, const SkMatrix&) SK_OVERRIDE;
-    virtual void endContext() SK_OVERRIDE;
-    virtual uint32_t getFlags() SK_OVERRIDE { return fFlags; }
-    virtual void shadeSpan(int x, int y, SkPMColor dstC[], int count) SK_OVERRIDE;
-    virtual ShadeProc asAShadeProc(void** ctx) SK_OVERRIDE;
-    virtual void shadeSpan16(int x, int y, uint16_t dstC[], int count) SK_OVERRIDE;
     virtual BitmapType asABitmap(SkBitmap*, SkMatrix*, TileMode*) const SK_OVERRIDE;
 
+    virtual bool validContext(const SkBitmap& device,
+                              const SkPaint& paint,
+                              const SkMatrix& matrix,
+                              SkMatrix* totalInverse = NULL) const SK_OVERRIDE;
+    virtual SkShader::Context* createContext(const SkBitmap&, const SkPaint&,
+                                             const SkMatrix&, void* storage) const SK_OVERRIDE;
+    virtual size_t contextSize() const SK_OVERRIDE;
+
     static bool CanDo(const SkBitmap&, TileMode tx, TileMode ty);
 
     SK_TO_STRING_OVERRIDE()
@@ -37,22 +39,54 @@
     GrEffectRef* asNewEffect(GrContext*, const SkPaint&) const SK_OVERRIDE;
 #endif
 
+    class BitmapProcShaderContext : public SkShader::Context {
+    public:
+        // The context takes ownership of the state. It will call its destructor
+        // but will NOT free the memory.
+        BitmapProcShaderContext(const SkBitmapProcShader& shader,
+                                const SkBitmap& device,
+                                const SkPaint& paint,
+                                const SkMatrix& matrix,
+                                SkBitmapProcState* state);
+        virtual ~BitmapProcShaderContext();
+
+        virtual void shadeSpan(int x, int y, SkPMColor dstC[], int count) SK_OVERRIDE;
+        virtual ShadeProc asAShadeProc(void** ctx) SK_OVERRIDE;
+        virtual void shadeSpan16(int x, int y, uint16_t dstC[], int count) SK_OVERRIDE;
+
+        virtual uint32_t getFlags() const SK_OVERRIDE { return fFlags; }
+
+    private:
+        SkBitmapProcState*  fState;
+        uint32_t            fFlags;
+
+        typedef SkShader::Context INHERITED;
+    };
+
 protected:
     SkBitmapProcShader(SkReadBuffer& );
     virtual void flatten(SkWriteBuffer&) const SK_OVERRIDE;
 
-    SkBitmap          fRawBitmap;   // experimental for RLE encoding
-    SkBitmapProcState fState;
-    uint32_t          fFlags;
+    SkBitmap    fRawBitmap;   // experimental for RLE encoding
+    uint8_t     fTileModeX, fTileModeY;
 
 private:
+    bool validInternal(const SkBitmap& device, const SkPaint& paint,
+                       const SkMatrix& matrix, SkMatrix* totalInverse,
+                       SkBitmapProcState* state) const;
+
     typedef SkShader INHERITED;
 };
 
-// Commonly used allocator. It currently is only used to allocate up to 2 objects. The total
-// bytes requested is calculated using one of our large shaders plus the size of an Sk3DBlitter
-// in SkDraw.cpp
-typedef SkSmallAllocator<2, sizeof(SkBitmapProcShader) + sizeof(void*) * 2> SkTBlitterAllocator;
+// Commonly used allocator. It currently is only used to allocate up to 3 objects. The total
+// bytes requested is calculated using one of our large shaders, its context size plus the size of
+// an Sk3DBlitter in SkDraw.cpp
+// Note that some contexts may contain other contexts (e.g. for compose shaders), but we've not
+// yet found a situation where the size below isn't big enough.
+typedef SkSmallAllocator<3, sizeof(SkBitmapProcShader) +
+                            sizeof(SkBitmapProcShader::BitmapProcShaderContext) +
+                            sizeof(SkBitmapProcState) +
+                            sizeof(void*) * 2> SkTBlitterAllocator;
 
 // If alloc is non-NULL, it will be used to allocate the returned SkShader, and MUST outlive
 // the SkShader.
diff --git a/src/core/SkBitmapProcState.cpp b/src/core/SkBitmapProcState.cpp
index be87d83..eecfbbc 100644
--- a/src/core/SkBitmapProcState.cpp
+++ b/src/core/SkBitmapProcState.cpp
@@ -360,17 +360,6 @@
     return true;
 }
 
-void SkBitmapProcState::endContext() {
-    SkDELETE(fBitmapFilter);
-    fBitmapFilter = NULL;
-    fScaledBitmap.reset();
-
-    if (fScaledCacheID) {
-        SkScaledImageCache::Unlock(fScaledCacheID);
-        fScaledCacheID = NULL;
-    }
-}
-
 SkBitmapProcState::~SkBitmapProcState() {
     if (fScaledCacheID) {
         SkScaledImageCache::Unlock(fScaledCacheID);
@@ -399,6 +388,7 @@
     }
     // The above logic should have always assigned fBitmap, but in case it
     // didn't, we check for that now...
+    // TODO(dominikg): Ask humper@ if we can just use an SkASSERT(fBitmap)?
     if (NULL == fBitmap) {
         return false;
     }
@@ -487,6 +477,7 @@
     // shader will perform.
 
     fMatrixProc = this->chooseMatrixProc(trivialMatrix);
+    // TODO(dominikg): SkASSERT(fMatrixProc) instead? chooseMatrixProc never returns NULL.
     if (NULL == fMatrixProc) {
         return false;
     }
@@ -528,6 +519,7 @@
                 fPaintPMColor = SkPreMultiplyColor(paint.getColor());
                 break;
             default:
+                // TODO(dominikg): Should we ever get here? SkASSERT(false) instead?
                 return false;
         }
 
diff --git a/src/core/SkBitmapProcState.h b/src/core/SkBitmapProcState.h
index d5a354e..663bcb8 100644
--- a/src/core/SkBitmapProcState.h
+++ b/src/core/SkBitmapProcState.h
@@ -89,12 +89,6 @@
     uint8_t             fTileModeY;         // CONSTRUCTOR
     uint8_t             fFilterLevel;       // chooseProcs
 
-    /** The shader will let us know when we can release some of our resources
-      * like scaled bitmaps.
-      */
-
-    void endContext();
-
     /** Platforms implement this, and can optionally overwrite only the
         following fields:
 
diff --git a/src/core/SkBlitter.cpp b/src/core/SkBlitter.cpp
index 52a05ed..7243f52 100644
--- a/src/core/SkBlitter.cpp
+++ b/src/core/SkBlitter.cpp
@@ -26,6 +26,15 @@
 
 bool SkBlitter::isNullBlitter() const { return false; }
 
+bool SkBlitter::resetShaderContext(const SkBitmap& device, const SkPaint& paint,
+                                   const SkMatrix& matrix) {
+    return true;
+}
+
+SkShader::Context* SkBlitter::getShaderContext() const {
+    return NULL;
+}
+
 const SkBitmap* SkBlitter::justAnOpaqueColor(uint32_t* value) {
     return NULL;
 }
@@ -568,102 +577,149 @@
 public:
     Sk3DShader(SkShader* proxy) : fProxy(proxy) {
         SkSafeRef(proxy);
-        fMask = NULL;
     }
 
     virtual ~Sk3DShader() {
         SkSafeUnref(fProxy);
     }
 
-    void setMask(const SkMask* mask) { fMask = mask; }
+    virtual size_t contextSize() const SK_OVERRIDE {
+        size_t size = sizeof(Sk3DShaderContext);
+        if (fProxy) {
+            size += fProxy->contextSize();
+        }
+        return size;
+    }
 
-    virtual bool setContext(const SkBitmap& device, const SkPaint& paint,
-                            const SkMatrix& matrix) SK_OVERRIDE {
-        if (!this->INHERITED::setContext(device, paint, matrix)) {
+    virtual bool validContext(const SkBitmap& device, const SkPaint& paint,
+                              const SkMatrix& matrix, SkMatrix* totalInverse = NULL) const
+            SK_OVERRIDE
+    {
+        if (!this->INHERITED::validContext(device, paint, matrix, totalInverse)) {
             return false;
         }
         if (fProxy) {
-            if (!fProxy->setContext(device, paint, matrix)) {
-                // must keep our set/end context calls balanced
-                this->INHERITED::endContext();
-                return false;
-            }
-        } else {
-            fPMColor = SkPreMultiplyColor(paint.getColor());
+            return fProxy->validContext(device, paint, matrix);
         }
         return true;
     }
 
-    virtual void endContext() SK_OVERRIDE {
-        if (fProxy) {
-            fProxy->endContext();
+    virtual SkShader::Context* createContext(const SkBitmap& device,
+                                             const SkPaint& paint,
+                                             const SkMatrix& matrix,
+                                             void* storage) const SK_OVERRIDE
+    {
+        if (!this->validContext(device, paint, matrix)) {
+            return NULL;
         }
-        this->INHERITED::endContext();
+
+        SkShader::Context* proxyContext;
+        if (fProxy) {
+            char* proxyContextStorage = (char*) storage + sizeof(Sk3DShaderContext);
+            proxyContext = fProxy->createContext(device, paint, matrix, proxyContextStorage);
+            SkASSERT(proxyContext);
+        } else {
+            proxyContext = NULL;
+        }
+        return SkNEW_PLACEMENT_ARGS(storage, Sk3DShaderContext, (*this, device, paint, matrix,
+                                                                 proxyContext));
     }
 
-    virtual void shadeSpan(int x, int y, SkPMColor span[], int count) SK_OVERRIDE {
-        if (fProxy) {
-            fProxy->shadeSpan(x, y, span, count);
-        }
-
-        if (fMask == NULL) {
-            if (fProxy == NULL) {
-                sk_memset32(span, fPMColor, count);
+    class Sk3DShaderContext : public SkShader::Context {
+    public:
+        // Calls proxyContext's destructor but will NOT free its memory.
+        Sk3DShaderContext(const Sk3DShader& shader, const SkBitmap& device, const SkPaint& paint,
+                          const SkMatrix& matrix, SkShader::Context* proxyContext)
+            : INHERITED(shader, device, paint, matrix)
+            , fMask(NULL)
+            , fProxyContext(proxyContext)
+        {
+            if (!fProxyContext) {
+                fPMColor = SkPreMultiplyColor(paint.getColor());
             }
-            return;
         }
 
-        SkASSERT(fMask->fBounds.contains(x, y));
-        SkASSERT(fMask->fBounds.contains(x + count - 1, y));
+        virtual ~Sk3DShaderContext() {
+            if (fProxyContext) {
+                fProxyContext->~Context();
+            }
+        }
 
-        size_t          size = fMask->computeImageSize();
-        const uint8_t*  alpha = fMask->getAddr8(x, y);
-        const uint8_t*  mulp = alpha + size;
-        const uint8_t*  addp = mulp + size;
+        void setMask(const SkMask* mask) { fMask = mask; }
 
-        if (fProxy) {
-            for (int i = 0; i < count; i++) {
-                if (alpha[i]) {
-                    SkPMColor c = span[i];
-                    if (c) {
-                        unsigned a = SkGetPackedA32(c);
-                        unsigned r = SkGetPackedR32(c);
-                        unsigned g = SkGetPackedG32(c);
-                        unsigned b = SkGetPackedB32(c);
+        virtual void shadeSpan(int x, int y, SkPMColor span[], int count) SK_OVERRIDE {
+            if (fProxyContext) {
+                fProxyContext->shadeSpan(x, y, span, count);
+            }
 
+            if (fMask == NULL) {
+                if (fProxyContext == NULL) {
+                    sk_memset32(span, fPMColor, count);
+                }
+                return;
+            }
+
+            SkASSERT(fMask->fBounds.contains(x, y));
+            SkASSERT(fMask->fBounds.contains(x + count - 1, y));
+
+            size_t          size = fMask->computeImageSize();
+            const uint8_t*  alpha = fMask->getAddr8(x, y);
+            const uint8_t*  mulp = alpha + size;
+            const uint8_t*  addp = mulp + size;
+
+            if (fProxyContext) {
+                for (int i = 0; i < count; i++) {
+                    if (alpha[i]) {
+                        SkPMColor c = span[i];
+                        if (c) {
+                            unsigned a = SkGetPackedA32(c);
+                            unsigned r = SkGetPackedR32(c);
+                            unsigned g = SkGetPackedG32(c);
+                            unsigned b = SkGetPackedB32(c);
+
+                            unsigned mul = SkAlpha255To256(mulp[i]);
+                            unsigned add = addp[i];
+
+                            r = SkFastMin32(SkAlphaMul(r, mul) + add, a);
+                            g = SkFastMin32(SkAlphaMul(g, mul) + add, a);
+                            b = SkFastMin32(SkAlphaMul(b, mul) + add, a);
+
+                            span[i] = SkPackARGB32(a, r, g, b);
+                        }
+                    } else {
+                        span[i] = 0;
+                    }
+                }
+            } else {    // color
+                unsigned a = SkGetPackedA32(fPMColor);
+                unsigned r = SkGetPackedR32(fPMColor);
+                unsigned g = SkGetPackedG32(fPMColor);
+                unsigned b = SkGetPackedB32(fPMColor);
+                for (int i = 0; i < count; i++) {
+                    if (alpha[i]) {
                         unsigned mul = SkAlpha255To256(mulp[i]);
                         unsigned add = addp[i];
 
-                        r = SkFastMin32(SkAlphaMul(r, mul) + add, a);
-                        g = SkFastMin32(SkAlphaMul(g, mul) + add, a);
-                        b = SkFastMin32(SkAlphaMul(b, mul) + add, a);
-
-                        span[i] = SkPackARGB32(a, r, g, b);
+                        span[i] = SkPackARGB32( a,
+                                        SkFastMin32(SkAlphaMul(r, mul) + add, a),
+                                        SkFastMin32(SkAlphaMul(g, mul) + add, a),
+                                        SkFastMin32(SkAlphaMul(b, mul) + add, a));
+                    } else {
+                        span[i] = 0;
                     }
-                } else {
-                    span[i] = 0;
-                }
-            }
-        } else {    // color
-            unsigned a = SkGetPackedA32(fPMColor);
-            unsigned r = SkGetPackedR32(fPMColor);
-            unsigned g = SkGetPackedG32(fPMColor);
-            unsigned b = SkGetPackedB32(fPMColor);
-            for (int i = 0; i < count; i++) {
-                if (alpha[i]) {
-                    unsigned mul = SkAlpha255To256(mulp[i]);
-                    unsigned add = addp[i];
-
-                    span[i] = SkPackARGB32( a,
-                                    SkFastMin32(SkAlphaMul(r, mul) + add, a),
-                                    SkFastMin32(SkAlphaMul(g, mul) + add, a),
-                                    SkFastMin32(SkAlphaMul(b, mul) + add, a));
-                } else {
-                    span[i] = 0;
                 }
             }
         }
-    }
+
+    private:
+        // Unowned.
+        const SkMask*       fMask;
+        // Memory is unowned, but we need to call the destructor.
+        SkShader::Context*  fProxyContext;
+        SkPMColor           fPMColor;
+
+        typedef SkShader::Context INHERITED;
+    };
 
 #ifndef SK_IGNORE_TO_STRING
     virtual void toString(SkString* str) const SK_OVERRIDE {
@@ -685,29 +741,30 @@
 protected:
     Sk3DShader(SkReadBuffer& buffer) : INHERITED(buffer) {
         fProxy = buffer.readShader();
-        fPMColor = buffer.readColor();
-        fMask = NULL;
+        // Leaving this here until we bump the picture version, though this
+        // shader should never be recorded.
+        buffer.readColor();
     }
 
     virtual void flatten(SkWriteBuffer& buffer) const SK_OVERRIDE {
         this->INHERITED::flatten(buffer);
         buffer.writeFlattenable(fProxy);
-        buffer.writeColor(fPMColor);
+        // Leaving this here until we bump the picture version, though this
+        // shader should never be recorded.
+        buffer.writeColor(SkColor());
     }
 
 private:
     SkShader*       fProxy;
-    SkPMColor       fPMColor;
-    const SkMask*   fMask;
 
     typedef SkShader INHERITED;
 };
 
 class Sk3DBlitter : public SkBlitter {
 public:
-    Sk3DBlitter(SkBlitter* proxy, Sk3DShader* shader)
+    Sk3DBlitter(SkBlitter* proxy, Sk3DShader::Sk3DShaderContext* shaderContext)
         : fProxy(proxy)
-        , f3DShader(SkRef(shader))
+        , f3DShaderContext(shaderContext)
     {}
 
     virtual void blitH(int x, int y, int width) {
@@ -729,22 +786,22 @@
 
     virtual void blitMask(const SkMask& mask, const SkIRect& clip) {
         if (mask.fFormat == SkMask::k3D_Format) {
-            f3DShader->setMask(&mask);
+            f3DShaderContext->setMask(&mask);
 
             ((SkMask*)&mask)->fFormat = SkMask::kA8_Format;
             fProxy->blitMask(mask, clip);
             ((SkMask*)&mask)->fFormat = SkMask::k3D_Format;
 
-            f3DShader->setMask(NULL);
+            f3DShaderContext->setMask(NULL);
         } else {
             fProxy->blitMask(mask, clip);
         }
     }
 
 private:
-    // fProxy is unowned. It will be deleted by SkSmallAllocator.
-    SkBlitter*               fProxy;
-    SkAutoTUnref<Sk3DShader> f3DShader;
+    // Both pointers are unowned. They will be deleted by SkSmallAllocator.
+    SkBlitter*                     fProxy;
+    Sk3DShader::Sk3DShaderContext* f3DShaderContext;
 };
 
 ///////////////////////////////////////////////////////////////////////////////
@@ -754,8 +811,7 @@
 static bool just_solid_color(const SkPaint& paint) {
     if (paint.getAlpha() == 0xFF && paint.getColorFilter() == NULL) {
         SkShader* shader = paint.getShader();
-        if (NULL == shader ||
-            (shader->getFlags() & SkShader::kOpaqueAlpha_Flag)) {
+        if (NULL == shader) {
             return true;
         }
     }
@@ -893,16 +949,22 @@
     }
 
     /*
-     *  We need to have balanced calls to the shader:
-     *      setContext
-     *      endContext
-     *  We make the first call here, in case it fails we can abort the draw.
-     *  The endContext() call is made by the blitter (assuming setContext did
-     *  not fail) in its destructor.
+     *  We create a SkShader::Context object, and store it on the blitter.
      */
-    if (shader && !shader->setContext(device, *paint, matrix)) {
-        blitter = allocator->createT<SkNullBlitter>();
-        return blitter;
+    SkShader::Context* shaderContext;
+    if (shader) {
+        // Try to create the ShaderContext
+        void* storage = allocator->reserveT<SkShader::Context>(shader->contextSize());
+        shaderContext = shader->createContext(device, *paint, matrix, storage);
+        if (!shaderContext) {
+            allocator->freeLast();
+            blitter = allocator->createT<SkNullBlitter>();
+            return blitter;
+        }
+        SkASSERT(shaderContext);
+        SkASSERT((void*) shaderContext == storage);
+    } else {
+        shaderContext = NULL;
     }
 
 
@@ -913,19 +975,20 @@
                 SkASSERT(NULL == paint->getXfermode());
                 blitter = allocator->createT<SkA8_Coverage_Blitter>(device, *paint);
             } else if (shader) {
-                blitter = allocator->createT<SkA8_Shader_Blitter>(device, *paint);
+                blitter = allocator->createT<SkA8_Shader_Blitter>(device, *paint, shaderContext);
             } else {
                 blitter = allocator->createT<SkA8_Blitter>(device, *paint);
             }
             break;
 
         case kRGB_565_SkColorType:
-            blitter = SkBlitter_ChooseD565(device, *paint, allocator);
+            blitter = SkBlitter_ChooseD565(device, *paint, shaderContext, allocator);
             break;
 
         case kN32_SkColorType:
             if (shader) {
-                blitter = allocator->createT<SkARGB32_Shader_Blitter>(device, *paint);
+                blitter = allocator->createT<SkARGB32_Shader_Blitter>(
+                        device, *paint, shaderContext);
             } else if (paint->getColor() == SK_ColorBLACK) {
                 blitter = allocator->createT<SkARGB32_Black_Blitter>(device, *paint);
             } else if (paint->getAlpha() == 0xFF) {
@@ -944,7 +1007,9 @@
     if (shader3D) {
         SkBlitter* innerBlitter = blitter;
         // innerBlitter was allocated by allocator, which will delete it.
-        blitter = allocator->createT<Sk3DBlitter>(innerBlitter, shader3D);
+        // We know shaderContext is of type Sk3DShaderContext because it belongs to shader3D.
+        blitter = allocator->createT<Sk3DBlitter>(innerBlitter,
+                static_cast<Sk3DShader::Sk3DShaderContext*>(shaderContext));
     }
     return blitter;
 }
@@ -956,18 +1021,36 @@
 
 ///////////////////////////////////////////////////////////////////////////////
 
-SkShaderBlitter::SkShaderBlitter(const SkBitmap& device, const SkPaint& paint)
-        : INHERITED(device) {
-    fShader = paint.getShader();
+SkShaderBlitter::SkShaderBlitter(const SkBitmap& device, const SkPaint& paint,
+                                 SkShader::Context* shaderContext)
+        : INHERITED(device)
+        , fShader(paint.getShader())
+        , fShaderContext(shaderContext) {
     SkASSERT(fShader);
-    SkASSERT(fShader->setContextHasBeenCalled());
+    SkASSERT(fShaderContext);
 
     fShader->ref();
-    fShaderFlags = fShader->getFlags();
+    fShaderFlags = fShaderContext->getFlags();
 }
 
 SkShaderBlitter::~SkShaderBlitter() {
-    SkASSERT(fShader->setContextHasBeenCalled());
-    fShader->endContext();
     fShader->unref();
 }
+
+bool SkShaderBlitter::resetShaderContext(const SkBitmap& device, const SkPaint& paint,
+                                         const SkMatrix& matrix) {
+    if (!fShader->validContext(device, paint, matrix)) {
+        return false;
+    }
+
+    // Only destroy the old context if we have a new one. We need to ensure to have a
+    // live context in fShaderContext because the storage is owned by an SkSmallAllocator
+    // outside of this class.
+    // The new context will be of the same size as the old one because we use the same
+    // shader to create it. It is therefore safe to re-use the storage.
+    fShaderContext->~Context();
+    fShaderContext = fShader->createContext(device, paint, matrix, (void*)fShaderContext);
+    SkASSERT(fShaderContext);
+
+    return true;
+}
diff --git a/src/core/SkBlitter.h b/src/core/SkBlitter.h
index d19a34b..f76839e 100644
--- a/src/core/SkBlitter.h
+++ b/src/core/SkBlitter.h
@@ -61,6 +61,13 @@
      */
     virtual bool isNullBlitter() const;
 
+    /**
+     *  Special methods for SkShaderBlitter. On all other classes this is a no-op.
+     */
+    virtual bool resetShaderContext(const SkBitmap& device, const SkPaint& paint,
+                                    const SkMatrix& matrix);
+    virtual SkShader::Context* getShaderContext() const;
+
     ///@name non-virtual helpers
     void blitMaskRegion(const SkMask& mask, const SkRegion& clip);
     void blitRectRegion(const SkIRect& rect, const SkRegion& clip);
diff --git a/src/core/SkBlitter_A8.cpp b/src/core/SkBlitter_A8.cpp
index 983a226..11f4259 100644
--- a/src/core/SkBlitter_A8.cpp
+++ b/src/core/SkBlitter_A8.cpp
@@ -228,11 +228,12 @@
 
 ///////////////////////////////////////////////////////////////////////
 
-SkA8_Shader_Blitter::SkA8_Shader_Blitter(const SkBitmap& device, const SkPaint& paint)
-    : INHERITED(device, paint) {
+SkA8_Shader_Blitter::SkA8_Shader_Blitter(const SkBitmap& device, const SkPaint& paint,
+                                         SkShader::Context* shaderContext)
+    : INHERITED(device, paint, shaderContext) {
     if ((fXfermode = paint.getXfermode()) != NULL) {
         fXfermode->ref();
-        SkASSERT(fShader);
+        SkASSERT(fShaderContext);
     }
 
     int width = device.width();
@@ -250,13 +251,14 @@
              (unsigned)(x + width) <= (unsigned)fDevice.width());
 
     uint8_t* device = fDevice.getAddr8(x, y);
+    SkShader::Context* shaderContext = fShaderContext;
 
-    if ((fShader->getFlags() & SkShader::kOpaqueAlpha_Flag) && !fXfermode) {
+    if ((shaderContext->getFlags() & SkShader::kOpaqueAlpha_Flag) && !fXfermode) {
         memset(device, 0xFF, width);
     } else {
         SkPMColor*  span = fBuffer;
 
-        fShader->shadeSpan(x, y, span, width);
+        shaderContext->shadeSpan(x, y, span, width);
         if (fXfermode) {
             fXfermode->xferA8(device, span, width, NULL);
         } else {
@@ -282,12 +284,12 @@
 
 void SkA8_Shader_Blitter::blitAntiH(int x, int y, const SkAlpha antialias[],
                                     const int16_t runs[]) {
-    SkShader*   shader = fShader;
-    SkXfermode* mode = fXfermode;
-    uint8_t*    aaExpand = fAAExpand;
-    SkPMColor*  span = fBuffer;
-    uint8_t*    device = fDevice.getAddr8(x, y);
-    int         opaque = fShader->getFlags() & SkShader::kOpaqueAlpha_Flag;
+    SkShader::Context* shaderContext = fShaderContext;
+    SkXfermode*        mode = fXfermode;
+    uint8_t*           aaExpand = fAAExpand;
+    SkPMColor*         span = fBuffer;
+    uint8_t*           device = fDevice.getAddr8(x, y);
+    int                opaque = shaderContext->getFlags() & SkShader::kOpaqueAlpha_Flag;
 
     for (;;) {
         int count = *runs;
@@ -299,7 +301,7 @@
             if (opaque && aa == 255 && mode == NULL) {
                 memset(device, 0xFF, count);
             } else {
-                shader->shadeSpan(x, y, span, count);
+                shaderContext->shadeSpan(x, y, span, count);
                 if (mode) {
                     memset(aaExpand, aa, count);
                     mode->xferA8(device, span, count, aaExpand);
@@ -329,11 +331,12 @@
     int height = clip.height();
     uint8_t* device = fDevice.getAddr8(x, y);
     const uint8_t* alpha = mask.getAddr8(x, y);
+    SkShader::Context* shaderContext = fShaderContext;
 
     SkPMColor*  span = fBuffer;
 
     while (--height >= 0) {
-        fShader->shadeSpan(x, y, span, width);
+        shaderContext->shadeSpan(x, y, span, width);
         if (fXfermode) {
             fXfermode->xferA8(device, span, width, alpha);
         } else {
diff --git a/src/core/SkBlitter_ARGB32.cpp b/src/core/SkBlitter_ARGB32.cpp
index d4bec1b..118a1d1 100644
--- a/src/core/SkBlitter_ARGB32.cpp
+++ b/src/core/SkBlitter_ARGB32.cpp
@@ -275,14 +275,16 @@
 }
 
 SkARGB32_Shader_Blitter::SkARGB32_Shader_Blitter(const SkBitmap& device,
-                            const SkPaint& paint) : INHERITED(device, paint) {
+        const SkPaint& paint, SkShader::Context* shaderContext)
+    : INHERITED(device, paint, shaderContext)
+{
     fBuffer = (SkPMColor*)sk_malloc_throw(device.width() * (sizeof(SkPMColor)));
 
     fXfermode = paint.getXfermode();
     SkSafeRef(fXfermode);
 
     int flags = 0;
-    if (!(fShader->getFlags() & SkShader::kOpaqueAlpha_Flag)) {
+    if (!(shaderContext->getFlags() & SkShader::kOpaqueAlpha_Flag)) {
         flags |= SkBlitRow::kSrcPixelAlpha_Flag32;
     }
     // we call this on the output from the shader
@@ -292,7 +294,7 @@
 
     fShadeDirectlyIntoDevice = false;
     if (fXfermode == NULL) {
-        if (fShader->getFlags() & SkShader::kOpaqueAlpha_Flag) {
+        if (shaderContext->getFlags() & SkShader::kOpaqueAlpha_Flag) {
             fShadeDirectlyIntoDevice = true;
         }
     } else {
@@ -305,7 +307,7 @@
         }
     }
 
-    fConstInY = SkToBool(fShader->getFlags() & SkShader::kConstInY32_Flag);
+    fConstInY = SkToBool(shaderContext->getFlags() & SkShader::kConstInY32_Flag);
 }
 
 SkARGB32_Shader_Blitter::~SkARGB32_Shader_Blitter() {
@@ -319,10 +321,10 @@
     uint32_t*   device = fDevice.getAddr32(x, y);
 
     if (fShadeDirectlyIntoDevice) {
-        fShader->shadeSpan(x, y, device, width);
+        fShaderContext->shadeSpan(x, y, device, width);
     } else {
         SkPMColor*  span = fBuffer;
-        fShader->shadeSpan(x, y, span, width);
+        fShaderContext->shadeSpan(x, y, span, width);
         if (fXfermode) {
             fXfermode->xfer32(device, span, width, NULL);
         } else {
@@ -335,22 +337,22 @@
     SkASSERT(x >= 0 && y >= 0 &&
              x + width <= fDevice.width() && y + height <= fDevice.height());
 
-    uint32_t*   device = fDevice.getAddr32(x, y);
-    size_t      deviceRB = fDevice.rowBytes();
-    SkShader*   shader = fShader;
-    SkPMColor*  span = fBuffer;
+    uint32_t*          device = fDevice.getAddr32(x, y);
+    size_t             deviceRB = fDevice.rowBytes();
+    SkShader::Context* shaderContext = fShaderContext;
+    SkPMColor*         span = fBuffer;
 
     if (fConstInY) {
         if (fShadeDirectlyIntoDevice) {
             // shade the first row directly into the device
-            fShader->shadeSpan(x, y, device, width);
+            shaderContext->shadeSpan(x, y, device, width);
             span = device;
             while (--height > 0) {
                 device = (uint32_t*)((char*)device + deviceRB);
                 memcpy(device, span, width << 2);
             }
         } else {
-            fShader->shadeSpan(x, y, span, width);
+            shaderContext->shadeSpan(x, y, span, width);
             SkXfermode* xfer = fXfermode;
             if (xfer) {
                 do {
@@ -372,7 +374,7 @@
 
     if (fShadeDirectlyIntoDevice) {
         void* ctx;
-        SkShader::ShadeProc shadeProc = fShader->asAShadeProc(&ctx);
+        SkShader::Context::ShadeProc shadeProc = shaderContext->asAShadeProc(&ctx);
         if (shadeProc) {
             do {
                 shadeProc(ctx, x, y, device, width);
@@ -381,7 +383,7 @@
             } while (--height > 0);
         } else {
             do {
-                shader->shadeSpan(x, y, device, width);
+                shaderContext->shadeSpan(x, y, device, width);
                 y += 1;
                 device = (uint32_t*)((char*)device + deviceRB);
             } while (--height > 0);
@@ -390,7 +392,7 @@
         SkXfermode* xfer = fXfermode;
         if (xfer) {
             do {
-                shader->shadeSpan(x, y, span, width);
+                shaderContext->shadeSpan(x, y, span, width);
                 xfer->xfer32(device, span, width, NULL);
                 y += 1;
                 device = (uint32_t*)((char*)device + deviceRB);
@@ -398,7 +400,7 @@
         } else {
             SkBlitRow::Proc32 proc = fProc32;
             do {
-                shader->shadeSpan(x, y, span, width);
+                shaderContext->shadeSpan(x, y, span, width);
                 proc(device, span, width, 255);
                 y += 1;
                 device = (uint32_t*)((char*)device + deviceRB);
@@ -409,9 +411,9 @@
 
 void SkARGB32_Shader_Blitter::blitAntiH(int x, int y, const SkAlpha antialias[],
                                         const int16_t runs[]) {
-    SkPMColor*  span = fBuffer;
-    uint32_t*   device = fDevice.getAddr32(x, y);
-    SkShader*   shader = fShader;
+    SkPMColor*         span = fBuffer;
+    uint32_t*          device = fDevice.getAddr32(x, y);
+    SkShader::Context* shaderContext = fShaderContext;
 
     if (fXfermode && !fShadeDirectlyIntoDevice) {
         for (;;) {
@@ -422,7 +424,7 @@
                 break;
             int aa = *antialias;
             if (aa) {
-                shader->shadeSpan(x, y, span, count);
+                shaderContext->shadeSpan(x, y, span, count);
                 if (aa == 255) {
                     xfer->xfer32(device, span, count, NULL);
                 } else {
@@ -438,7 +440,7 @@
             x += count;
         }
     } else if (fShadeDirectlyIntoDevice ||
-               (fShader->getFlags() & SkShader::kOpaqueAlpha_Flag)) {
+               (shaderContext->getFlags() & SkShader::kOpaqueAlpha_Flag)) {
         for (;;) {
             int count = *runs;
             if (count <= 0) {
@@ -448,9 +450,9 @@
             if (aa) {
                 if (aa == 255) {
                     // cool, have the shader draw right into the device
-                    shader->shadeSpan(x, y, device, count);
+                    shaderContext->shadeSpan(x, y, device, count);
                 } else {
-                    shader->shadeSpan(x, y, span, count);
+                    shaderContext->shadeSpan(x, y, span, count);
                     fProc32Blend(device, span, count, aa);
                 }
             }
@@ -467,7 +469,7 @@
             }
             int aa = *antialias;
             if (aa) {
-                fShader->shadeSpan(x, y, span, count);
+                shaderContext->shadeSpan(x, y, span, count);
                 if (aa == 255) {
                     fProc32(device, span, count, 255);
                 } else {
@@ -491,10 +493,11 @@
 
     SkASSERT(mask.fBounds.contains(clip));
 
+    SkShader::Context*  shaderContext = fShaderContext;
     SkBlitMask::RowProc proc = NULL;
     if (!fXfermode) {
         unsigned flags = 0;
-        if (fShader->getFlags() & SkShader::kOpaqueAlpha_Flag) {
+        if (shaderContext->getFlags() & SkShader::kOpaqueAlpha_Flag) {
             flags |= SkBlitMask::kSrcIsOpaque_RowFlag;
         }
         proc = SkBlitMask::RowFactory(SkBitmap::kARGB_8888_Config, mask.fFormat,
@@ -515,14 +518,13 @@
     const uint8_t* maskRow = (const uint8_t*)mask.getAddr(x, y);
     const size_t maskRB = mask.fRowBytes;
 
-    SkShader* shader = fShader;
     SkPMColor* span = fBuffer;
 
     if (fXfermode) {
         SkASSERT(SkMask::kA8_Format == mask.fFormat);
         SkXfermode* xfer = fXfermode;
         do {
-            shader->shadeSpan(x, y, span, width);
+            shaderContext->shadeSpan(x, y, span, width);
             xfer->xfer32((SkPMColor*)dstRow, span, width, maskRow);
             dstRow += dstRB;
             maskRow += maskRB;
@@ -530,7 +532,7 @@
         } while (--height > 0);
     } else {
         do {
-            shader->shadeSpan(x, y, span, width);
+            shaderContext->shadeSpan(x, y, span, width);
             proc(dstRow, maskRow, span, width);
             dstRow += dstRB;
             maskRow += maskRB;
@@ -542,13 +544,13 @@
 void SkARGB32_Shader_Blitter::blitV(int x, int y, int height, SkAlpha alpha) {
     SkASSERT(x >= 0 && y >= 0 && y + height <= fDevice.height());
 
-    uint32_t*   device = fDevice.getAddr32(x, y);
-    size_t      deviceRB = fDevice.rowBytes();
-    SkShader*   shader = fShader;
+    uint32_t*          device = fDevice.getAddr32(x, y);
+    size_t             deviceRB = fDevice.rowBytes();
+    SkShader::Context* shaderContext = fShaderContext;
 
     if (fConstInY) {
         SkPMColor c;
-        fShader->shadeSpan(x, y, &c, 1);
+        shaderContext->shadeSpan(x, y, &c, 1);
 
         if (fShadeDirectlyIntoDevice) {
             if (255 == alpha) {
@@ -582,7 +584,7 @@
 
     if (fShadeDirectlyIntoDevice) {
         void* ctx;
-        SkShader::ShadeProc shadeProc = fShader->asAShadeProc(&ctx);
+        SkShader::Context::ShadeProc shadeProc = shaderContext->asAShadeProc(&ctx);
         if (255 == alpha) {
             if (shadeProc) {
                 do {
@@ -592,7 +594,7 @@
                 } while (--height > 0);
             } else {
                 do {
-                    shader->shadeSpan(x, y, device, 1);
+                    shaderContext->shadeSpan(x, y, device, 1);
                     y += 1;
                     device = (uint32_t*)((char*)device + deviceRB);
                 } while (--height > 0);
@@ -608,7 +610,7 @@
                 } while (--height > 0);
             } else {
                 do {
-                    shader->shadeSpan(x, y, &c, 1);
+                    shaderContext->shadeSpan(x, y, &c, 1);
                     *device = SkFourByteInterp(c, *device, alpha);
                     y += 1;
                     device = (uint32_t*)((char*)device + deviceRB);
@@ -620,7 +622,7 @@
         SkXfermode* xfer = fXfermode;
         if (xfer) {
             do {
-                shader->shadeSpan(x, y, span, 1);
+                shaderContext->shadeSpan(x, y, span, 1);
                 xfer->xfer32(device, span, 1, &alpha);
                 y += 1;
                 device = (uint32_t*)((char*)device + deviceRB);
@@ -628,7 +630,7 @@
         } else {
             SkBlitRow::Proc32 proc = (255 == alpha) ? fProc32 : fProc32Blend;
             do {
-                shader->shadeSpan(x, y, span, 1);
+                shaderContext->shadeSpan(x, y, span, 1);
                 proc(device, span, 1, alpha);
                 y += 1;
                 device = (uint32_t*)((char*)device + deviceRB);
diff --git a/src/core/SkBlitter_RGB16.cpp b/src/core/SkBlitter_RGB16.cpp
index 21b5a16..e22aac4 100644
--- a/src/core/SkBlitter_RGB16.cpp
+++ b/src/core/SkBlitter_RGB16.cpp
@@ -107,7 +107,8 @@
 
 class SkRGB16_Shader_Blitter : public SkShaderBlitter {
 public:
-    SkRGB16_Shader_Blitter(const SkBitmap& device, const SkPaint& paint);
+    SkRGB16_Shader_Blitter(const SkBitmap& device, const SkPaint& paint,
+                           SkShader::Context* shaderContext);
     virtual ~SkRGB16_Shader_Blitter();
     virtual void blitH(int x, int y, int width);
     virtual void blitAntiH(int x, int y, const SkAlpha* antialias,
@@ -129,7 +130,8 @@
 // used only if the shader can perform shadSpan16
 class SkRGB16_Shader16_Blitter : public SkRGB16_Shader_Blitter {
 public:
-    SkRGB16_Shader16_Blitter(const SkBitmap& device, const SkPaint& paint);
+    SkRGB16_Shader16_Blitter(const SkBitmap& device, const SkPaint& paint,
+                             SkShader::Context* shaderContext);
     virtual void blitH(int x, int y, int width);
     virtual void blitAntiH(int x, int y, const SkAlpha* antialias,
                            const int16_t* runs);
@@ -141,7 +143,8 @@
 
 class SkRGB16_Shader_Xfermode_Blitter : public SkShaderBlitter {
 public:
-    SkRGB16_Shader_Xfermode_Blitter(const SkBitmap& device, const SkPaint& paint);
+    SkRGB16_Shader_Xfermode_Blitter(const SkBitmap& device, const SkPaint& paint,
+                                    SkShader::Context* shaderContext);
     virtual ~SkRGB16_Shader_Xfermode_Blitter();
     virtual void blitH(int x, int y, int width);
     virtual void blitAntiH(int x, int y, const SkAlpha* antialias,
@@ -679,8 +682,9 @@
 ///////////////////////////////////////////////////////////////////////////////
 
 SkRGB16_Shader16_Blitter::SkRGB16_Shader16_Blitter(const SkBitmap& device,
-                                                   const SkPaint& paint)
-    : SkRGB16_Shader_Blitter(device, paint) {
+                                                   const SkPaint& paint,
+                                                   SkShader::Context* shaderContext)
+    : SkRGB16_Shader_Blitter(device, paint, shaderContext) {
     SkASSERT(SkShader::CanCallShadeSpan16(fShaderFlags));
 }
 
@@ -688,28 +692,28 @@
     SkASSERT(x + width <= fDevice.width());
 
     uint16_t* SK_RESTRICT device = fDevice.getAddr16(x, y);
-    SkShader*   shader = fShader;
+    SkShader::Context*    shaderContext = fShaderContext;
 
-    int alpha = shader->getSpan16Alpha();
+    int alpha = shaderContext->getSpan16Alpha();
     if (0xFF == alpha) {
-        shader->shadeSpan16(x, y, device, width);
+        shaderContext->shadeSpan16(x, y, device, width);
     } else {
         uint16_t* span16 = (uint16_t*)fBuffer;
-        shader->shadeSpan16(x, y, span16, width);
+        shaderContext->shadeSpan16(x, y, span16, width);
         SkBlendRGB16(span16, device, SkAlpha255To256(alpha), width);
     }
 }
 
 void SkRGB16_Shader16_Blitter::blitRect(int x, int y, int width, int height) {
-    SkShader*   shader = fShader;
-    uint16_t*   dst = fDevice.getAddr16(x, y);
-    size_t      dstRB = fDevice.rowBytes();
-    int         alpha = shader->getSpan16Alpha();
+    SkShader::Context* shaderContext = fShaderContext;
+    uint16_t*          dst = fDevice.getAddr16(x, y);
+    size_t             dstRB = fDevice.rowBytes();
+    int                alpha = shaderContext->getSpan16Alpha();
 
     if (0xFF == alpha) {
         if (fShaderFlags & SkShader::kConstInY16_Flag) {
             // have the shader blit directly into the device the first time
-            shader->shadeSpan16(x, y, dst, width);
+            shaderContext->shadeSpan16(x, y, dst, width);
             // and now just memcpy that line on the subsequent lines
             if (--height > 0) {
                 const uint16_t* orig = dst;
@@ -720,7 +724,7 @@
             }
         } else {    // need to call shadeSpan16 for every line
             do {
-                shader->shadeSpan16(x, y, dst, width);
+                shaderContext->shadeSpan16(x, y, dst, width);
                 y += 1;
                 dst = (uint16_t*)((char*)dst + dstRB);
             } while (--height);
@@ -729,14 +733,14 @@
         int scale = SkAlpha255To256(alpha);
         uint16_t* span16 = (uint16_t*)fBuffer;
         if (fShaderFlags & SkShader::kConstInY16_Flag) {
-            shader->shadeSpan16(x, y, span16, width);
+            shaderContext->shadeSpan16(x, y, span16, width);
             do {
                 SkBlendRGB16(span16, dst, scale, width);
                 dst = (uint16_t*)((char*)dst + dstRB);
             } while (--height);
         } else {
             do {
-                shader->shadeSpan16(x, y, span16, width);
+                shaderContext->shadeSpan16(x, y, span16, width);
                 SkBlendRGB16(span16, dst, scale, width);
                 y += 1;
                 dst = (uint16_t*)((char*)dst + dstRB);
@@ -748,11 +752,11 @@
 void SkRGB16_Shader16_Blitter::blitAntiH(int x, int y,
                                          const SkAlpha* SK_RESTRICT antialias,
                                          const int16_t* SK_RESTRICT runs) {
-    SkShader*   shader = fShader;
+    SkShader::Context*     shaderContext = fShaderContext;
     SkPMColor* SK_RESTRICT span = fBuffer;
-    uint16_t* SK_RESTRICT device = fDevice.getAddr16(x, y);
+    uint16_t* SK_RESTRICT  device = fDevice.getAddr16(x, y);
 
-    int alpha = shader->getSpan16Alpha();
+    int alpha = shaderContext->getSpan16Alpha();
     uint16_t* span16 = (uint16_t*)span;
 
     if (0xFF == alpha) {
@@ -766,9 +770,9 @@
             int aa = *antialias;
             if (aa == 255) {
                 // go direct to the device!
-                shader->shadeSpan16(x, y, device, count);
+                shaderContext->shadeSpan16(x, y, device, count);
             } else if (aa) {
-                shader->shadeSpan16(x, y, span16, count);
+                shaderContext->shadeSpan16(x, y, span16, count);
                 SkBlendRGB16(span16, device, SkAlpha255To256(aa), count);
             }
             device += count;
@@ -787,7 +791,7 @@
 
             int aa = SkAlphaMul(*antialias, alpha);
             if (aa) {
-                shader->shadeSpan16(x, y, span16, count);
+                shaderContext->shadeSpan16(x, y, span16, count);
                 SkBlendRGB16(span16, device, SkAlpha255To256(aa), count);
             }
 
@@ -802,8 +806,9 @@
 ///////////////////////////////////////////////////////////////////////////////
 
 SkRGB16_Shader_Blitter::SkRGB16_Shader_Blitter(const SkBitmap& device,
-                                               const SkPaint& paint)
-: INHERITED(device, paint) {
+                                               const SkPaint& paint,
+                                               SkShader::Context* shaderContext)
+: INHERITED(device, paint, shaderContext) {
     SkASSERT(paint.getXfermode() == NULL);
 
     fBuffer = (SkPMColor*)sk_malloc_throw(device.width() * sizeof(SkPMColor));
@@ -834,20 +839,20 @@
 void SkRGB16_Shader_Blitter::blitH(int x, int y, int width) {
     SkASSERT(x + width <= fDevice.width());
 
-    fShader->shadeSpan(x, y, fBuffer, width);
+    fShaderContext->shadeSpan(x, y, fBuffer, width);
     // shaders take care of global alpha, so we pass 0xFF (should be ignored)
     fOpaqueProc(fDevice.getAddr16(x, y), fBuffer, width, 0xFF, x, y);
 }
 
 void SkRGB16_Shader_Blitter::blitRect(int x, int y, int width, int height) {
-    SkShader*       shader = fShader;
-    SkBlitRow::Proc proc = fOpaqueProc;
-    SkPMColor*      buffer = fBuffer;
-    uint16_t*       dst = fDevice.getAddr16(x, y);
-    size_t          dstRB = fDevice.rowBytes();
+    SkShader::Context* shaderContext = fShaderContext;
+    SkBlitRow::Proc    proc = fOpaqueProc;
+    SkPMColor*         buffer = fBuffer;
+    uint16_t*          dst = fDevice.getAddr16(x, y);
+    size_t             dstRB = fDevice.rowBytes();
 
     if (fShaderFlags & SkShader::kConstInY32_Flag) {
-        shader->shadeSpan(x, y, buffer, width);
+        shaderContext->shadeSpan(x, y, buffer, width);
         do {
             proc(dst, buffer, width, 0xFF, x, y);
             y += 1;
@@ -855,7 +860,7 @@
         } while (--height);
     } else {
         do {
-            shader->shadeSpan(x, y, buffer, width);
+            shaderContext->shadeSpan(x, y, buffer, width);
             proc(dst, buffer, width, 0xFF, x, y);
             y += 1;
             dst = (uint16_t*)((char*)dst + dstRB);
@@ -880,9 +885,9 @@
 void SkRGB16_Shader_Blitter::blitAntiH(int x, int y,
                                        const SkAlpha* SK_RESTRICT antialias,
                                        const int16_t* SK_RESTRICT runs) {
-    SkShader*   shader = fShader;
+    SkShader::Context*     shaderContext = fShaderContext;
     SkPMColor* SK_RESTRICT span = fBuffer;
-    uint16_t* SK_RESTRICT device = fDevice.getAddr16(x, y);
+    uint16_t* SK_RESTRICT  device = fDevice.getAddr16(x, y);
 
     for (;;) {
         int count = *runs;
@@ -901,7 +906,7 @@
         int nonZeroCount = count + count_nonzero_span(runs + count, antialias + count);
 
         SkASSERT(nonZeroCount <= fDevice.width()); // don't overrun fBuffer
-        shader->shadeSpan(x, y, span, nonZeroCount);
+        shaderContext->shadeSpan(x, y, span, nonZeroCount);
 
         SkPMColor* localSpan = span;
         for (;;) {
@@ -928,8 +933,9 @@
 ///////////////////////////////////////////////////////////////////////
 
 SkRGB16_Shader_Xfermode_Blitter::SkRGB16_Shader_Xfermode_Blitter(
-                                const SkBitmap& device, const SkPaint& paint)
-: INHERITED(device, paint) {
+                                const SkBitmap& device, const SkPaint& paint,
+                                SkShader::Context* shaderContext)
+: INHERITED(device, paint, shaderContext) {
     fXfermode = paint.getXfermode();
     SkASSERT(fXfermode);
     fXfermode->ref();
@@ -950,18 +956,18 @@
     uint16_t*   device = fDevice.getAddr16(x, y);
     SkPMColor*  span = fBuffer;
 
-    fShader->shadeSpan(x, y, span, width);
+    fShaderContext->shadeSpan(x, y, span, width);
     fXfermode->xfer16(device, span, width, NULL);
 }
 
 void SkRGB16_Shader_Xfermode_Blitter::blitAntiH(int x, int y,
                                 const SkAlpha* SK_RESTRICT antialias,
                                 const int16_t* SK_RESTRICT runs) {
-    SkShader*   shader = fShader;
-    SkXfermode* mode = fXfermode;
+    SkShader::Context*     shaderContext = fShaderContext;
+    SkXfermode*            mode = fXfermode;
     SkPMColor* SK_RESTRICT span = fBuffer;
-    uint8_t* SK_RESTRICT aaExpand = fAAExpand;
-    uint16_t* SK_RESTRICT device = fDevice.getAddr16(x, y);
+    uint8_t* SK_RESTRICT   aaExpand = fAAExpand;
+    uint16_t* SK_RESTRICT  device = fDevice.getAddr16(x, y);
 
     for (;;) {
         int count = *runs;
@@ -981,7 +987,7 @@
                                                       antialias + count);
 
         SkASSERT(nonZeroCount <= fDevice.width()); // don't overrun fBuffer
-        shader->shadeSpan(x, y, span, nonZeroCount);
+        shaderContext->shadeSpan(x, y, span, nonZeroCount);
 
         x += nonZeroCount;
         SkPMColor* localSpan = span;
@@ -1012,6 +1018,7 @@
 ///////////////////////////////////////////////////////////////////////////////
 
 SkBlitter* SkBlitter_ChooseD565(const SkBitmap& device, const SkPaint& paint,
+        SkShader::Context* shaderContext,
         SkTBlitterAllocator* allocator) {
     SkASSERT(allocator != NULL);
 
@@ -1023,12 +1030,14 @@
     SkASSERT(NULL == mode || NULL != shader);
 
     if (shader) {
+        SkASSERT(shaderContext != NULL);
         if (mode) {
-            blitter = allocator->createT<SkRGB16_Shader_Xfermode_Blitter>(device, paint);
-        } else if (shader->canCallShadeSpan16()) {
-            blitter = allocator->createT<SkRGB16_Shader16_Blitter>(device, paint);
+            blitter = allocator->createT<SkRGB16_Shader_Xfermode_Blitter>(device, paint,
+                                                                          shaderContext);
+        } else if (shaderContext->canCallShadeSpan16()) {
+            blitter = allocator->createT<SkRGB16_Shader16_Blitter>(device, paint, shaderContext);
         } else {
-            blitter = allocator->createT<SkRGB16_Shader_Blitter>(device, paint);
+            blitter = allocator->createT<SkRGB16_Shader_Blitter>(device, paint, shaderContext);
         }
     } else {
         // no shader, no xfermode, (and we always ignore colorfilter)
diff --git a/src/core/SkCanvas.cpp b/src/core/SkCanvas.cpp
index d839971..e3451cd 100644
--- a/src/core/SkCanvas.cpp
+++ b/src/core/SkCanvas.cpp
@@ -91,32 +91,10 @@
 };
 #endif
 
-class AutoCheckNoSetContext {
-public:
-    AutoCheckNoSetContext(const SkPaint& paint) : fPaint(paint) {
-        this->assertNoSetContext(fPaint);
-    }
-    ~AutoCheckNoSetContext() {
-        this->assertNoSetContext(fPaint);
-    }
-
-private:
-    const SkPaint& fPaint;
-
-    void assertNoSetContext(const SkPaint& paint) {
-        SkShader* s = paint.getShader();
-        if (s) {
-            SkASSERT(!s->setContextHasBeenCalled());
-        }
-    }
-};
-
 #define CHECK_LOCKCOUNT_BALANCE(bitmap)  AutoCheckLockCountBalance clcb(bitmap)
-#define CHECK_SHADER_NOSETCONTEXT(paint) AutoCheckNoSetContext     cshsc(paint)
 
 #else
     #define CHECK_LOCKCOUNT_BALANCE(bitmap)
-    #define CHECK_SHADER_NOSETCONTEXT(paint)
 #endif
 
 typedef SkTLazy<SkPaint> SkLazyPaint;
@@ -1940,8 +1918,6 @@
 }
 
 void SkCanvas::internalDrawPaint(const SkPaint& paint) {
-    CHECK_SHADER_NOSETCONTEXT(paint);
-
     LOOPER_BEGIN(paint, SkDrawFilter::kPaint_Type, NULL)
 
     while (iter.next()) {
@@ -1957,8 +1933,6 @@
         return;
     }
 
-    CHECK_SHADER_NOSETCONTEXT(paint);
-
     SkRect r, storage;
     const SkRect* bounds = NULL;
     if (paint.canComputeFastBounds()) {
@@ -1986,8 +1960,6 @@
 }
 
 void SkCanvas::drawRect(const SkRect& r, const SkPaint& paint) {
-    CHECK_SHADER_NOSETCONTEXT(paint);
-
     SkRect storage;
     const SkRect* bounds = NULL;
     if (paint.canComputeFastBounds()) {
@@ -2007,8 +1979,6 @@
 }
 
 void SkCanvas::drawOval(const SkRect& oval, const SkPaint& paint) {
-    CHECK_SHADER_NOSETCONTEXT(paint);
-
     SkRect storage;
     const SkRect* bounds = NULL;
     if (paint.canComputeFastBounds()) {
@@ -2028,8 +1998,6 @@
 }
 
 void SkCanvas::drawRRect(const SkRRect& rrect, const SkPaint& paint) {
-    CHECK_SHADER_NOSETCONTEXT(paint);
-
     SkRect storage;
     const SkRect* bounds = NULL;
     if (paint.canComputeFastBounds()) {
@@ -2060,8 +2028,6 @@
 
 void SkCanvas::onDrawDRRect(const SkRRect& outer, const SkRRect& inner,
                             const SkPaint& paint) {
-    CHECK_SHADER_NOSETCONTEXT(paint);
-
     SkRect storage;
     const SkRect* bounds = NULL;
     if (paint.canComputeFastBounds()) {
@@ -2081,8 +2047,6 @@
 }
 
 void SkCanvas::drawPath(const SkPath& path, const SkPaint& paint) {
-    CHECK_SHADER_NOSETCONTEXT(paint);
-
     if (!path.isFinite()) {
         return;
     }
@@ -2358,8 +2322,6 @@
 
 void SkCanvas::onDrawText(const void* text, size_t byteLength, SkScalar x, SkScalar y,
                           const SkPaint& paint) {
-    CHECK_SHADER_NOSETCONTEXT(paint);
-
     LOOPER_BEGIN(paint, SkDrawFilter::kText_Type, NULL)
 
     while (iter.next()) {
@@ -2374,10 +2336,8 @@
 
 void SkCanvas::onDrawPosText(const void* text, size_t byteLength, const SkPoint pos[],
                              const SkPaint& paint) {
-    CHECK_SHADER_NOSETCONTEXT(paint);
-    
     LOOPER_BEGIN(paint, SkDrawFilter::kText_Type, NULL)
-    
+
     while (iter.next()) {
         SkDeviceFilteredPaint dfp(iter.fDevice, looper.paint());
         iter.fDevice->drawPosText(iter, text, byteLength, &pos->fX, 0, 2,
@@ -2389,10 +2349,8 @@
 
 void SkCanvas::onDrawPosTextH(const void* text, size_t byteLength, const SkScalar xpos[],
                               SkScalar constY, const SkPaint& paint) {
-    CHECK_SHADER_NOSETCONTEXT(paint);
-    
     LOOPER_BEGIN(paint, SkDrawFilter::kText_Type, NULL)
-    
+
     while (iter.next()) {
         SkDeviceFilteredPaint dfp(iter.fDevice, looper.paint());
         iter.fDevice->drawPosText(iter, text, byteLength, xpos, constY, 1,
@@ -2404,10 +2362,8 @@
 
 void SkCanvas::onDrawTextOnPath(const void* text, size_t byteLength, const SkPath& path,
                                 const SkMatrix* matrix, const SkPaint& paint) {
-    CHECK_SHADER_NOSETCONTEXT(paint);
-    
     LOOPER_BEGIN(paint, SkDrawFilter::kText_Type, NULL)
-    
+
     while (iter.next()) {
         iter.fDevice->drawTextOnPath(iter, text, byteLength, path,
                                      matrix, looper.paint());
@@ -2439,8 +2395,6 @@
                             const SkColor colors[], SkXfermode* xmode,
                             const uint16_t indices[], int indexCount,
                             const SkPaint& paint) {
-    CHECK_SHADER_NOSETCONTEXT(paint);
-
     LOOPER_BEGIN(paint, SkDrawFilter::kPath_Type, NULL)
 
     while (iter.next()) {
diff --git a/src/core/SkComposeShader.cpp b/src/core/SkComposeShader.cpp
index f53eedf..77bc46f 100644
--- a/src/core/SkComposeShader.cpp
+++ b/src/core/SkComposeShader.cpp
@@ -45,6 +45,10 @@
     fShaderA->unref();
 }
 
+size_t SkComposeShader::contextSize() const {
+    return sizeof(ComposeShaderContext) + fShaderA->contextSize() + fShaderB->contextSize();
+}
+
 class SkAutoAlphaRestore {
 public:
     SkAutoAlphaRestore(SkPaint* paint, uint8_t newAlpha) {
@@ -69,17 +73,16 @@
     buffer.writeFlattenable(fMode);
 }
 
-/*  We call setContext on our two worker shaders. However, we
-    always let them see opaque alpha, and if the paint really
-    is translucent, then we apply that after the fact.
+/*  We call validContext/createContext on our two worker shaders.
+    However, we always let them see opaque alpha, and if the paint
+    really is translucent, then we apply that after the fact.
 
-    We need to keep the calls to setContext/endContext balanced, since if we
-    return false, our endContext() will not be called.
  */
-bool SkComposeShader::setContext(const SkBitmap& device,
-                                 const SkPaint& paint,
-                                 const SkMatrix& matrix) {
-    if (!this->INHERITED::setContext(device, paint, matrix)) {
+bool SkComposeShader::validContext(const SkBitmap& device,
+                                   const SkPaint& paint,
+                                   const SkMatrix& matrix,
+                                   SkMatrix* totalInverse) const {
+    if (!this->INHERITED::validContext(device, paint, matrix, totalInverse)) {
         return false;
     }
 
@@ -90,38 +93,62 @@
 
     tmpM.setConcat(matrix, this->getLocalMatrix());
 
+    return fShaderA->validContext(device, paint, tmpM) &&
+           fShaderB->validContext(device, paint, tmpM);
+}
+
+SkShader::Context* SkComposeShader::createContext(const SkBitmap& device, const SkPaint& paint,
+                                                  const SkMatrix& matrix, void* storage) const {
+    if (!this->validContext(device, paint, matrix)) {
+        return NULL;
+    }
+
+    // we preconcat our localMatrix (if any) with the device matrix
+    // before calling our sub-shaders
+
+    SkMatrix tmpM;
+
+    tmpM.setConcat(matrix, this->getLocalMatrix());
+
     SkAutoAlphaRestore  restore(const_cast<SkPaint*>(&paint), 0xFF);
 
-    bool setContextA = fShaderA->setContext(device, paint, tmpM);
-    bool setContextB = fShaderB->setContext(device, paint, tmpM);
-    if (!setContextA || !setContextB) {
-        if (setContextB) {
-            fShaderB->endContext();
-        }
-        else if (setContextA) {
-            fShaderA->endContext();
-        }
-        this->INHERITED::endContext();
-        return false;
-    }
-    return true;
+    char* aStorage = (char*) storage + sizeof(ComposeShaderContext);
+    char* bStorage = aStorage + fShaderA->contextSize();
+
+    SkShader::Context* contextA = fShaderA->createContext(device, paint, tmpM, aStorage);
+    SkShader::Context* contextB = fShaderB->createContext(device, paint, tmpM, bStorage);
+
+    // Both functions must succeed; otherwise validContext should have returned
+    // false.
+    SkASSERT(contextA);
+    SkASSERT(contextB);
+
+    return SkNEW_PLACEMENT_ARGS(storage, ComposeShaderContext,
+                                (*this, device, paint, matrix, contextA, contextB));
 }
 
-void SkComposeShader::endContext() {
-    fShaderB->endContext();
-    fShaderA->endContext();
-    this->INHERITED::endContext();
+SkComposeShader::ComposeShaderContext::ComposeShaderContext(
+        const SkComposeShader& shader, const SkBitmap& device,
+        const SkPaint& paint, const SkMatrix& matrix,
+        SkShader::Context* contextA, SkShader::Context* contextB)
+    : INHERITED(shader, device, paint, matrix)
+    , fShaderContextA(contextA)
+    , fShaderContextB(contextB) {}
+
+SkComposeShader::ComposeShaderContext::~ComposeShaderContext() {
+    fShaderContextA->~Context();
+    fShaderContextB->~Context();
 }
 
 // larger is better (fewer times we have to loop), but we shouldn't
 // take up too much stack-space (each element is 4 bytes)
 #define TMP_COLOR_COUNT     64
 
-void SkComposeShader::shadeSpan(int x, int y, SkPMColor result[], int count) {
-    SkShader*   shaderA = fShaderA;
-    SkShader*   shaderB = fShaderB;
-    SkXfermode* mode = fMode;
-    unsigned    scale = SkAlpha255To256(this->getPaintAlpha());
+void SkComposeShader::ComposeShaderContext::shadeSpan(int x, int y, SkPMColor result[], int count) {
+    SkShader::Context* shaderContextA = fShaderContextA;
+    SkShader::Context* shaderContextB = fShaderContextB;
+    SkXfermode*        mode = static_cast<const SkComposeShader&>(fShader).fMode;
+    unsigned           scale = SkAlpha255To256(this->getPaintAlpha());
 
     SkPMColor   tmp[TMP_COLOR_COUNT];
 
@@ -134,8 +161,8 @@
                 n = TMP_COLOR_COUNT;
             }
 
-            shaderA->shadeSpan(x, y, result, n);
-            shaderB->shadeSpan(x, y, tmp, n);
+            shaderContextA->shadeSpan(x, y, result, n);
+            shaderContextB->shadeSpan(x, y, tmp, n);
 
             if (256 == scale) {
                 for (int i = 0; i < n; i++) {
@@ -159,8 +186,8 @@
                 n = TMP_COLOR_COUNT;
             }
 
-            shaderA->shadeSpan(x, y, result, n);
-            shaderB->shadeSpan(x, y, tmp, n);
+            shaderContextA->shadeSpan(x, y, result, n);
+            shaderContextB->shadeSpan(x, y, tmp, n);
             mode->xfer32(result, tmp, n, NULL);
 
             if (256 == scale) {
diff --git a/src/core/SkCoreBlitters.h b/src/core/SkCoreBlitters.h
index 2851840..2d22d38 100644
--- a/src/core/SkCoreBlitters.h
+++ b/src/core/SkCoreBlitters.h
@@ -27,12 +27,29 @@
 
 class SkShaderBlitter : public SkRasterBlitter {
 public:
-    SkShaderBlitter(const SkBitmap& device, const SkPaint& paint);
+    /**
+      *  The storage for shaderContext is owned by the caller, but the object itself is not.
+      *  The blitter only ensures that the storage always holds a live object, but it may
+      *  exchange that object.
+      */
+    SkShaderBlitter(const SkBitmap& device, const SkPaint& paint,
+                    SkShader::Context* shaderContext);
     virtual ~SkShaderBlitter();
 
+    /**
+      *  Create a new shader context and uses it instead of the old one if successful.
+      *  Will create the context at the same location as the old one (this is safe
+      *  because the shader itself is unchanged).
+      */
+    virtual bool resetShaderContext(const SkBitmap& device, const SkPaint& paint,
+                                    const SkMatrix& matrix) SK_OVERRIDE;
+
+    virtual SkShader::Context* getShaderContext() const SK_OVERRIDE { return fShaderContext; }
+
 protected:
-    uint32_t    fShaderFlags;
-    SkShader*   fShader;
+    uint32_t            fShaderFlags;
+    const SkShader*     fShader;
+    SkShader::Context*  fShaderContext;
 
 private:
     // illegal
@@ -75,7 +92,8 @@
 
 class SkA8_Shader_Blitter : public SkShaderBlitter {
 public:
-    SkA8_Shader_Blitter(const SkBitmap& device, const SkPaint& paint);
+    SkA8_Shader_Blitter(const SkBitmap& device, const SkPaint& paint,
+                        SkShader::Context* shaderContext);
     virtual ~SkA8_Shader_Blitter();
     virtual void blitH(int x, int y, int width);
     virtual void blitAntiH(int x, int y, const SkAlpha antialias[], const int16_t runs[]);
@@ -141,7 +159,8 @@
 
 class SkARGB32_Shader_Blitter : public SkShaderBlitter {
 public:
-    SkARGB32_Shader_Blitter(const SkBitmap& device, const SkPaint& paint);
+    SkARGB32_Shader_Blitter(const SkBitmap& device, const SkPaint& paint,
+                            SkShader::Context* shaderContext);
     virtual ~SkARGB32_Shader_Blitter();
     virtual void blitH(int x, int y, int width) SK_OVERRIDE;
     virtual void blitV(int x, int y, int height, SkAlpha alpha) SK_OVERRIDE;
@@ -179,6 +198,7 @@
  */
 
 SkBlitter* SkBlitter_ChooseD565(const SkBitmap& device, const SkPaint& paint,
+                                SkShader::Context* shaderContext,
                                 SkTBlitterAllocator* allocator);
 
 #endif
diff --git a/src/core/SkDraw.cpp b/src/core/SkDraw.cpp
index 7eb0be6..6ddd0d2 100644
--- a/src/core/SkDraw.cpp
+++ b/src/core/SkDraw.cpp
@@ -2354,9 +2354,26 @@
 public:
     SkTriColorShader() {}
 
-    bool setup(const SkPoint pts[], const SkColor colors[], int, int, int);
+    virtual SkShader::Context* createContext(
+            const SkBitmap&, const SkPaint&, const SkMatrix&, void*) const SK_OVERRIDE;
+    virtual size_t contextSize() const SK_OVERRIDE;
 
-    virtual void shadeSpan(int x, int y, SkPMColor dstC[], int count) SK_OVERRIDE;
+    class TriColorShaderContext : public SkShader::Context {
+    public:
+        TriColorShaderContext(const SkTriColorShader& shader, const SkBitmap& device,
+                              const SkPaint& paint, const SkMatrix& matrix);
+        virtual ~TriColorShaderContext();
+
+        bool setup(const SkPoint pts[], const SkColor colors[], int, int, int);
+
+        virtual void shadeSpan(int x, int y, SkPMColor dstC[], int count) SK_OVERRIDE;
+
+    private:
+        SkMatrix    fDstToUnit;
+        SkPMColor   fColors[3];
+
+        typedef SkShader::Context INHERITED;
+    };
 
     SK_TO_STRING_OVERRIDE()
     SK_DECLARE_PUBLIC_FLATTENABLE_DESERIALIZATION_PROCS(SkTriColorShader)
@@ -2365,14 +2382,20 @@
     SkTriColorShader(SkReadBuffer& buffer) : SkShader(buffer) {}
 
 private:
-    SkMatrix    fDstToUnit;
-    SkPMColor   fColors[3];
-
     typedef SkShader INHERITED;
 };
 
-bool SkTriColorShader::setup(const SkPoint pts[], const SkColor colors[],
-                             int index0, int index1, int index2) {
+SkShader::Context* SkTriColorShader::createContext(const SkBitmap& device, const SkPaint& paint,
+                                                   const SkMatrix& matrix, void* storage) const {
+    if (!this->validContext(device, paint, matrix)) {
+        return NULL;
+    }
+
+    return SkNEW_PLACEMENT_ARGS(storage, TriColorShaderContext, (*this, device, paint, matrix));
+}
+
+bool SkTriColorShader::TriColorShaderContext::setup(const SkPoint pts[], const SkColor colors[],
+                                                    int index0, int index1, int index2) {
 
     fColors[0] = SkPreMultiplyColor(colors[index0]);
     fColors[1] = SkPreMultiplyColor(colors[index1]);
@@ -2407,7 +2430,18 @@
     return SkAlpha255To256(scale);
 }
 
-void SkTriColorShader::shadeSpan(int x, int y, SkPMColor dstC[], int count) {
+
+SkTriColorShader::TriColorShaderContext::TriColorShaderContext(
+        const SkTriColorShader& shader, const SkBitmap& device,
+        const SkPaint& paint, const SkMatrix& matrix)
+    : INHERITED(shader, device, paint, matrix) {}
+
+SkTriColorShader::TriColorShaderContext::~TriColorShaderContext() {}
+
+size_t SkTriColorShader::contextSize() const {
+    return sizeof(TriColorShaderContext);
+}
+void SkTriColorShader::TriColorShaderContext::shadeSpan(int x, int y, SkPMColor dstC[], int count) {
     SkPoint src;
 
     for (int i = 0; i < count; i++) {
@@ -2492,6 +2526,7 @@
     }
 
     // setup the custom shader (if needed)
+    SkAutoTUnref<SkComposeShader> composeShader;
     if (NULL != colors) {
         if (NULL == textures) {
             // just colors (no texture)
@@ -2504,9 +2539,8 @@
                 xmode = SkXfermode::Create(SkXfermode::kModulate_Mode);
                 releaseMode = true;
             }
-            SkShader* compose = SkNEW_ARGS(SkComposeShader,
-                                           (&triShader, shader, xmode));
-            p.setShader(compose)->unref();
+            composeShader.reset(SkNEW_ARGS(SkComposeShader, (&triShader, shader, xmode)));
+            p.setShader(composeShader);
             if (releaseMode) {
                 xmode->unref();
             }
@@ -2514,9 +2548,7 @@
     }
 
     SkAutoBlitterChoose blitter(*fBitmap, *fMatrix, p);
-    // important that we abort early, as below we may manipulate the shader
-    // and that is only valid if the shader returned true from setContext.
-    // If it returned false, then our blitter will be the NullBlitter.
+    // Abort early if we failed to create a shader context.
     if (blitter->isNullBlitter()) {
         return;
     }
@@ -2532,30 +2564,38 @@
             savedLocalM = shader->getLocalMatrix();
         }
 
-        // setContext has already been called and verified to return true
-        // by the constructor of SkAutoBlitterChoose
-        bool prevContextSuccess = true;
         while (vertProc(&state)) {
             if (NULL != textures) {
                 if (texture_to_matrix(state, vertices, textures, &tempM)) {
                     tempM.postConcat(savedLocalM);
                     shader->setLocalMatrix(tempM);
-                    // Need to recall setContext since we changed the local matrix.
-                    // However, we also need to balance the calls this with a
-                    // call to endContext which requires tracking the result of
-                    // the previous call to setContext.
-                    if (prevContextSuccess) {
-                        shader->endContext();
-                    }
-                    prevContextSuccess = shader->setContext(*fBitmap, p, *fMatrix);
-                    if (!prevContextSuccess) {
+                    if (!blitter->resetShaderContext(*fBitmap, p, *fMatrix)) {
                         continue;
                     }
                 }
             }
             if (NULL != colors) {
-                if (!triShader.setup(vertices, colors,
-                                     state.f0, state.f1, state.f2)) {
+                // Find the context for triShader.
+                SkTriColorShader::TriColorShaderContext* triColorShaderContext;
+
+                SkShader::Context* shaderContext = blitter->getShaderContext();
+                SkASSERT(shaderContext);
+                if (p.getShader() == &triShader) {
+                    triColorShaderContext =
+                            static_cast<SkTriColorShader::TriColorShaderContext*>(shaderContext);
+                } else {
+                    // The shader is a compose shader and triShader is its first shader.
+                    SkASSERT(p.getShader() == composeShader);
+                    SkASSERT(composeShader->getShaderA() == &triShader);
+                    SkComposeShader::ComposeShaderContext* composeShaderContext =
+                            static_cast<SkComposeShader::ComposeShaderContext*>(shaderContext);
+                    SkShader::Context* shaderContextA = composeShaderContext->getShaderContextA();
+                    triColorShaderContext =
+                            static_cast<SkTriColorShader::TriColorShaderContext*>(shaderContextA);
+                }
+
+                if (!triColorShaderContext->setup(vertices, colors,
+                                                  state.f0, state.f1, state.f2)) {
                     continue;
                 }
             }
@@ -2570,13 +2610,6 @@
         if (NULL != shader) {
             shader->setLocalMatrix(savedLocalM);
         }
-
-        // If the final call to setContext fails we must make it suceed so that the
-        // call to endContext in the destructor for SkAutoBlitterChoose is balanced.
-        if (!prevContextSuccess) {
-            prevContextSuccess = shader->setContext(*fBitmap, paint, SkMatrix::I());
-            SkASSERT(prevContextSuccess);
-        }
     } else {
         // no colors[] and no texture
         HairProc hairProc = ChooseHairProc(paint.isAntiAlias());
diff --git a/src/core/SkFilterShader.cpp b/src/core/SkFilterShader.cpp
index 5896191..5c5e8f3 100644
--- a/src/core/SkFilterShader.cpp
+++ b/src/core/SkFilterShader.cpp
@@ -38,9 +38,11 @@
     buffer.writeFlattenable(fFilter);
 }
 
-uint32_t SkFilterShader::getFlags() {
-    uint32_t shaderF = fShader->getFlags();
-    uint32_t filterF = fFilter->getFlags();
+uint32_t SkFilterShader::FilterShaderContext::getFlags() const {
+    const SkFilterShader& filterShader = static_cast<const SkFilterShader&>(fShader);
+
+    uint32_t shaderF = fShaderContext->getFlags();
+    uint32_t filterF = filterShader.fFilter->getFlags();
 
     // if the filter doesn't support 16bit, clear the matching bit in the shader
     if (!(filterF & SkColorFilter::kHasFilter16_Flag)) {
@@ -53,38 +55,62 @@
     return shaderF;
 }
 
-bool SkFilterShader::setContext(const SkBitmap& device,
-                                const SkPaint& paint,
-                                const SkMatrix& matrix) {
-    // we need to keep the setContext/endContext calls balanced. If we return
-    // false, our endContext() will not be called.
-
-    if (!this->INHERITED::setContext(device, paint, matrix)) {
-        return false;
+SkShader::Context* SkFilterShader::createContext(const SkBitmap& device,
+                                                 const SkPaint& paint,
+                                                 const SkMatrix& matrix,
+                                                 void* storage) const {
+    if (!this->validContext(device, paint, matrix)) {
+        return NULL;
     }
-    if (!fShader->setContext(device, paint, matrix)) {
-        this->INHERITED::endContext();
-        return false;
-    }
-    return true;
+
+    char* shaderContextStorage = (char*)storage + sizeof(FilterShaderContext);
+    SkShader::Context* shaderContext = fShader->createContext(device, paint, matrix,
+                                                              shaderContextStorage);
+    SkASSERT(shaderContext);
+
+    return SkNEW_PLACEMENT_ARGS(storage, FilterShaderContext,
+                                (*this, shaderContext, device, paint, matrix));
 }
 
-void SkFilterShader::endContext() {
-    fShader->endContext();
-    this->INHERITED::endContext();
+size_t SkFilterShader::contextSize() const {
+    return sizeof(FilterShaderContext) + fShader->contextSize();
 }
 
-void SkFilterShader::shadeSpan(int x, int y, SkPMColor result[], int count) {
-    fShader->shadeSpan(x, y, result, count);
-    fFilter->filterSpan(result, count, result);
+bool SkFilterShader::validContext(const SkBitmap& device,
+                                  const SkPaint& paint,
+                                  const SkMatrix& matrix,
+                                  SkMatrix* totalInverse) const {
+    return this->INHERITED::validContext(device, paint, matrix, totalInverse) &&
+           fShader->validContext(device, paint, matrix);
 }
 
-void SkFilterShader::shadeSpan16(int x, int y, uint16_t result[], int count) {
-    SkASSERT(fShader->getFlags() & SkShader::kHasSpan16_Flag);
-    SkASSERT(fFilter->getFlags() & SkColorFilter::kHasFilter16_Flag);
+SkFilterShader::FilterShaderContext::FilterShaderContext(const SkFilterShader& filterShader,
+                                                         SkShader::Context* shaderContext,
+                                                         const SkBitmap& device,
+                                                         const SkPaint& paint,
+                                                         const SkMatrix& matrix)
+    : INHERITED(filterShader, device, paint, matrix)
+    , fShaderContext(shaderContext) {}
 
-    fShader->shadeSpan16(x, y, result, count);
-    fFilter->filterSpan16(result, count, result);
+SkFilterShader::FilterShaderContext::~FilterShaderContext() {
+    fShaderContext->~Context();
+}
+
+void SkFilterShader::FilterShaderContext::shadeSpan(int x, int y, SkPMColor result[], int count) {
+    const SkFilterShader& filterShader = static_cast<const SkFilterShader&>(fShader);
+
+    fShaderContext->shadeSpan(x, y, result, count);
+    filterShader.fFilter->filterSpan(result, count, result);
+}
+
+void SkFilterShader::FilterShaderContext::shadeSpan16(int x, int y, uint16_t result[], int count) {
+    const SkFilterShader& filterShader = static_cast<const SkFilterShader&>(fShader);
+
+    SkASSERT(fShaderContext->getFlags() & SkShader::kHasSpan16_Flag);
+    SkASSERT(filterShader.fFilter->getFlags() & SkColorFilter::kHasFilter16_Flag);
+
+    fShaderContext->shadeSpan16(x, y, result, count);
+    filterShader.fFilter->filterSpan16(result, count, result);
 }
 
 #ifndef SK_IGNORE_TO_STRING
diff --git a/src/core/SkFilterShader.h b/src/core/SkFilterShader.h
index 11add0c..4ef4577 100644
--- a/src/core/SkFilterShader.h
+++ b/src/core/SkFilterShader.h
@@ -17,12 +17,29 @@
     SkFilterShader(SkShader* shader, SkColorFilter* filter);
     virtual ~SkFilterShader();
 
-    virtual uint32_t getFlags() SK_OVERRIDE;
-    virtual bool setContext(const SkBitmap&, const SkPaint&,
-                            const SkMatrix&) SK_OVERRIDE;
-    virtual void endContext() SK_OVERRIDE;
-    virtual void shadeSpan(int x, int y, SkPMColor[], int count) SK_OVERRIDE;
-    virtual void shadeSpan16(int x, int y, uint16_t[], int count) SK_OVERRIDE;
+    virtual bool validContext(const SkBitmap&, const SkPaint&,
+                              const SkMatrix&, SkMatrix* totalInverse = NULL) const SK_OVERRIDE;
+    virtual SkShader::Context* createContext(const SkBitmap&, const SkPaint&,
+                                             const SkMatrix&, void* storage) const SK_OVERRIDE;
+    virtual size_t contextSize() const SK_OVERRIDE;
+
+    class FilterShaderContext : public SkShader::Context {
+    public:
+        // Takes ownership of shaderContext and calls its destructor.
+        FilterShaderContext(const SkFilterShader& filterShader, SkShader::Context* shaderContext,
+                            const SkBitmap& device, const SkPaint& paint, const SkMatrix& matrix);
+        virtual ~FilterShaderContext();
+
+        virtual uint32_t getFlags() const SK_OVERRIDE;
+
+        virtual void shadeSpan(int x, int y, SkPMColor[], int count) SK_OVERRIDE;
+        virtual void shadeSpan16(int x, int y, uint16_t[], int count) SK_OVERRIDE;
+
+    private:
+        SkShader::Context* fShaderContext;
+
+        typedef SkShader::Context INHERITED;
+    };
 
     SK_TO_STRING_OVERRIDE()
     SK_DECLARE_PUBLIC_FLATTENABLE_DESERIALIZATION_PROCS(SkFilterShader)
diff --git a/src/core/SkPictureShader.cpp b/src/core/SkPictureShader.cpp
index bf31285..dc5c90b 100644
--- a/src/core/SkPictureShader.cpp
+++ b/src/core/SkPictureShader.cpp
@@ -49,7 +49,7 @@
     fPicture->flatten(buffer);
 }
 
-bool SkPictureShader::buildBitmapShader(const SkMatrix& matrix) const {
+SkShader* SkPictureShader::refBitmapShader(const SkMatrix& matrix) const {
     SkASSERT(fPicture && fPicture->width() > 0 && fPicture->height() > 0);
 
     SkMatrix m;
@@ -70,17 +70,20 @@
 
     SkISize tileSize = scaledSize.toRound();
     if (tileSize.isEmpty()) {
-        return false;
+        return NULL;
     }
 
     // The actual scale, compensating for rounding.
     SkSize tileScale = SkSize::Make(SkIntToScalar(tileSize.width()) / fPicture->width(),
                                     SkIntToScalar(tileSize.height()) / fPicture->height());
 
-    if (!fCachedShader || tileScale != fCachedTileScale) {
+    SkAutoMutexAcquire ama(fCachedBitmapShaderMutex);
+
+    if (!fCachedBitmapShader || tileScale != fCachedTileScale ||
+        this->getLocalMatrix() != fCachedLocalMatrix) {
         SkBitmap bm;
         if (!bm.allocN32Pixels(tileSize.width(), tileSize.height())) {
-            return false;
+            return NULL;
         }
         bm.eraseColor(SK_ColorTRANSPARENT);
 
@@ -88,66 +91,91 @@
         canvas.scale(tileScale.width(), tileScale.height());
         canvas.drawPicture(*fPicture);
 
-        fCachedShader.reset(CreateBitmapShader(bm, fTmx, fTmy));
+        fCachedBitmapShader.reset(CreateBitmapShader(bm, fTmx, fTmy));
         fCachedTileScale = tileScale;
+        fCachedLocalMatrix = this->getLocalMatrix();
+
+        SkMatrix shaderMatrix = this->getLocalMatrix();
+        shaderMatrix.preScale(1 / tileScale.width(), 1 / tileScale.height());
+        fCachedBitmapShader->setLocalMatrix(shaderMatrix);
     }
 
-    SkMatrix shaderMatrix = this->getLocalMatrix();
-    shaderMatrix.preScale(1 / tileScale.width(), 1 / tileScale.height());
-    fCachedShader->setLocalMatrix(shaderMatrix);
-
-    return true;
+    // Increment the ref counter inside the mutex to ensure the returned pointer is still valid.
+    // Otherwise, the pointer may have been overwritten on a different thread before the object's
+    // ref count was incremented.
+    fCachedBitmapShader.get()->ref();
+    return fCachedBitmapShader;
 }
 
-bool SkPictureShader::setContext(const SkBitmap& device,
-                                 const SkPaint& paint,
-                                 const SkMatrix& matrix) {
-    if (!this->buildBitmapShader(matrix)) {
-        return false;
+SkShader* SkPictureShader::validInternal(const SkBitmap& device, const SkPaint& paint,
+                                         const SkMatrix& matrix, SkMatrix* totalInverse) const {
+    if (!this->INHERITED::validContext(device, paint, matrix, totalInverse)) {
+        return NULL;
     }
 
-    if (!this->INHERITED::setContext(device, paint, matrix)) {
-        return false;
+    SkAutoTUnref<SkShader> bitmapShader(this->refBitmapShader(matrix));
+    if (!bitmapShader || !bitmapShader->validContext(device, paint, matrix)) {
+        return NULL;
     }
 
-    SkASSERT(fCachedShader);
-    if (!fCachedShader->setContext(device, paint, matrix)) {
-        this->INHERITED::endContext();
-        return false;
+    return bitmapShader.detach();
+}
+
+bool SkPictureShader::validContext(const SkBitmap& device, const SkPaint& paint,
+                                   const SkMatrix& matrix, SkMatrix* totalInverse) const {
+    SkAutoTUnref<SkShader> shader(this->validInternal(device, paint, matrix, totalInverse));
+    return shader != NULL;
+}
+
+SkShader::Context* SkPictureShader::createContext(const SkBitmap& device, const SkPaint& paint,
+                                                  const SkMatrix& matrix, void* storage) const {
+    SkAutoTUnref<SkShader> bitmapShader(this->validInternal(device, paint, matrix, NULL));
+    if (!bitmapShader) {
+        return NULL;
     }
 
-    return true;
+    return SkNEW_PLACEMENT_ARGS(storage, PictureShaderContext,
+                                (*this, device, paint, matrix, bitmapShader.detach()));
 }
 
-void SkPictureShader::endContext() {
-    SkASSERT(fCachedShader);
-    fCachedShader->endContext();
-
-    this->INHERITED::endContext();
+size_t SkPictureShader::contextSize() const {
+    return sizeof(PictureShaderContext);
 }
 
-uint32_t SkPictureShader::getFlags() {
-    if (NULL != fCachedShader) {
-        return fCachedShader->getFlags();
-    }
-    return 0;
+SkPictureShader::PictureShaderContext::PictureShaderContext(
+        const SkPictureShader& shader, const SkBitmap& device,
+        const SkPaint& paint, const SkMatrix& matrix, SkShader* bitmapShader)
+    : INHERITED(shader, device, paint, matrix)
+    , fBitmapShader(bitmapShader)
+{
+    SkASSERT(fBitmapShader);
+    fBitmapShaderContextStorage = sk_malloc_throw(fBitmapShader->contextSize());
+    fBitmapShaderContext = fBitmapShader->createContext(
+            device, paint, matrix, fBitmapShaderContextStorage);
+    SkASSERT(fBitmapShaderContext);
 }
 
-SkShader::ShadeProc SkPictureShader::asAShadeProc(void** ctx) {
-    if (fCachedShader) {
-        return fCachedShader->asAShadeProc(ctx);
-    }
-    return NULL;
+SkPictureShader::PictureShaderContext::~PictureShaderContext() {
+    fBitmapShaderContext->~Context();
+    sk_free(fBitmapShaderContextStorage);
 }
 
-void SkPictureShader::shadeSpan(int x, int y, SkPMColor dstC[], int count) {
-    SkASSERT(fCachedShader);
-    fCachedShader->shadeSpan(x, y, dstC, count);
+uint32_t SkPictureShader::PictureShaderContext::getFlags() const {
+    return fBitmapShaderContext->getFlags();
 }
 
-void SkPictureShader::shadeSpan16(int x, int y, uint16_t dstC[], int count) {
-    SkASSERT(fCachedShader);
-    fCachedShader->shadeSpan16(x, y, dstC, count);
+SkShader::Context::ShadeProc SkPictureShader::PictureShaderContext::asAShadeProc(void** ctx) {
+    return fBitmapShaderContext->asAShadeProc(ctx);
+}
+
+void SkPictureShader::PictureShaderContext::shadeSpan(int x, int y, SkPMColor dstC[], int count) {
+    SkASSERT(fBitmapShaderContext);
+    fBitmapShaderContext->shadeSpan(x, y, dstC, count);
+}
+
+void SkPictureShader::PictureShaderContext::shadeSpan16(int x, int y, uint16_t dstC[], int count) {
+    SkASSERT(fBitmapShaderContext);
+    fBitmapShaderContext->shadeSpan16(x, y, dstC, count);
 }
 
 #ifndef SK_IGNORE_TO_STRING
@@ -168,10 +196,10 @@
 
 #if SK_SUPPORT_GPU
 GrEffectRef* SkPictureShader::asNewEffect(GrContext* context, const SkPaint& paint) const {
-    if (!this->buildBitmapShader(context->getMatrix())) {
+    SkAutoTUnref<SkShader> bitmapShader(this->refBitmapShader(context->getMatrix()));
+    if (!bitmapShader) {
         return NULL;
     }
-    SkASSERT(fCachedShader);
-    return fCachedShader->asNewEffect(context, paint);
+    return bitmapShader->asNewEffect(context, paint);
 }
 #endif
diff --git a/src/core/SkPictureShader.h b/src/core/SkPictureShader.h
index ea74b56..d1be059 100644
--- a/src/core/SkPictureShader.h
+++ b/src/core/SkPictureShader.h
@@ -24,13 +24,33 @@
     static SkPictureShader* Create(SkPicture*, TileMode, TileMode);
     virtual ~SkPictureShader();
 
-    virtual bool setContext(const SkBitmap&, const SkPaint&, const SkMatrix&) SK_OVERRIDE;
-    virtual void endContext() SK_OVERRIDE;
-    virtual uint32_t getFlags() SK_OVERRIDE;
+    virtual bool validContext(const SkBitmap&, const SkPaint&,
+                              const SkMatrix&, SkMatrix* totalInverse = NULL) const SK_OVERRIDE;
+    virtual SkShader::Context* createContext(const SkBitmap& device, const SkPaint& paint,
+                                             const SkMatrix& matrix, void* storage) const
+            SK_OVERRIDE;
+    virtual size_t contextSize() const SK_OVERRIDE;
 
-    virtual ShadeProc asAShadeProc(void** ctx) SK_OVERRIDE;
-    virtual void shadeSpan(int x, int y, SkPMColor dstC[], int count) SK_OVERRIDE;
-    virtual void shadeSpan16(int x, int y, uint16_t dstC[], int count) SK_OVERRIDE;
+    class PictureShaderContext : public SkShader::Context {
+    public:
+        PictureShaderContext(const SkPictureShader& shader, const SkBitmap& device,
+                             const SkPaint& paint, const SkMatrix& matrix,
+                             SkShader* bitmapShader);
+        virtual ~PictureShaderContext();
+
+        virtual uint32_t getFlags() const SK_OVERRIDE;
+
+        virtual ShadeProc asAShadeProc(void** ctx) SK_OVERRIDE;
+        virtual void shadeSpan(int x, int y, SkPMColor dstC[], int count) SK_OVERRIDE;
+        virtual void shadeSpan16(int x, int y, uint16_t dstC[], int count) SK_OVERRIDE;
+
+    private:
+        SkAutoTUnref<SkShader>  fBitmapShader;
+        SkShader::Context*      fBitmapShaderContext;
+        void*                   fBitmapShaderContextStorage;
+
+        typedef SkShader::Context INHERITED;
+    };
 
     SK_TO_STRING_OVERRIDE()
     SK_DECLARE_PUBLIC_FLATTENABLE_DESERIALIZATION_PROCS(SkPictureShader)
@@ -46,13 +66,18 @@
 private:
     SkPictureShader(SkPicture*, TileMode, TileMode);
 
-    bool buildBitmapShader(const SkMatrix&) const;
+    SkShader* validInternal(const SkBitmap& device, const SkPaint& paint,
+                            const SkMatrix& matrix, SkMatrix* totalInverse) const;
+
+    SkShader* refBitmapShader(const SkMatrix&) const;
 
     SkPicture*  fPicture;
     TileMode    fTmx, fTmy;
 
-    mutable SkAutoTUnref<SkShader>  fCachedShader;
+    mutable SkMutex                 fCachedBitmapShaderMutex;
+    mutable SkAutoTUnref<SkShader>  fCachedBitmapShader;
     mutable SkSize                  fCachedTileScale;
+    mutable SkMatrix                fCachedLocalMatrix;
 
     typedef SkShader INHERITED;
 };
diff --git a/src/core/SkShader.cpp b/src/core/SkShader.cpp
index e337b7d..40e52a0 100644
--- a/src/core/SkShader.cpp
+++ b/src/core/SkShader.cpp
@@ -17,7 +17,6 @@
 
 SkShader::SkShader() {
     fLocalMatrix.reset();
-    SkDEBUGCODE(fInSetContext = false;)
 }
 
 SkShader::SkShader(SkReadBuffer& buffer)
@@ -27,12 +26,9 @@
     } else {
         fLocalMatrix.reset();
     }
-
-    SkDEBUGCODE(fInSetContext = false;)
 }
 
 SkShader::~SkShader() {
-    SkASSERT(!fInSetContext);
 }
 
 void SkShader::flatten(SkWriteBuffer& buffer) const {
@@ -44,39 +40,48 @@
     }
 }
 
-bool SkShader::setContext(const SkBitmap& device,
-                          const SkPaint& paint,
-                          const SkMatrix& matrix) {
-    SkASSERT(!this->setContextHasBeenCalled());
-
+bool SkShader::computeTotalInverse(const SkMatrix& matrix, SkMatrix* totalInverse) const {
     const SkMatrix* m = &matrix;
     SkMatrix        total;
 
-    fPaintAlpha = paint.getAlpha();
     if (this->hasLocalMatrix()) {
         total.setConcat(matrix, this->getLocalMatrix());
         m = &total;
     }
-    if (m->invert(&fTotalInverse)) {
-        fTotalInverseClass = (uint8_t)ComputeMatrixClass(fTotalInverse);
-        SkDEBUGCODE(fInSetContext = true;)
-        return true;
-    }
-    return false;
+
+    return m->invert(totalInverse);
 }
 
-void SkShader::endContext() {
-    SkASSERT(fInSetContext);
-    SkDEBUGCODE(fInSetContext = false;)
+bool SkShader::validContext(const SkBitmap& device,
+                            const SkPaint& paint,
+                            const SkMatrix& matrix,
+                            SkMatrix* totalInverse) const {
+    return this->computeTotalInverse(matrix, totalInverse);
 }
 
-SkShader::ShadeProc SkShader::asAShadeProc(void** ctx) {
+SkShader::Context::Context(const SkShader& shader, const SkBitmap& device,
+                           const SkPaint& paint, const SkMatrix& matrix)
+    : fShader(shader)
+{
+    SkASSERT(fShader.validContext(device, paint, matrix));
+
+    // Because the context parameters must be valid at this point, we know that the matrix is
+    // invertible.
+    SkAssertResult(fShader.computeTotalInverse(matrix, &fTotalInverse));
+    fTotalInverseClass = (uint8_t)ComputeMatrixClass(fTotalInverse);
+
+    fPaintAlpha = paint.getAlpha();
+}
+
+SkShader::Context::~Context() {}
+
+SkShader::Context::ShadeProc SkShader::Context::asAShadeProc(void** ctx) {
     return NULL;
 }
 
 #include "SkColorPriv.h"
 
-void SkShader::shadeSpan16(int x, int y, uint16_t span16[], int count) {
+void SkShader::Context::shadeSpan16(int x, int y, uint16_t span16[], int count) {
     SkASSERT(span16);
     SkASSERT(count > 0);
     SkASSERT(this->canCallShadeSpan16());
@@ -94,7 +99,7 @@
     #define SkU32BitShiftToByteOffset(shift)    ((shift) >> 3)
 #endif
 
-void SkShader::shadeSpanAlpha(int x, int y, uint8_t alpha[], int count) {
+void SkShader::Context::shadeSpanAlpha(int x, int y, uint8_t alpha[], int count) {
     SkASSERT(count > 0);
 
     SkPMColor   colors[kTempColorCount];
@@ -148,7 +153,7 @@
 #endif
 }
 
-SkShader::MatrixClass SkShader::ComputeMatrixClass(const SkMatrix& mat) {
+SkShader::Context::MatrixClass SkShader::Context::ComputeMatrixClass(const SkMatrix& mat) {
     MatrixClass mc = kLinear_MatrixClass;
 
     if (mat.hasPerspective()) {
@@ -163,8 +168,7 @@
 
 //////////////////////////////////////////////////////////////////////////////
 
-SkShader::BitmapType SkShader::asABitmap(SkBitmap*, SkMatrix*,
-                                         TileMode*) const {
+SkShader::BitmapType SkShader::asABitmap(SkBitmap*, SkMatrix*, TileMode*) const {
     return kNone_BitmapType;
 }
 
@@ -199,19 +203,16 @@
 #include "SkColorShader.h"
 #include "SkUtils.h"
 
-SkColorShader::SkColorShader() {
-    fFlags = 0;
-    fInheritColor = true;
+SkColorShader::SkColorShader()
+    : fColor()
+    , fInheritColor(true) {
 }
 
-SkColorShader::SkColorShader(SkColor c) {
-    fFlags = 0;
-    fColor = c;
-    fInheritColor = false;
+SkColorShader::SkColorShader(SkColor c)
+    : fColor(c)
+    , fInheritColor(false) {
 }
 
-SkColorShader::~SkColorShader() {}
-
 bool SkColorShader::isOpaque() const {
     if (fInheritColor) {
         return true; // using paint's alpha
@@ -220,8 +221,6 @@
 }
 
 SkColorShader::SkColorShader(SkReadBuffer& b) : INHERITED(b) {
-    fFlags = 0; // computed in setContext
-
     fInheritColor = b.readBool();
     if (fInheritColor) {
         return;
@@ -238,32 +237,43 @@
     buffer.writeColor(fColor);
 }
 
-uint32_t SkColorShader::getFlags() {
+uint32_t SkColorShader::ColorShaderContext::getFlags() const {
     return fFlags;
 }
 
-uint8_t SkColorShader::getSpan16Alpha() const {
+uint8_t SkColorShader::ColorShaderContext::getSpan16Alpha() const {
     return SkGetPackedA32(fPMColor);
 }
 
-bool SkColorShader::setContext(const SkBitmap& device, const SkPaint& paint,
-                               const SkMatrix& matrix) {
-    if (!this->INHERITED::setContext(device, paint, matrix)) {
-        return false;
+SkShader::Context* SkColorShader::createContext(const SkBitmap& device, const SkPaint& paint,
+                                                const SkMatrix& matrix, void* storage) const {
+    if (!this->validContext(device, paint, matrix)) {
+        return NULL;
     }
 
+    return SkNEW_PLACEMENT_ARGS(storage, ColorShaderContext, (*this, device, paint, matrix));
+}
+
+SkColorShader::ColorShaderContext::ColorShaderContext(const SkColorShader& shader,
+                                                      const SkBitmap& device,
+                                                      const SkPaint& paint,
+                                                      const SkMatrix& matrix)
+    : INHERITED(shader, device, paint, matrix)
+{
     unsigned a;
 
-    if (fInheritColor) {
-        fColor = paint.getColor();
-        a = SkColorGetA(fColor);
+    SkColor color;
+    if (shader.fInheritColor) {
+        color = paint.getColor();
+        a = SkColorGetA(color);
     } else {
-        a = SkAlphaMul(SkColorGetA(fColor), SkAlpha255To256(paint.getAlpha()));
+        color = shader.fColor;
+        a = SkAlphaMul(SkColorGetA(color), SkAlpha255To256(paint.getAlpha()));
     }
 
-    unsigned r = SkColorGetR(fColor);
-    unsigned g = SkColorGetG(fColor);
-    unsigned b = SkColorGetB(fColor);
+    unsigned r = SkColorGetR(color);
+    unsigned g = SkColorGetG(color);
+    unsigned b = SkColorGetB(color);
 
     // we want this before we apply any alpha
     fColor16 = SkPack888ToRGB16(r, g, b);
@@ -282,19 +292,17 @@
             fFlags |= kHasSpan16_Flag;
         }
     }
-
-    return true;
 }
 
-void SkColorShader::shadeSpan(int x, int y, SkPMColor span[], int count) {
+void SkColorShader::ColorShaderContext::shadeSpan(int x, int y, SkPMColor span[], int count) {
     sk_memset32(span, fPMColor, count);
 }
 
-void SkColorShader::shadeSpan16(int x, int y, uint16_t span[], int count) {
+void SkColorShader::ColorShaderContext::shadeSpan16(int x, int y, uint16_t span[], int count) {
     sk_memset16(span, fColor16, count);
 }
 
-void SkColorShader::shadeSpanAlpha(int x, int y, uint8_t alpha[], int count) {
+void SkColorShader::ColorShaderContext::shadeSpanAlpha(int x, int y, uint8_t alpha[], int count) {
     memset(alpha, SkGetPackedA32(fPMColor), count);
 }
 
@@ -334,27 +342,9 @@
 
 ///////////////////////////////////////////////////////////////////////////////
 
+#ifndef SK_IGNORE_TO_STRING
 #include "SkEmptyShader.h"
 
-uint32_t SkEmptyShader::getFlags() { return 0; }
-uint8_t SkEmptyShader::getSpan16Alpha() const { return 0; }
-
-bool SkEmptyShader::setContext(const SkBitmap&, const SkPaint&,
-                               const SkMatrix&) { return false; }
-
-void SkEmptyShader::shadeSpan(int x, int y, SkPMColor span[], int count) {
-    SkDEBUGFAIL("should never get called, since setContext() returned false");
-}
-
-void SkEmptyShader::shadeSpan16(int x, int y, uint16_t span[], int count) {
-    SkDEBUGFAIL("should never get called, since setContext() returned false");
-}
-
-void SkEmptyShader::shadeSpanAlpha(int x, int y, uint8_t alpha[], int count) {
-    SkDEBUGFAIL("should never get called, since setContext() returned false");
-}
-
-#ifndef SK_IGNORE_TO_STRING
 void SkEmptyShader::toString(SkString* str) const {
     str->append("SkEmptyShader: (");
 
diff --git a/src/core/SkSmallAllocator.h b/src/core/SkSmallAllocator.h
index 655008b..8d4b53a 100644
--- a/src/core/SkSmallAllocator.h
+++ b/src/core/SkSmallAllocator.h
@@ -117,10 +117,12 @@
             // but we're not sure we can catch all callers, so handle it but
             // assert false in debug mode.
             SkASSERT(false);
+            rec->fStorageSize = 0;
             rec->fHeapStorage = sk_malloc_throw(storageRequired);
             rec->fObj = static_cast<void*>(rec->fHeapStorage);
         } else {
             // There is space in fStorage.
+            rec->fStorageSize = storageRequired;
             rec->fHeapStorage = NULL;
             SkASSERT(SkIsAlign4(fStorageUsed));
             rec->fObj = static_cast<void*>(fStorage + (fStorageUsed / 4));
@@ -131,11 +133,26 @@
         return rec->fObj;
     }
 
+    /*
+     *  Free the memory reserved last without calling the destructor.
+     *  Can be used in a nested way, i.e. after reserving A and B, calling
+     *  freeLast once will free B and calling it again will free A.
+     */
+    void freeLast() {
+        SkASSERT(fNumObjects > 0);
+        Rec* rec = &fRecs[fNumObjects - 1];
+        sk_free(rec->fHeapStorage);
+        fStorageUsed -= rec->fStorageSize;
+
+        fNumObjects--;
+    }
+
 private:
     struct Rec {
-        void* fObj;
-        void* fHeapStorage;
-        void  (*fKillProc)(void*);
+        size_t fStorageSize;  // 0 if allocated on heap
+        void*  fObj;
+        void*  fHeapStorage;
+        void   (*fKillProc)(void*);
     };
 
     // Number of bytes used so far.
diff --git a/src/effects/SkPerlinNoiseShader.cpp b/src/effects/SkPerlinNoiseShader.cpp
index ed63faf..5adb582 100644
--- a/src/effects/SkPerlinNoiseShader.cpp
+++ b/src/effects/SkPerlinNoiseShader.cpp
@@ -278,7 +278,6 @@
   , fStitchTiles(!fTileSize.isEmpty())
 {
     SkASSERT(numOctaves >= 0 && numOctaves < 256);
-    fMatrix.reset();
     fPaintingData = SkNEW_ARGS(PaintingData, (fTileSize, fSeed, fBaseFrequencyX, fBaseFrequencyY));
 }
 
@@ -293,7 +292,6 @@
     fStitchTiles    = buffer.readBool();
     fTileSize.fWidth  = buffer.readInt();
     fTileSize.fHeight = buffer.readInt();
-    fMatrix.reset();
     fPaintingData = SkNEW_ARGS(PaintingData, (fTileSize, fSeed, fBaseFrequencyX, fBaseFrequencyY));
     buffer.validate(perlin_noise_type_is_valid(fType) &&
                     (fNumOctaves >= 0) && (fNumOctaves <= 255) &&
@@ -317,9 +315,9 @@
     buffer.writeInt(fTileSize.fHeight);
 }
 
-SkScalar SkPerlinNoiseShader::noise2D(int channel, const PaintingData& paintingData,
-                                      const StitchData& stitchData,
-                                      const SkPoint& noiseVector) const {
+SkScalar SkPerlinNoiseShader::PerlinNoiseShaderContext::noise2D(
+        int channel, const PaintingData& paintingData,
+        const StitchData& stitchData, const SkPoint& noiseVector) const {
     struct Noise {
         int noisePositionIntegerValue;
         SkScalar noisePositionFractionValue;
@@ -333,8 +331,9 @@
     Noise noiseX(noiseVector.x());
     Noise noiseY(noiseVector.y());
     SkScalar u, v;
+    const SkPerlinNoiseShader& perlinNoiseShader = static_cast<const SkPerlinNoiseShader&>(fShader);
     // If stitching, adjust lattice points accordingly.
-    if (fStitchTiles) {
+    if (perlinNoiseShader.fStitchTiles) {
         noiseX.noisePositionIntegerValue =
             checkNoise(noiseX.noisePositionIntegerValue, stitchData.fWrapX, stitchData.fWidth);
         noiseY.noisePositionIntegerValue =
@@ -365,11 +364,11 @@
     return SkScalarInterp(a, b, sy);
 }
 
-SkScalar SkPerlinNoiseShader::calculateTurbulenceValueForPoint(int channel,
-                                                               const PaintingData& paintingData,
-                                                               StitchData& stitchData,
-                                                               const SkPoint& point) const {
-    if (fStitchTiles) {
+SkScalar SkPerlinNoiseShader::PerlinNoiseShaderContext::calculateTurbulenceValueForPoint(
+        int channel, const PaintingData& paintingData,
+        StitchData& stitchData, const SkPoint& point) const {
+    const SkPerlinNoiseShader& perlinNoiseShader = static_cast<const SkPerlinNoiseShader&>(fShader);
+    if (perlinNoiseShader.fStitchTiles) {
         // Set up TurbulenceInitial stitch values.
         stitchData = paintingData.fStitchDataInit;
     }
@@ -377,14 +376,14 @@
     SkPoint noiseVector(SkPoint::Make(SkScalarMul(point.x(), paintingData.fBaseFrequency.fX),
                                       SkScalarMul(point.y(), paintingData.fBaseFrequency.fY)));
     SkScalar ratio = SK_Scalar1;
-    for (int octave = 0; octave < fNumOctaves; ++octave) {
+    for (int octave = 0; octave < perlinNoiseShader.fNumOctaves; ++octave) {
         SkScalar noise = noise2D(channel, paintingData, stitchData, noiseVector);
         turbulenceFunctionResult += SkScalarDiv(
-            (fType == kFractalNoise_Type) ? noise : SkScalarAbs(noise), ratio);
+            (perlinNoiseShader.fType == kFractalNoise_Type) ? noise : SkScalarAbs(noise), ratio);
         noiseVector.fX *= 2;
         noiseVector.fY *= 2;
         ratio *= 2;
-        if (fStitchTiles) {
+        if (perlinNoiseShader.fStitchTiles) {
             // Update stitch values
             stitchData.fWidth  *= 2;
             stitchData.fWrapX   = stitchData.fWidth + kPerlinNoise;
@@ -395,7 +394,7 @@
 
     // The value of turbulenceFunctionResult comes from ((turbulenceFunctionResult) + 1) / 2
     // by fractalNoise and (turbulenceFunctionResult) by turbulence.
-    if (fType == kFractalNoise_Type) {
+    if (perlinNoiseShader.fType == kFractalNoise_Type) {
         turbulenceFunctionResult =
             SkScalarMul(turbulenceFunctionResult, SK_ScalarHalf) + SK_ScalarHalf;
     }
@@ -409,7 +408,9 @@
     return SkScalarPin(turbulenceFunctionResult, 0, SK_Scalar1);
 }
 
-SkPMColor SkPerlinNoiseShader::shade(const SkPoint& point, StitchData& stitchData) const {
+SkPMColor SkPerlinNoiseShader::PerlinNoiseShaderContext::shade(
+        const SkPoint& point, StitchData& stitchData) const {
+    const SkPerlinNoiseShader& perlinNoiseShader = static_cast<const SkPerlinNoiseShader&>(fShader);
     SkPoint newPoint;
     fMatrix.mapPoints(&newPoint, &point, 1);
     newPoint.fX = SkScalarRoundToScalar(newPoint.fX);
@@ -418,15 +419,32 @@
     U8CPU rgba[4];
     for (int channel = 3; channel >= 0; --channel) {
         rgba[channel] = SkScalarFloorToInt(255 *
-            calculateTurbulenceValueForPoint(channel, *fPaintingData, stitchData, newPoint));
+            calculateTurbulenceValueForPoint(channel, *perlinNoiseShader.fPaintingData,
+                                             stitchData, newPoint));
     }
     return SkPreMultiplyARGB(rgba[3], rgba[0], rgba[1], rgba[2]);
 }
 
-bool SkPerlinNoiseShader::setContext(const SkBitmap& device, const SkPaint& paint,
-                                     const SkMatrix& matrix) {
+SkShader::Context* SkPerlinNoiseShader::createContext(const SkBitmap& device, const SkPaint& paint,
+                                                      const SkMatrix& matrix, void* storage) const {
+    if (!this->validContext(device, paint, matrix)) {
+        return NULL;
+    }
+
+    return SkNEW_PLACEMENT_ARGS(storage, PerlinNoiseShaderContext, (*this, device, paint, matrix));
+}
+
+size_t SkPerlinNoiseShader::contextSize() const {
+    return sizeof(PerlinNoiseShaderContext);
+}
+
+SkPerlinNoiseShader::PerlinNoiseShaderContext::PerlinNoiseShaderContext(
+        const SkPerlinNoiseShader& shader, const SkBitmap& device,
+        const SkPaint& paint, const SkMatrix& matrix)
+    : INHERITED(shader, device, paint, matrix)
+{
     SkMatrix newMatrix = matrix;
-    newMatrix.postConcat(getLocalMatrix());
+    newMatrix.postConcat(shader.getLocalMatrix());
     SkMatrix invMatrix;
     if (!newMatrix.invert(&invMatrix)) {
         invMatrix.reset();
@@ -437,10 +455,10 @@
     newMatrix.postConcat(invMatrix);
     newMatrix.postConcat(invMatrix);
     fMatrix = newMatrix;
-    return INHERITED::setContext(device, paint, matrix);
 }
 
-void SkPerlinNoiseShader::shadeSpan(int x, int y, SkPMColor result[], int count) {
+void SkPerlinNoiseShader::PerlinNoiseShaderContext::shadeSpan(
+        int x, int y, SkPMColor result[], int count) {
     SkPoint point = SkPoint::Make(SkIntToScalar(x), SkIntToScalar(y));
     StitchData stitchData;
     for (int i = 0; i < count; ++i) {
@@ -449,7 +467,8 @@
     }
 }
 
-void SkPerlinNoiseShader::shadeSpan16(int x, int y, uint16_t result[], int count) {
+void SkPerlinNoiseShader::PerlinNoiseShaderContext::shadeSpan16(
+        int x, int y, uint16_t result[], int count) {
     SkPoint point = SkPoint::Make(SkIntToScalar(x), SkIntToScalar(y));
     StitchData stitchData;
     DITHER_565_SCAN(y);
diff --git a/src/effects/SkTransparentShader.cpp b/src/effects/SkTransparentShader.cpp
index bd8b99a..0997e62 100644
--- a/src/effects/SkTransparentShader.cpp
+++ b/src/effects/SkTransparentShader.cpp
@@ -11,26 +11,40 @@
 #include "SkColorPriv.h"
 #include "SkString.h"
 
-bool SkTransparentShader::setContext(const SkBitmap& device,
-                                     const SkPaint& paint,
-                                     const SkMatrix& matrix) {
-    fDevice = &device;
-    fAlpha = paint.getAlpha();
+SkShader::Context* SkTransparentShader::createContext(const SkBitmap& device,
+                                                      const SkPaint& paint,
+                                                      const SkMatrix& matrix,
+                                                      void* storage) const {
+    if (!this->validContext(device, paint, matrix)) {
+        return NULL;
+    }
 
-    return this->INHERITED::setContext(device, paint, matrix);
+    return SkNEW_PLACEMENT_ARGS(storage, TransparentShaderContext, (*this, device, paint, matrix));
 }
 
-uint32_t SkTransparentShader::getFlags() {
+size_t SkTransparentShader::contextSize() const {
+    return sizeof(TransparentShaderContext);
+}
+
+SkTransparentShader::TransparentShaderContext::TransparentShaderContext(
+        const SkTransparentShader& shader, const SkBitmap& device,
+        const SkPaint& paint, const SkMatrix& matrix)
+    : INHERITED(shader, device, paint, matrix)
+    , fDevice(&device) {}
+
+SkTransparentShader::TransparentShaderContext::~TransparentShaderContext() {}
+
+uint32_t SkTransparentShader::TransparentShaderContext::getFlags() const {
     uint32_t flags = this->INHERITED::getFlags();
 
     switch (fDevice->colorType()) {
         case kRGB_565_SkColorType:
             flags |= kHasSpan16_Flag;
-            if (fAlpha == 255)
+            if (this->getPaintAlpha() == 255)
                 flags |= kOpaqueAlpha_Flag;
             break;
         case kN32_SkColorType:
-            if (fAlpha == 255 && fDevice->isOpaque())
+            if (this->getPaintAlpha() == 255 && fDevice->isOpaque())
                 flags |= kOpaqueAlpha_Flag;
             break;
         default:
@@ -39,8 +53,9 @@
     return flags;
 }
 
-void SkTransparentShader::shadeSpan(int x, int y, SkPMColor span[], int count) {
-    unsigned scale = SkAlpha255To256(fAlpha);
+void SkTransparentShader::TransparentShaderContext::shadeSpan(int x, int y, SkPMColor span[],
+                                                              int count) {
+    unsigned scale = SkAlpha255To256(this->getPaintAlpha());
 
     switch (fDevice->colorType()) {
         case kN32_SkColorType:
@@ -63,7 +78,7 @@
                     span[i] = SkPixel16ToPixel32(src[i]);
                 }
             } else {
-                unsigned alpha = fAlpha;
+                unsigned alpha = this->getPaintAlpha();
                 for (int i = count - 1; i >= 0; --i) {
                     uint16_t c = src[i];
                     unsigned r = SkPacked16ToR32(c);
@@ -97,7 +112,8 @@
     }
 }
 
-void SkTransparentShader::shadeSpan16(int x, int y, uint16_t span[], int count) {
+void SkTransparentShader::TransparentShaderContext::shadeSpan16(int x, int y, uint16_t span[],
+                                                                int count) {
     SkASSERT(fDevice->colorType() == kRGB_565_SkColorType);
 
     uint16_t* src = fDevice->getAddr16(x, y);
diff --git a/src/effects/gradients/SkGradientShader.cpp b/src/effects/gradients/SkGradientShader.cpp
index 2e92076..46e0c95 100644
--- a/src/effects/gradients/SkGradientShader.cpp
+++ b/src/effects/gradients/SkGradientShader.cpp
@@ -15,8 +15,6 @@
 SkGradientShaderBase::SkGradientShaderBase(const Descriptor& desc) {
     SkASSERT(desc.fCount > 1);
 
-    fCacheAlpha = 256;  // init to a value that paint.getAlpha() can't return
-
     fMapper = desc.fMapper;
     SkSafeRef(fMapper);
     fGradFlags = SkToU8(desc.fGradFlags);
@@ -26,10 +24,6 @@
     fTileMode = desc.fTileMode;
     fTileProc = gTileProcs[desc.fTileMode];
 
-    fCache16 = fCache16Storage = NULL;
-    fCache32 = NULL;
-    fCache32PixelRef = NULL;
-
     /*  Note: we let the caller skip the first and/or last position.
         i.e. pos[0] = 0.3, pos[1] = 0.7
         In these cases, we insert dummy entries to ensure that the final data
@@ -146,14 +140,8 @@
 }
 
 SkGradientShaderBase::SkGradientShaderBase(SkReadBuffer& buffer) : INHERITED(buffer) {
-    fCacheAlpha = 256;
-
     fMapper = buffer.readUnitMapper();
 
-    fCache16 = fCache16Storage = NULL;
-    fCache32 = NULL;
-    fCache32PixelRef = NULL;
-
     int colorCount = fColorCount = buffer.getArrayCount();
     if (colorCount > kColorStorageCount) {
         size_t allocSize = (sizeof(SkColor) + sizeof(SkPMColor) + sizeof(Rec)) * colorCount;
@@ -188,10 +176,6 @@
 }
 
 SkGradientShaderBase::~SkGradientShaderBase() {
-    if (fCache16Storage) {
-        sk_free(fCache16Storage);
-    }
-    SkSafeUnref(fCache32PixelRef);
     if (fOrigColors != fStorage) {
         sk_free(fOrigColors);
     }
@@ -199,7 +183,6 @@
 }
 
 void SkGradientShaderBase::initCommon() {
-    fFlags = 0;
     unsigned colorAlpha = 0xFF;
     for (int i = 0; i < fColorCount; i++) {
         colorAlpha &= SkColorGetA(fOrigColors[i]);
@@ -267,49 +250,50 @@
     return fColorsAreOpaque;
 }
 
-bool SkGradientShaderBase::setContext(const SkBitmap& device,
-                                 const SkPaint& paint,
-                                 const SkMatrix& matrix) {
-    if (!this->INHERITED::setContext(device, paint, matrix)) {
-        return false;
-    }
-
+SkGradientShaderBase::GradientShaderBaseContext::GradientShaderBaseContext(
+        const SkGradientShaderBase& shader, const SkBitmap& device,
+        const SkPaint& paint, const SkMatrix& matrix)
+    : INHERITED(shader, device, paint, matrix)
+    , fCache(shader.refCache(getPaintAlpha()))
+{
     const SkMatrix& inverse = this->getTotalInverse();
 
-    fDstToIndex.setConcat(fPtsToUnit, inverse);
+    fDstToIndex.setConcat(shader.fPtsToUnit, inverse);
+
     fDstToIndexProc = fDstToIndex.getMapXYProc();
-    fDstToIndexClass = (uint8_t)SkShader::ComputeMatrixClass(fDstToIndex);
+    fDstToIndexClass = (uint8_t)SkShader::Context::ComputeMatrixClass(fDstToIndex);
 
     // now convert our colors in to PMColors
     unsigned paintAlpha = this->getPaintAlpha();
 
     fFlags = this->INHERITED::getFlags();
-    if (fColorsAreOpaque && paintAlpha == 0xFF) {
+    if (shader.fColorsAreOpaque && paintAlpha == 0xFF) {
         fFlags |= kOpaqueAlpha_Flag;
     }
     // we can do span16 as long as our individual colors are opaque,
     // regardless of the paint's alpha
-    if (fColorsAreOpaque) {
+    if (shader.fColorsAreOpaque) {
         fFlags |= kHasSpan16_Flag;
     }
-
-    this->setCacheAlpha(paintAlpha);
-    return true;
 }
 
-void SkGradientShaderBase::setCacheAlpha(U8CPU alpha) const {
-    // if the new alpha differs from the previous time we were called, inval our cache
-    // this will trigger the cache to be rebuilt.
-    // we don't care about the first time, since the cache ptrs will already be NULL
-    if (fCacheAlpha != alpha) {
-        fCache16 = NULL;            // inval the cache
-        fCache32 = NULL;            // inval the cache
-        fCacheAlpha = alpha;        // record the new alpha
-        // inform our subclasses
-        if (fCache32PixelRef) {
-            fCache32PixelRef->notifyPixelsChanged();
-        }
-    }
+SkGradientShaderBase::GradientShaderCache::GradientShaderCache(
+        U8CPU alpha, const SkGradientShaderBase& shader)
+    : fCacheAlpha(alpha)
+    , fShader(shader)
+    , fCache16Inited(false)
+    , fCache32Inited(false)
+{
+    // Only initialize the cache in getCache16/32.
+    fCache16 = NULL;
+    fCache32 = NULL;
+    fCache16Storage = NULL;
+    fCache32PixelRef = NULL;
+}
+
+SkGradientShaderBase::GradientShaderCache::~GradientShaderCache() {
+    sk_free(fCache16Storage);
+    SkSafeUnref(fCache32PixelRef);
 }
 
 #define Fixed_To_Dot8(x)        (((x) + 0x80) >> 8)
@@ -318,8 +302,8 @@
     build a 16bit table as long as the original colors are opaque, even if the
     paint specifies a non-opaque alpha.
 */
-void SkGradientShaderBase::Build16bitCache(uint16_t cache[], SkColor c0, SkColor c1,
-                                      int count) {
+void SkGradientShaderBase::GradientShaderCache::Build16bitCache(
+        uint16_t cache[], SkColor c0, SkColor c1, int count) {
     SkASSERT(count > 1);
     SkASSERT(SkColorGetA(c0) == 0xFF);
     SkASSERT(SkColorGetA(c1) == 0xFF);
@@ -367,8 +351,9 @@
  */
 typedef uint32_t SkUFixed;
 
-void SkGradientShaderBase::Build32bitCache(SkPMColor cache[], SkColor c0, SkColor c1,
-                                      int count, U8CPU paintAlpha, uint32_t gradFlags) {
+void SkGradientShaderBase::GradientShaderCache::Build32bitCache(
+        SkPMColor cache[], SkColor c0, SkColor c1,
+        int count, U8CPU paintAlpha, uint32_t gradFlags) {
     SkASSERT(count > 1);
 
     // need to apply paintAlpha to our two endpoints
@@ -511,99 +496,123 @@
     return 0;
 }
 
-const uint16_t* SkGradientShaderBase::getCache16() const {
-    if (fCache16 == NULL) {
-        // double the count for dither entries
-        const int entryCount = kCache16Count * 2;
-        const size_t allocSize = sizeof(uint16_t) * entryCount;
-
-        if (fCache16Storage == NULL) { // set the storage and our working ptr
-            fCache16Storage = (uint16_t*)sk_malloc_throw(allocSize);
-        }
-        fCache16 = fCache16Storage;
-        if (fColorCount == 2) {
-            Build16bitCache(fCache16, fOrigColors[0], fOrigColors[1],
-                            kCache16Count);
-        } else {
-            Rec* rec = fRecs;
-            int prevIndex = 0;
-            for (int i = 1; i < fColorCount; i++) {
-                int nextIndex = SkFixedToFFFF(rec[i].fPos) >> kCache16Shift;
-                SkASSERT(nextIndex < kCache16Count);
-
-                if (nextIndex > prevIndex)
-                    Build16bitCache(fCache16 + prevIndex, fOrigColors[i-1], fOrigColors[i], nextIndex - prevIndex + 1);
-                prevIndex = nextIndex;
-            }
-        }
-
-        if (fMapper) {
-            fCache16Storage = (uint16_t*)sk_malloc_throw(allocSize);
-            uint16_t* linear = fCache16;         // just computed linear data
-            uint16_t* mapped = fCache16Storage;  // storage for mapped data
-            SkUnitMapper* map = fMapper;
-            for (int i = 0; i < kCache16Count; i++) {
-                int index = map->mapUnit16(bitsTo16(i, kCache16Bits)) >> kCache16Shift;
-                mapped[i] = linear[index];
-                mapped[i + kCache16Count] = linear[index + kCache16Count];
-            }
-            sk_free(fCache16);
-            fCache16 = fCache16Storage;
-        }
-    }
+const uint16_t* SkGradientShaderBase::GradientShaderCache::getCache16() {
+    SkOnce(&fCache16Inited, &fCache16Mutex, SkGradientShaderBase::GradientShaderCache::initCache16,
+           this);
+    SkASSERT(fCache16);
     return fCache16;
 }
 
-const SkPMColor* SkGradientShaderBase::getCache32() const {
-    if (fCache32 == NULL) {
-        SkImageInfo info;
-        info.fWidth = kCache32Count;
-        info.fHeight = 4;   // for our 4 dither rows
-        info.fAlphaType = kPremul_SkAlphaType;
-        info.fColorType = kN32_SkColorType;
+void SkGradientShaderBase::GradientShaderCache::initCache16(GradientShaderCache* cache) {
+    // double the count for dither entries
+    const int entryCount = kCache16Count * 2;
+    const size_t allocSize = sizeof(uint16_t) * entryCount;
 
-        if (NULL == fCache32PixelRef) {
-            fCache32PixelRef = SkMallocPixelRef::NewAllocate(info, 0, NULL);
-        }
-        fCache32 = (SkPMColor*)fCache32PixelRef->getAddr();
-        if (fColorCount == 2) {
-            Build32bitCache(fCache32, fOrigColors[0], fOrigColors[1],
-                            kCache32Count, fCacheAlpha, fGradFlags);
-        } else {
-            Rec* rec = fRecs;
-            int prevIndex = 0;
-            for (int i = 1; i < fColorCount; i++) {
-                int nextIndex = SkFixedToFFFF(rec[i].fPos) >> kCache32Shift;
-                SkASSERT(nextIndex < kCache32Count);
+    SkASSERT(NULL == cache->fCache16Storage);
+    cache->fCache16Storage = (uint16_t*)sk_malloc_throw(allocSize);
+    cache->fCache16 = cache->fCache16Storage;
+    if (cache->fShader.fColorCount == 2) {
+        Build16bitCache(cache->fCache16, cache->fShader.fOrigColors[0],
+                        cache->fShader.fOrigColors[1], kCache16Count);
+    } else {
+        Rec* rec = cache->fShader.fRecs;
+        int prevIndex = 0;
+        for (int i = 1; i < cache->fShader.fColorCount; i++) {
+            int nextIndex = SkFixedToFFFF(rec[i].fPos) >> kCache16Shift;
+            SkASSERT(nextIndex < kCache16Count);
 
-                if (nextIndex > prevIndex)
-                    Build32bitCache(fCache32 + prevIndex, fOrigColors[i-1],
-                                    fOrigColors[i], nextIndex - prevIndex + 1,
-                                    fCacheAlpha, fGradFlags);
-                prevIndex = nextIndex;
-            }
-        }
-
-        if (fMapper) {
-            SkMallocPixelRef* newPR = SkMallocPixelRef::NewAllocate(info, 0, NULL);
-            SkPMColor* linear = fCache32;           // just computed linear data
-            SkPMColor* mapped = (SkPMColor*)newPR->getAddr();    // storage for mapped data
-            SkUnitMapper* map = fMapper;
-            for (int i = 0; i < kCache32Count; i++) {
-                int index = map->mapUnit16((i << 8) | i) >> 8;
-                mapped[i + kCache32Count*0] = linear[index + kCache32Count*0];
-                mapped[i + kCache32Count*1] = linear[index + kCache32Count*1];
-                mapped[i + kCache32Count*2] = linear[index + kCache32Count*2];
-                mapped[i + kCache32Count*3] = linear[index + kCache32Count*3];
-            }
-            fCache32PixelRef->unref();
-            fCache32PixelRef = newPR;
-            fCache32 = (SkPMColor*)newPR->getAddr();
+            if (nextIndex > prevIndex)
+                Build16bitCache(cache->fCache16 + prevIndex, cache->fShader.fOrigColors[i-1],
+                                cache->fShader.fOrigColors[i], nextIndex - prevIndex + 1);
+            prevIndex = nextIndex;
         }
     }
+
+    if (cache->fShader.fMapper) {
+        cache->fCache16Storage = (uint16_t*)sk_malloc_throw(allocSize);
+        uint16_t* linear = cache->fCache16;         // just computed linear data
+        uint16_t* mapped = cache->fCache16Storage;  // storage for mapped data
+        SkUnitMapper* map = cache->fShader.fMapper;
+        for (int i = 0; i < kCache16Count; i++) {
+            int index = map->mapUnit16(bitsTo16(i, kCache16Bits)) >> kCache16Shift;
+            mapped[i] = linear[index];
+            mapped[i + kCache16Count] = linear[index + kCache16Count];
+        }
+        sk_free(cache->fCache16);
+        cache->fCache16 = cache->fCache16Storage;
+    }
+}
+
+const SkPMColor* SkGradientShaderBase::GradientShaderCache::getCache32() {
+    SkOnce(&fCache32Inited, &fCache32Mutex, SkGradientShaderBase::GradientShaderCache::initCache32,
+           this);
+    SkASSERT(fCache32);
     return fCache32;
 }
 
+void SkGradientShaderBase::GradientShaderCache::initCache32(GradientShaderCache* cache) {
+    SkImageInfo info;
+    info.fWidth = kCache32Count;
+    info.fHeight = 4;   // for our 4 dither rows
+    info.fAlphaType = kPremul_SkAlphaType;
+    info.fColorType = kN32_SkColorType;
+
+    SkASSERT(NULL == cache->fCache32PixelRef);
+    cache->fCache32PixelRef = SkMallocPixelRef::NewAllocate(info, 0, NULL);
+    cache->fCache32 = (SkPMColor*)cache->fCache32PixelRef->getAddr();
+    if (cache->fShader.fColorCount == 2) {
+        Build32bitCache(cache->fCache32, cache->fShader.fOrigColors[0],
+                        cache->fShader.fOrigColors[1], kCache32Count, cache->fCacheAlpha,
+                        cache->fShader.fGradFlags);
+    } else {
+        Rec* rec = cache->fShader.fRecs;
+        int prevIndex = 0;
+        for (int i = 1; i < cache->fShader.fColorCount; i++) {
+            int nextIndex = SkFixedToFFFF(rec[i].fPos) >> kCache32Shift;
+            SkASSERT(nextIndex < kCache32Count);
+
+            if (nextIndex > prevIndex)
+                Build32bitCache(cache->fCache32 + prevIndex, cache->fShader.fOrigColors[i-1],
+                                cache->fShader.fOrigColors[i], nextIndex - prevIndex + 1,
+                                cache->fCacheAlpha, cache->fShader.fGradFlags);
+            prevIndex = nextIndex;
+        }
+    }
+
+    if (cache->fShader.fMapper) {
+        SkMallocPixelRef* newPR = SkMallocPixelRef::NewAllocate(info, 0, NULL);
+        SkPMColor* linear = cache->fCache32;           // just computed linear data
+        SkPMColor* mapped = (SkPMColor*)newPR->getAddr();    // storage for mapped data
+        SkUnitMapper* map = cache->fShader.fMapper;
+        for (int i = 0; i < kCache32Count; i++) {
+            int index = map->mapUnit16((i << 8) | i) >> 8;
+            mapped[i + kCache32Count*0] = linear[index + kCache32Count*0];
+            mapped[i + kCache32Count*1] = linear[index + kCache32Count*1];
+            mapped[i + kCache32Count*2] = linear[index + kCache32Count*2];
+            mapped[i + kCache32Count*3] = linear[index + kCache32Count*3];
+        }
+        cache->fCache32PixelRef->unref();
+        cache->fCache32PixelRef = newPR;
+        cache->fCache32 = (SkPMColor*)newPR->getAddr();
+    }
+}
+
+/*
+ *  The gradient holds a cache for the most recent value of alpha. Successive
+ *  callers with the same alpha value will share the same cache.
+ */
+SkGradientShaderBase::GradientShaderCache* SkGradientShaderBase::refCache(U8CPU alpha) const {
+    SkAutoMutexAcquire ama(fCacheMutex);
+    if (!fCache || fCache->getAlpha() != alpha) {
+        fCache.reset(SkNEW_ARGS(GradientShaderCache, (alpha, *this)));
+    }
+    // Increment the ref counter inside the mutex to ensure the returned pointer is still valid.
+    // Otherwise, the pointer may have been overwritten on a different thread before the object's
+    // ref count was incremented.
+    fCache.get()->ref();
+    return fCache;
+}
+
 /*
  *  Because our caller might rebuild the same (logically the same) gradient
  *  over and over, we'd like to return exactly the same "bitmap" if possible,
@@ -615,14 +624,14 @@
 void SkGradientShaderBase::getGradientTableBitmap(SkBitmap* bitmap) const {
     // our caller assumes no external alpha, so we ensure that our cache is
     // built with 0xFF
-    this->setCacheAlpha(0xFF);
+    SkAutoTUnref<GradientShaderCache> cache(this->refCache(0xFF));
 
     // don't have a way to put the mapper into our cache-key yet
     if (fMapper) {
-        // force our cahce32pixelref to be built
-        (void)this->getCache32();
+        // force our cache32pixelref to be built
+        (void)cache->getCache32();
         bitmap->setConfig(SkImageInfo::MakeN32Premul(kCache32Count, 1));
-        bitmap->setPixelRef(fCache32PixelRef);
+        bitmap->setPixelRef(cache->getCache32PixelRef());
         return;
     }
 
@@ -661,9 +670,9 @@
 
     if (!gCache->find(storage.get(), size, bitmap)) {
         // force our cahce32pixelref to be built
-        (void)this->getCache32();
+        (void)cache->getCache32();
         bitmap->setConfig(SkImageInfo::MakeN32Premul(kCache32Count, 1));
-        bitmap->setPixelRef(fCache32PixelRef);
+        bitmap->setPixelRef(cache->getCache32PixelRef());
 
         gCache->add(storage.get(), size, *bitmap);
     }
diff --git a/src/effects/gradients/SkGradientShaderPriv.h b/src/effects/gradients/SkGradientShaderPriv.h
index 02bb50b..5dec665 100644
--- a/src/effects/gradients/SkGradientShaderPriv.h
+++ b/src/effects/gradients/SkGradientShaderPriv.h
@@ -19,6 +19,7 @@
 #include "SkTemplates.h"
 #include "SkBitmapCache.h"
 #include "SkShader.h"
+#include "SkOnce.h"
 
 static inline void sk_memset32_dither(uint32_t dst[], uint32_t v0, uint32_t v1,
                                int count) {
@@ -101,8 +102,64 @@
     SkGradientShaderBase(const Descriptor& desc);
     virtual ~SkGradientShaderBase();
 
-    virtual bool setContext(const SkBitmap&, const SkPaint&, const SkMatrix&) SK_OVERRIDE;
-    virtual uint32_t getFlags() SK_OVERRIDE { return fFlags; }
+    // The cache is initialized on-demand when getCache16/32 is called.
+    class GradientShaderCache : public SkRefCnt {
+    public:
+        GradientShaderCache(U8CPU alpha, const SkGradientShaderBase& shader);
+        ~GradientShaderCache();
+
+        const uint16_t*     getCache16();
+        const SkPMColor*    getCache32();
+
+        SkMallocPixelRef* getCache32PixelRef() const { return fCache32PixelRef; }
+
+        unsigned getAlpha() const { return fCacheAlpha; }
+
+    private:
+        // Working pointers. If either is NULL, we need to recompute the corresponding cache values.
+        uint16_t*   fCache16;
+        SkPMColor*  fCache32;
+
+        uint16_t*         fCache16Storage;    // Storage for fCache16, allocated on demand.
+        SkMallocPixelRef* fCache32PixelRef;
+        const unsigned    fCacheAlpha;        // The alpha value we used when we computed the cache.
+                                              // Larger than 8bits so we can store uninitialized
+                                              // value.
+
+        const SkGradientShaderBase& fShader;
+
+        // Make sure we only initialize the caches once.
+        bool    fCache16Inited, fCache32Inited;
+        SkMutex fCache16Mutex, fCache32Mutex;
+
+        static void initCache16(GradientShaderCache* cache);
+        static void initCache32(GradientShaderCache* cache);
+
+        static void Build16bitCache(uint16_t[], SkColor c0, SkColor c1, int count);
+        static void Build32bitCache(SkPMColor[], SkColor c0, SkColor c1, int count,
+                                    U8CPU alpha, uint32_t gradFlags);
+    };
+
+    class GradientShaderBaseContext : public SkShader::Context {
+    public:
+        GradientShaderBaseContext(const SkGradientShaderBase& shader, const SkBitmap& device,
+                                  const SkPaint& paint, const SkMatrix& matrix);
+        ~GradientShaderBaseContext() {}
+
+        virtual uint32_t getFlags() const SK_OVERRIDE { return fFlags; }
+
+    protected:
+        SkMatrix    fDstToIndex;
+        SkMatrix::MapXYProc fDstToIndexProc;
+        uint8_t     fDstToIndexClass;
+        uint8_t     fFlags;
+
+        SkAutoTUnref<GradientShaderCache> fCache;
+
+    private:
+        typedef SkShader::Context INHERITED;
+    };
+
     virtual bool isOpaque() const SK_OVERRIDE;
 
     void getGradientTableBitmap(SkBitmap*) const;
@@ -148,13 +205,9 @@
 
     SkUnitMapper* fMapper;
     SkMatrix    fPtsToUnit;     // set by subclass
-    SkMatrix    fDstToIndex;
-    SkMatrix::MapXYProc fDstToIndexProc;
     TileMode    fTileMode;
     TileProc    fTileProc;
     int         fColorCount;
-    uint8_t     fDstToIndexClass;
-    uint8_t     fFlags;
     uint8_t     fGradFlags;
 
     struct Rec {
@@ -163,9 +216,6 @@
     };
     Rec*        fRecs;
 
-    const uint16_t*     getCache16() const;
-    const SkPMColor*    getCache32() const;
-
     void commonAsAGradient(GradientInfo*, bool flipGrad = false) const;
 
     /*
@@ -191,20 +241,13 @@
         kStorageSize = kColorStorageCount * (sizeof(SkColor) + sizeof(Rec))
     };
     SkColor     fStorage[(kStorageSize + 3) >> 2];
-    SkColor*    fOrigColors; // original colors, before modulation by paint in setContext
+    SkColor*    fOrigColors; // original colors, before modulation by paint in context.
     bool        fColorsAreOpaque;
 
-    mutable uint16_t*   fCache16;   // working ptr. If this is NULL, we need to recompute the cache values
-    mutable SkPMColor*  fCache32;   // working ptr. If this is NULL, we need to recompute the cache values
+    GradientShaderCache* refCache(U8CPU alpha) const;
+    mutable SkMutex                           fCacheMutex;
+    mutable SkAutoTUnref<GradientShaderCache> fCache;
 
-    mutable uint16_t*   fCache16Storage;    // storage for fCache16, allocated on demand
-    mutable SkMallocPixelRef* fCache32PixelRef;
-    mutable unsigned    fCacheAlpha;        // the alpha value we used when we computed the cache. larger than 8bits so we can store uninitialized value
-
-    static void Build16bitCache(uint16_t[], SkColor c0, SkColor c1, int count);
-    static void Build32bitCache(SkPMColor[], SkColor c0, SkColor c1, int count,
-                                U8CPU alpha, uint32_t gradFlags);
-    void setCacheAlpha(U8CPU alpha) const;
     void initCommon();
 
     typedef SkShader INHERITED;
diff --git a/src/effects/gradients/SkLinearGradient.cpp b/src/effects/gradients/SkLinearGradient.cpp
index b24a634..e660d7c 100644
--- a/src/effects/gradients/SkLinearGradient.cpp
+++ b/src/effects/gradients/SkLinearGradient.cpp
@@ -71,12 +71,24 @@
     buffer.writePoint(fEnd);
 }
 
-bool SkLinearGradient::setContext(const SkBitmap& device, const SkPaint& paint,
-                                 const SkMatrix& matrix) {
-    if (!this->INHERITED::setContext(device, paint, matrix)) {
-        return false;
+size_t SkLinearGradient::contextSize() const {
+    return sizeof(LinearGradientContext);
+}
+
+SkShader::Context* SkLinearGradient::createContext(const SkBitmap& device, const SkPaint& paint,
+                                                   const SkMatrix& matrix, void* storage) const {
+    if (!this->validContext(device, paint, matrix)) {
+        return NULL;
     }
 
+    return SkNEW_PLACEMENT_ARGS(storage, LinearGradientContext, (*this, device, paint, matrix));
+}
+
+SkLinearGradient::LinearGradientContext::LinearGradientContext(
+        const SkLinearGradient& shader, const SkBitmap& device,
+        const SkPaint& paint, const SkMatrix& matrix)
+    : INHERITED(shader, device, paint, matrix)
+{
     unsigned mask = SkMatrix::kTranslate_Mask | SkMatrix::kScale_Mask;
     if ((fDstToIndex.getType() & ~mask) == 0) {
         // when we dither, we are (usually) not const-in-Y
@@ -87,7 +99,6 @@
             fFlags |= SkShader::kConstInY16_Flag;
         }
     }
-    return true;
 }
 
 #define NO_CHECK_ITER               \
@@ -196,14 +207,16 @@
 
 }
 
-void SkLinearGradient::shadeSpan(int x, int y, SkPMColor* SK_RESTRICT dstC,
-                                int count) {
+void SkLinearGradient::LinearGradientContext::shadeSpan(int x, int y, SkPMColor* SK_RESTRICT dstC,
+                                                        int count) {
     SkASSERT(count > 0);
 
+    const SkLinearGradient& linearGradient = static_cast<const SkLinearGradient&>(fShader);
+
     SkPoint             srcPt;
     SkMatrix::MapXYProc dstProc = fDstToIndexProc;
-    TileProc            proc = fTileProc;
-    const SkPMColor* SK_RESTRICT cache = this->getCache32();
+    TileProc            proc = linearGradient.fTileProc;
+    const SkPMColor* SK_RESTRICT cache = fCache->getCache32();
     int                 toggle = init_dither_toggle(x, y);
 
     if (fDstToIndexClass != kPerspective_MatrixClass) {
@@ -223,12 +236,12 @@
         LinearShadeProc shadeProc = shadeSpan_linear_repeat;
         if (0 == dx) {
             shadeProc = shadeSpan_linear_vertical_lerp;
-        } else if (SkShader::kClamp_TileMode == fTileMode) {
+        } else if (SkShader::kClamp_TileMode == linearGradient.fTileMode) {
             shadeProc = shadeSpan_linear_clamp;
-        } else if (SkShader::kMirror_TileMode == fTileMode) {
+        } else if (SkShader::kMirror_TileMode == linearGradient.fTileMode) {
             shadeProc = shadeSpan_linear_mirror;
         } else {
-            SkASSERT(SkShader::kRepeat_TileMode == fTileMode);
+            SkASSERT(SkShader::kRepeat_TileMode == linearGradient.fTileMode);
         }
         (*shadeProc)(proc, dx, fx, dstC, cache, toggle, count);
     } else {
@@ -381,14 +394,16 @@
     return SkAbs32(x) < (SK_Fixed1 >> 12);
 }
 
-void SkLinearGradient::shadeSpan16(int x, int y,
-                                  uint16_t* SK_RESTRICT dstC, int count) {
+void SkLinearGradient::LinearGradientContext::shadeSpan16(int x, int y,
+                                                          uint16_t* SK_RESTRICT dstC, int count) {
     SkASSERT(count > 0);
 
+    const SkLinearGradient& linearGradient = static_cast<const SkLinearGradient&>(fShader);
+
     SkPoint             srcPt;
     SkMatrix::MapXYProc dstProc = fDstToIndexProc;
-    TileProc            proc = fTileProc;
-    const uint16_t* SK_RESTRICT cache = this->getCache16();
+    TileProc            proc = linearGradient.fTileProc;
+    const uint16_t* SK_RESTRICT cache = fCache->getCache16();
     int                 toggle = init_dither_toggle16(x, y);
 
     if (fDstToIndexClass != kPerspective_MatrixClass) {
@@ -408,12 +423,12 @@
         LinearShade16Proc shadeProc = shadeSpan16_linear_repeat;
         if (fixed_nearly_zero(dx)) {
             shadeProc = shadeSpan16_linear_vertical;
-        } else if (SkShader::kClamp_TileMode == fTileMode) {
+        } else if (SkShader::kClamp_TileMode == linearGradient.fTileMode) {
             shadeProc = shadeSpan16_linear_clamp;
-        } else if (SkShader::kMirror_TileMode == fTileMode) {
+        } else if (SkShader::kMirror_TileMode == linearGradient.fTileMode) {
             shadeProc = shadeSpan16_linear_mirror;
         } else {
-            SkASSERT(SkShader::kRepeat_TileMode == fTileMode);
+            SkASSERT(SkShader::kRepeat_TileMode == linearGradient.fTileMode);
         }
         (*shadeProc)(proc, dx, fx, dstC, cache, toggle, count);
     } else {
diff --git a/src/effects/gradients/SkLinearGradient.h b/src/effects/gradients/SkLinearGradient.h
index 013c449..8d80667 100644
--- a/src/effects/gradients/SkLinearGradient.h
+++ b/src/effects/gradients/SkLinearGradient.h
@@ -15,9 +15,23 @@
 public:
     SkLinearGradient(const SkPoint pts[2], const Descriptor&);
 
-    virtual bool setContext(const SkBitmap&, const SkPaint&, const SkMatrix&) SK_OVERRIDE;
-    virtual void shadeSpan(int x, int y, SkPMColor dstC[], int count) SK_OVERRIDE;
-    virtual void shadeSpan16(int x, int y, uint16_t dstC[], int count) SK_OVERRIDE;
+    virtual SkShader::Context* createContext(const SkBitmap&, const SkPaint&, const SkMatrix&,
+                                             void* storage) const SK_OVERRIDE;
+    virtual size_t contextSize() const SK_OVERRIDE;
+
+    class LinearGradientContext : public SkGradientShaderBase::GradientShaderBaseContext {
+    public:
+        LinearGradientContext(const SkLinearGradient& shader, const SkBitmap& device,
+                              const SkPaint& paint, const SkMatrix& matrix);
+        ~LinearGradientContext() {}
+
+        virtual void shadeSpan(int x, int y, SkPMColor dstC[], int count) SK_OVERRIDE;
+        virtual void shadeSpan16(int x, int y, uint16_t dstC[], int count) SK_OVERRIDE;
+
+    private:
+        typedef SkGradientShaderBase::GradientShaderBaseContext INHERITED;
+    };
+
     virtual BitmapType asABitmap(SkBitmap*, SkMatrix*, TileMode*) const SK_OVERRIDE;
     virtual GradientType asAGradient(GradientInfo* info) const SK_OVERRIDE;
     virtual GrEffectRef* asNewEffect(GrContext* context, const SkPaint&) const SK_OVERRIDE;
diff --git a/src/effects/gradients/SkRadialGradient.cpp b/src/effects/gradients/SkRadialGradient.cpp
index 1b9e725..bc2ea3b 100644
--- a/src/effects/gradients/SkRadialGradient.cpp
+++ b/src/effects/gradients/SkRadialGradient.cpp
@@ -157,16 +157,36 @@
     rad_to_unit_matrix(center, radius, &fPtsToUnit);
 }
 
-void SkRadialGradient::shadeSpan16(int x, int y, uint16_t* dstCParam,
-                         int count) {
+size_t SkRadialGradient::contextSize() const {
+    return sizeof(RadialGradientContext);
+}
+
+SkShader::Context* SkRadialGradient::createContext(const SkBitmap& device, const SkPaint& paint,
+                                                   const SkMatrix& matrix, void* storage) const {
+    if (!this->validContext(device, paint, matrix)) {
+        return NULL;
+    }
+
+    return SkNEW_PLACEMENT_ARGS(storage, RadialGradientContext, (*this, device, paint, matrix));
+}
+
+SkRadialGradient::RadialGradientContext::RadialGradientContext(
+        const SkRadialGradient& shader, const SkBitmap& device,
+        const SkPaint& paint, const SkMatrix& matrix)
+    : INHERITED(shader, device, paint, matrix) {}
+
+void SkRadialGradient::RadialGradientContext::shadeSpan16(int x, int y, uint16_t* dstCParam,
+                                                          int count) {
     SkASSERT(count > 0);
 
+    const SkRadialGradient& radialGradient = static_cast<const SkRadialGradient&>(fShader);
+
     uint16_t* SK_RESTRICT dstC = dstCParam;
 
     SkPoint             srcPt;
     SkMatrix::MapXYProc dstProc = fDstToIndexProc;
-    TileProc            proc = fTileProc;
-    const uint16_t* SK_RESTRICT cache = this->getCache16();
+    TileProc            proc = radialGradient.fTileProc;
+    const uint16_t* SK_RESTRICT cache = fCache->getCache16();
     int                 toggle = init_dither_toggle16(x, y);
 
     if (fDstToIndexClass != kPerspective_MatrixClass) {
@@ -187,12 +207,12 @@
         }
 
         RadialShade16Proc shadeProc = shadeSpan16_radial_repeat;
-        if (SkShader::kClamp_TileMode == fTileMode) {
+        if (SkShader::kClamp_TileMode == radialGradient.fTileMode) {
             shadeProc = shadeSpan16_radial_clamp;
-        } else if (SkShader::kMirror_TileMode == fTileMode) {
+        } else if (SkShader::kMirror_TileMode == radialGradient.fTileMode) {
             shadeProc = shadeSpan16_radial_mirror;
         } else {
-            SkASSERT(SkShader::kRepeat_TileMode == fTileMode);
+            SkASSERT(SkShader::kRepeat_TileMode == radialGradient.fTileMode);
         }
         (*shadeProc)(srcPt.fX, sdx, srcPt.fY, sdy, dstC,
                      cache, toggle, count);
@@ -389,14 +409,16 @@
 
 }  // namespace
 
-void SkRadialGradient::shadeSpan(int x, int y,
-                                SkPMColor* SK_RESTRICT dstC, int count) {
+void SkRadialGradient::RadialGradientContext::shadeSpan(int x, int y,
+                                                        SkPMColor* SK_RESTRICT dstC, int count) {
     SkASSERT(count > 0);
 
+    const SkRadialGradient& radialGradient = static_cast<const SkRadialGradient&>(fShader);
+
     SkPoint             srcPt;
     SkMatrix::MapXYProc dstProc = fDstToIndexProc;
-    TileProc            proc = fTileProc;
-    const SkPMColor* SK_RESTRICT cache = this->getCache32();
+    TileProc            proc = radialGradient.fTileProc;
+    const SkPMColor* SK_RESTRICT cache = fCache->getCache32();
     int toggle = init_dither_toggle(x, y);
 
     if (fDstToIndexClass != kPerspective_MatrixClass) {
@@ -416,12 +438,12 @@
         }
 
         RadialShadeProc shadeProc = shadeSpan_radial_repeat;
-        if (SkShader::kClamp_TileMode == fTileMode) {
+        if (SkShader::kClamp_TileMode == radialGradient.fTileMode) {
             shadeProc = shadeSpan_radial_clamp;
-        } else if (SkShader::kMirror_TileMode == fTileMode) {
+        } else if (SkShader::kMirror_TileMode == radialGradient.fTileMode) {
             shadeProc = shadeSpan_radial_mirror;
         } else {
-            SkASSERT(SkShader::kRepeat_TileMode == fTileMode);
+            SkASSERT(SkShader::kRepeat_TileMode == radialGradient.fTileMode);
         }
         (*shadeProc)(srcPt.fX, sdx, srcPt.fY, sdy, dstC, cache, count, toggle);
     } else {    // perspective case
diff --git a/src/effects/gradients/SkRadialGradient.h b/src/effects/gradients/SkRadialGradient.h
index 4a72514..a3d04b1 100644
--- a/src/effects/gradients/SkRadialGradient.h
+++ b/src/effects/gradients/SkRadialGradient.h
@@ -14,10 +14,24 @@
 class SkRadialGradient : public SkGradientShaderBase {
 public:
     SkRadialGradient(const SkPoint& center, SkScalar radius, const Descriptor&);
-    virtual void shadeSpan(int x, int y, SkPMColor* dstC, int count)
-        SK_OVERRIDE;
-    virtual void shadeSpan16(int x, int y, uint16_t* dstCParam,
-                             int count) SK_OVERRIDE;
+
+    virtual SkShader::Context* createContext(const SkBitmap&, const SkPaint&, const SkMatrix&,
+                                             void* storage) const SK_OVERRIDE;
+    virtual size_t contextSize() const SK_OVERRIDE;
+
+    class RadialGradientContext : public SkGradientShaderBase::GradientShaderBaseContext {
+    public:
+        RadialGradientContext(const SkRadialGradient& shader, const SkBitmap& device,
+                              const SkPaint& paint, const SkMatrix& matrix);
+        ~RadialGradientContext() {}
+
+        virtual void shadeSpan(int x, int y, SkPMColor dstC[], int count) SK_OVERRIDE;
+        virtual void shadeSpan16(int x, int y, uint16_t dstC[], int count) SK_OVERRIDE;
+
+    private:
+        typedef SkGradientShaderBase::GradientShaderBaseContext INHERITED;
+    };
+
     virtual BitmapType asABitmap(SkBitmap* bitmap,
                                  SkMatrix* matrix,
                                  TileMode* xy) const SK_OVERRIDE;
diff --git a/src/effects/gradients/SkSweepGradient.cpp b/src/effects/gradients/SkSweepGradient.cpp
index 7024945..6dff1e7 100644
--- a/src/effects/gradients/SkSweepGradient.cpp
+++ b/src/effects/gradients/SkSweepGradient.cpp
@@ -52,6 +52,24 @@
     buffer.writePoint(fCenter);
 }
 
+size_t SkSweepGradient::contextSize() const {
+    return sizeof(SweepGradientContext);
+}
+
+SkShader::Context* SkSweepGradient::createContext(const SkBitmap& device, const SkPaint& paint,
+                                                  const SkMatrix& matrix, void* storage) const {
+    if (!this->validContext(device, paint, matrix)) {
+        return NULL;
+    }
+
+    return SkNEW_PLACEMENT_ARGS(storage, SweepGradientContext, (*this, device, paint, matrix));
+}
+
+SkSweepGradient::SweepGradientContext::SweepGradientContext(
+        const SkSweepGradient& shader, const SkBitmap& device,
+        const SkPaint& paint, const SkMatrix& matrix)
+    : INHERITED(shader, device, paint, matrix) {}
+
 //  returns angle in a circle [0..2PI) -> [0..255]
 static unsigned SkATan2_255(float y, float x) {
     //    static const float g255Over2PI = 255 / (2 * SK_ScalarPI);
@@ -69,11 +87,11 @@
     return ir;
 }
 
-void SkSweepGradient::shadeSpan(int x, int y, SkPMColor* SK_RESTRICT dstC,
-                               int count) {
+void SkSweepGradient::SweepGradientContext::shadeSpan(int x, int y, SkPMColor* SK_RESTRICT dstC,
+                                                      int count) {
     SkMatrix::MapXYProc proc = fDstToIndexProc;
     const SkMatrix&     matrix = fDstToIndex;
-    const SkPMColor* SK_RESTRICT cache = this->getCache32();
+    const SkPMColor* SK_RESTRICT cache = fCache->getCache32();
     int                 toggle = init_dither_toggle(x, y);
     SkPoint             srcPt;
 
@@ -111,11 +129,11 @@
     }
 }
 
-void SkSweepGradient::shadeSpan16(int x, int y, uint16_t* SK_RESTRICT dstC,
-                                 int count) {
+void SkSweepGradient::SweepGradientContext::shadeSpan16(int x, int y, uint16_t* SK_RESTRICT dstC,
+                                                        int count) {
     SkMatrix::MapXYProc proc = fDstToIndexProc;
     const SkMatrix&     matrix = fDstToIndex;
-    const uint16_t* SK_RESTRICT cache = this->getCache16();
+    const uint16_t* SK_RESTRICT cache = fCache->getCache16();
     int                 toggle = init_dither_toggle16(x, y);
     SkPoint             srcPt;
 
diff --git a/src/effects/gradients/SkSweepGradient.h b/src/effects/gradients/SkSweepGradient.h
index ca19da2..9998ed1 100644
--- a/src/effects/gradients/SkSweepGradient.h
+++ b/src/effects/gradients/SkSweepGradient.h
@@ -14,8 +14,23 @@
 class SkSweepGradient : public SkGradientShaderBase {
 public:
     SkSweepGradient(SkScalar cx, SkScalar cy, const Descriptor&);
-    virtual void shadeSpan(int x, int y, SkPMColor dstC[], int count) SK_OVERRIDE;
-    virtual void shadeSpan16(int x, int y, uint16_t dstC[], int count) SK_OVERRIDE;
+
+    virtual SkShader::Context* createContext(const SkBitmap&, const SkPaint&, const SkMatrix&,
+                                             void* storage) const SK_OVERRIDE;
+    virtual size_t contextSize() const SK_OVERRIDE;
+
+    class SweepGradientContext : public SkGradientShaderBase::GradientShaderBaseContext {
+    public:
+        SweepGradientContext(const SkSweepGradient& shader, const SkBitmap& device,
+                             const SkPaint& paint, const SkMatrix& matrix);
+        ~SweepGradientContext() {}
+
+        virtual void shadeSpan(int x, int y, SkPMColor dstC[], int count) SK_OVERRIDE;
+        virtual void shadeSpan16(int x, int y, uint16_t dstC[], int count) SK_OVERRIDE;
+
+    private:
+        typedef SkGradientShaderBase::GradientShaderBaseContext INHERITED;
+    };
 
     virtual BitmapType asABitmap(SkBitmap* bitmap,
                                  SkMatrix* matrix,
@@ -33,8 +48,9 @@
     virtual void flatten(SkWriteBuffer& buffer) const SK_OVERRIDE;
 
 private:
-    typedef SkGradientShaderBase INHERITED;
     const SkPoint fCenter;
+
+    typedef SkGradientShaderBase INHERITED;
 };
 
 #endif
diff --git a/src/effects/gradients/SkTwoPointConicalGradient.cpp b/src/effects/gradients/SkTwoPointConicalGradient.cpp
index 1e6a0d8..b7aba82 100644
--- a/src/effects/gradients/SkTwoPointConicalGradient.cpp
+++ b/src/effects/gradients/SkTwoPointConicalGradient.cpp
@@ -9,6 +9,18 @@
 
 #include "SkTwoPointConicalGradient_gpu.h"
 
+struct TwoPtRadialContext {
+    const TwoPtRadial&  fRec;
+    float               fRelX, fRelY;
+    const float         fIncX, fIncY;
+    float               fB;
+    const float         fDB;
+
+    TwoPtRadialContext(const TwoPtRadial& rec, SkScalar fx, SkScalar fy,
+                       SkScalar dfx, SkScalar dfy);
+    SkFixed nextT();
+};
+
 static int valid_divide(float numer, float denom, float* ratio) {
     SkASSERT(ratio);
     if (0 == denom) {
@@ -83,47 +95,48 @@
     fFlipped = flipped;
 }
 
-void TwoPtRadial::setup(SkScalar fx, SkScalar fy, SkScalar dfx, SkScalar dfy) {
-    fRelX = SkScalarToFloat(fx) - fCenterX;
-    fRelY = SkScalarToFloat(fy) - fCenterY;
-    fIncX = SkScalarToFloat(dfx);
-    fIncY = SkScalarToFloat(dfy);
-    fB = -2 * (fDCenterX * fRelX + fDCenterY * fRelY + fRDR);
-    fDB = -2 * (fDCenterX * fIncX + fDCenterY * fIncY);
-}
+TwoPtRadialContext::TwoPtRadialContext(const TwoPtRadial& rec, SkScalar fx, SkScalar fy,
+                                       SkScalar dfx, SkScalar dfy)
+    : fRec(rec)
+    , fRelX(SkScalarToFloat(fx) - rec.fCenterX)
+    , fRelY(SkScalarToFloat(fy) - rec.fCenterY)
+    , fIncX(SkScalarToFloat(dfx))
+    , fIncY(SkScalarToFloat(dfy))
+    , fB(-2 * (rec.fDCenterX * fRelX + rec.fDCenterY * fRelY + rec.fRDR))
+    , fDB(-2 * (rec.fDCenterX * fIncX + rec.fDCenterY * fIncY)) {}
 
-SkFixed TwoPtRadial::nextT() {
+SkFixed TwoPtRadialContext::nextT() {
     float roots[2];
 
-    float C = sqr(fRelX) + sqr(fRelY) - fRadius2;
-    int countRoots = find_quad_roots(fA, fB, C, roots, fFlipped);
+    float C = sqr(fRelX) + sqr(fRelY) - fRec.fRadius2;
+    int countRoots = find_quad_roots(fRec.fA, fB, C, roots, fRec.fFlipped);
 
     fRelX += fIncX;
     fRelY += fIncY;
     fB += fDB;
 
     if (0 == countRoots) {
-        return kDontDrawT;
+        return TwoPtRadial::kDontDrawT;
     }
 
     // Prefer the bigger t value if both give a radius(t) > 0
     // find_quad_roots returns the values sorted, so we start with the last
     float t = roots[countRoots - 1];
-    float r = lerp(fRadius, fDRadius, t);
+    float r = lerp(fRec.fRadius, fRec.fDRadius, t);
     if (r <= 0) {
         t = roots[0];   // might be the same as roots[countRoots-1]
-        r = lerp(fRadius, fDRadius, t);
+        r = lerp(fRec.fRadius, fRec.fDRadius, t);
         if (r <= 0) {
-            return kDontDrawT;
+            return TwoPtRadial::kDontDrawT;
         }
     }
     return SkFloatToFixed(t);
 }
 
-typedef void (*TwoPointConicalProc)(TwoPtRadial* rec, SkPMColor* dstC,
+typedef void (*TwoPointConicalProc)(TwoPtRadialContext* rec, SkPMColor* dstC,
                                     const SkPMColor* cache, int toggle, int count);
 
-static void twopoint_clamp(TwoPtRadial* rec, SkPMColor* SK_RESTRICT dstC,
+static void twopoint_clamp(TwoPtRadialContext* rec, SkPMColor* SK_RESTRICT dstC,
                            const SkPMColor* SK_RESTRICT cache, int toggle,
                            int count) {
     for (; count > 0; --count) {
@@ -140,7 +153,7 @@
     }
 }
 
-static void twopoint_repeat(TwoPtRadial* rec, SkPMColor* SK_RESTRICT dstC,
+static void twopoint_repeat(TwoPtRadialContext* rec, SkPMColor* SK_RESTRICT dstC,
                             const SkPMColor* SK_RESTRICT cache, int toggle,
                             int count) {
     for (; count > 0; --count) {
@@ -157,7 +170,7 @@
     }
 }
 
-static void twopoint_mirror(TwoPtRadial* rec, SkPMColor* SK_RESTRICT dstC,
+static void twopoint_mirror(TwoPtRadialContext* rec, SkPMColor* SK_RESTRICT dstC,
                             const SkPMColor* SK_RESTRICT cache, int toggle,
                             int count) {
     for (; count > 0; --count) {
@@ -203,8 +216,39 @@
     return false;
 }
 
-void SkTwoPointConicalGradient::shadeSpan(int x, int y, SkPMColor* dstCParam,
-                                          int count) {
+size_t SkTwoPointConicalGradient::contextSize() const {
+    return sizeof(TwoPointConicalGradientContext);
+}
+
+SkShader::Context* SkTwoPointConicalGradient::createContext(
+        const SkBitmap& device, const SkPaint& paint,
+        const SkMatrix& matrix, void* storage) const {
+    if (!this->validContext(device, paint, matrix)) {
+        return NULL;
+    }
+
+    return SkNEW_PLACEMENT_ARGS(storage, TwoPointConicalGradientContext,
+                                (*this, device, paint, matrix));
+}
+
+SkTwoPointConicalGradient::TwoPointConicalGradientContext::TwoPointConicalGradientContext(
+        const SkTwoPointConicalGradient& shader, const SkBitmap& device,
+        const SkPaint& paint, const SkMatrix& matrix)
+    : INHERITED(shader, device, paint, matrix)
+{
+    // we don't have a span16 proc
+    fFlags &= ~kHasSpan16_Flag;
+
+    // in general, we might discard based on computed-radius, so clear
+    // this flag (todo: sometimes we can detect that we never discard...)
+    fFlags &= ~kOpaqueAlpha_Flag;
+}
+
+void SkTwoPointConicalGradient::TwoPointConicalGradientContext::shadeSpan(
+        int x, int y, SkPMColor* dstCParam, int count) {
+    const SkTwoPointConicalGradient& twoPointConicalGradient =
+            static_cast<const SkTwoPointConicalGradient&>(fShader);
+
     int toggle = init_dither_toggle(x, y);
 
     SkASSERT(count > 0);
@@ -213,15 +257,15 @@
 
     SkMatrix::MapXYProc dstProc = fDstToIndexProc;
 
-    const SkPMColor* SK_RESTRICT cache = this->getCache32();
+    const SkPMColor* SK_RESTRICT cache = fCache->getCache32();
 
     TwoPointConicalProc shadeProc = twopoint_repeat;
-    if (SkShader::kClamp_TileMode == fTileMode) {
+    if (SkShader::kClamp_TileMode == twoPointConicalGradient.fTileMode) {
         shadeProc = twopoint_clamp;
-    } else if (SkShader::kMirror_TileMode == fTileMode) {
+    } else if (SkShader::kMirror_TileMode == twoPointConicalGradient.fTileMode) {
         shadeProc = twopoint_mirror;
     } else {
-        SkASSERT(SkShader::kRepeat_TileMode == fTileMode);
+        SkASSERT(SkShader::kRepeat_TileMode == twoPointConicalGradient.fTileMode);
     }
 
     if (fDstToIndexClass != kPerspective_MatrixClass) {
@@ -242,16 +286,16 @@
             dy = fDstToIndex.getSkewY();
         }
 
-        fRec.setup(fx, fy, dx, dy);
-        (*shadeProc)(&fRec, dstC, cache, toggle, count);
+        TwoPtRadialContext rec(twoPointConicalGradient.fRec, fx, fy, dx, dy);
+        (*shadeProc)(&rec, dstC, cache, toggle, count);
     } else {    // perspective case
         SkScalar dstX = SkIntToScalar(x) + SK_ScalarHalf;
         SkScalar dstY = SkIntToScalar(y) + SK_ScalarHalf;
         for (; count > 0; --count) {
             SkPoint srcPt;
             dstProc(fDstToIndex, dstX, dstY, &srcPt);
-            fRec.setup(srcPt.fX, srcPt.fY, 0, 0);
-            (*shadeProc)(&fRec, dstC, cache, toggle, 1);
+            TwoPtRadialContext rec(twoPointConicalGradient.fRec, srcPt.fX, srcPt.fY, 0, 0);
+            (*shadeProc)(&rec, dstC, cache, toggle, 1);
 
             dstX += SK_Scalar1;
             toggle = next_dither_toggle(toggle);
@@ -260,23 +304,6 @@
     }
 }
 
-bool SkTwoPointConicalGradient::setContext(const SkBitmap& device,
-                                           const SkPaint& paint,
-                                           const SkMatrix& matrix) {
-    if (!this->INHERITED::setContext(device, paint, matrix)) {
-        return false;
-    }
-
-    // we don't have a span16 proc
-    fFlags &= ~kHasSpan16_Flag;
-
-    // in general, we might discard based on computed-radius, so clear
-    // this flag (todo: sometimes we can detect that we never discard...)
-    fFlags &= ~kOpaqueAlpha_Flag;
-
-    return true;
-}
-
 SkShader::BitmapType SkTwoPointConicalGradient::asABitmap(
     SkBitmap* bitmap, SkMatrix* matrix, SkShader::TileMode* xy) const {
     SkPoint diff = fCenter2 - fCenter1;
diff --git a/src/effects/gradients/SkTwoPointConicalGradient.h b/src/effects/gradients/SkTwoPointConicalGradient.h
index b2e258e..80aa6fa 100644
--- a/src/effects/gradients/SkTwoPointConicalGradient.h
+++ b/src/effects/gradients/SkTwoPointConicalGradient.h
@@ -11,6 +11,8 @@
 
 #include "SkGradientShaderPriv.h"
 
+// TODO(dominikg): Worth making it truly immutable (i.e. set values in constructor)?
+// Should only be initialized once via init(). Immutable afterwards.
 struct TwoPtRadial {
     enum {
         kDontDrawT  = 0x80000000
@@ -29,13 +31,6 @@
               const SkPoint& center1, SkScalar rad1,
               bool flipped);
 
-    // used by setup and nextT
-    float   fRelX, fRelY, fIncX, fIncY;
-    float   fB, fDB;
-
-    void setup(SkScalar fx, SkScalar fy, SkScalar dfx, SkScalar dfy);
-    SkFixed nextT();
-
     static bool DontDrawT(SkFixed t) {
         return kDontDrawT == (uint32_t)t;
     }
@@ -51,11 +46,24 @@
                               const SkPoint& end, SkScalar endRadius,
                               bool flippedGrad, const Descriptor&);
 
-    virtual void shadeSpan(int x, int y, SkPMColor* dstCParam,
-                           int count) SK_OVERRIDE;
-    virtual bool setContext(const SkBitmap& device,
-                            const SkPaint& paint,
-                            const SkMatrix& matrix) SK_OVERRIDE;
+
+    virtual SkShader::Context* createContext(const SkBitmap&, const SkPaint&, const SkMatrix&,
+                                             void* storage) const SK_OVERRIDE;
+    virtual size_t contextSize() const SK_OVERRIDE;
+
+    class TwoPointConicalGradientContext : public SkGradientShaderBase::GradientShaderBaseContext {
+    public:
+        TwoPointConicalGradientContext(const SkTwoPointConicalGradient& shader,
+                                       const SkBitmap& device,
+                                       const SkPaint& paint,
+                                       const SkMatrix& matrix);
+        ~TwoPointConicalGradientContext() {}
+
+        virtual void shadeSpan(int x, int y, SkPMColor dstC[], int count) SK_OVERRIDE;
+
+    private:
+        typedef SkGradientShaderBase::GradientShaderBaseContext INHERITED;
+    };
 
     virtual BitmapType asABitmap(SkBitmap* bitmap,
                                  SkMatrix* matrix,
diff --git a/src/effects/gradients/SkTwoPointRadialGradient.cpp b/src/effects/gradients/SkTwoPointRadialGradient.cpp
index e1359b1..a598c6e 100644
--- a/src/effects/gradients/SkTwoPointRadialGradient.cpp
+++ b/src/effects/gradients/SkTwoPointRadialGradient.cpp
@@ -220,23 +220,60 @@
     return kRadial2_GradientType;
 }
 
-void SkTwoPointRadialGradient::shadeSpan(int x, int y, SkPMColor* dstCParam,
-                                         int count) {
+size_t SkTwoPointRadialGradient::contextSize() const {
+    return sizeof(TwoPointRadialGradientContext);
+}
+
+bool SkTwoPointRadialGradient::validContext(const SkBitmap& device, const SkPaint& paint,
+                                            const SkMatrix& matrix, SkMatrix* totalInverse) const {
+    // For now, we might have divided by zero, so detect that.
+    if (0 == fDiffRadius) {
+        return false;
+    }
+
+    return this->INHERITED::validContext(device, paint, matrix, totalInverse);
+}
+
+SkShader::Context* SkTwoPointRadialGradient::createContext(
+        const SkBitmap& device, const SkPaint& paint,
+        const SkMatrix& matrix, void* storage) const {
+    if (!this->validContext(device, paint, matrix)) {
+        return NULL;
+    }
+
+    return SkNEW_PLACEMENT_ARGS(storage, TwoPointRadialGradientContext,
+                                (*this, device, paint, matrix));
+}
+
+SkTwoPointRadialGradient::TwoPointRadialGradientContext::TwoPointRadialGradientContext(
+        const SkTwoPointRadialGradient& shader, const SkBitmap& device,
+        const SkPaint& paint, const SkMatrix& matrix)
+    : INHERITED(shader, device, paint, matrix)
+{
+    // we don't have a span16 proc
+    fFlags &= ~kHasSpan16_Flag;
+}
+
+void SkTwoPointRadialGradient::TwoPointRadialGradientContext::shadeSpan(
+        int x, int y, SkPMColor* dstCParam, int count) {
     SkASSERT(count > 0);
 
+    const SkTwoPointRadialGradient& twoPointRadialGradient =
+            static_cast<const SkTwoPointRadialGradient&>(fShader);
+
     SkPMColor* SK_RESTRICT dstC = dstCParam;
 
     // Zero difference between radii:  fill with transparent black.
-    if (fDiffRadius == 0) {
+    if (twoPointRadialGradient.fDiffRadius == 0) {
       sk_bzero(dstC, count * sizeof(*dstC));
       return;
     }
     SkMatrix::MapXYProc dstProc = fDstToIndexProc;
-    TileProc            proc = fTileProc;
-    const SkPMColor* SK_RESTRICT cache = this->getCache32();
+    TileProc            proc = twoPointRadialGradient.fTileProc;
+    const SkPMColor* SK_RESTRICT cache = fCache->getCache32();
 
-    SkScalar foura = fA * 4;
-    bool posRoot = fDiffRadius < 0;
+    SkScalar foura = twoPointRadialGradient.fA * 4;
+    bool posRoot = twoPointRadialGradient.fDiffRadius < 0;
     if (fDstToIndexClass != kPerspective_MatrixClass) {
         SkPoint srcPt;
         dstProc(fDstToIndex, SkIntToScalar(x) + SK_ScalarHalf,
@@ -254,21 +291,23 @@
             dx = fDstToIndex.getScaleX();
             dy = fDstToIndex.getSkewY();
         }
-        SkScalar b = (SkScalarMul(fDiff.fX, fx) +
-                     SkScalarMul(fDiff.fY, fy) - fStartRadius) * 2;
-        SkScalar db = (SkScalarMul(fDiff.fX, dx) +
-                      SkScalarMul(fDiff.fY, dy)) * 2;
+        SkScalar b = (SkScalarMul(twoPointRadialGradient.fDiff.fX, fx) +
+                     SkScalarMul(twoPointRadialGradient.fDiff.fY, fy) -
+                     twoPointRadialGradient.fStartRadius) * 2;
+        SkScalar db = (SkScalarMul(twoPointRadialGradient.fDiff.fX, dx) +
+                      SkScalarMul(twoPointRadialGradient.fDiff.fY, dy)) * 2;
 
         TwoPointRadialShadeProc shadeProc = shadeSpan_twopoint_repeat;
-        if (SkShader::kClamp_TileMode == fTileMode) {
+        if (SkShader::kClamp_TileMode == twoPointRadialGradient.fTileMode) {
             shadeProc = shadeSpan_twopoint_clamp;
-        } else if (SkShader::kMirror_TileMode == fTileMode) {
+        } else if (SkShader::kMirror_TileMode == twoPointRadialGradient.fTileMode) {
             shadeProc = shadeSpan_twopoint_mirror;
         } else {
-            SkASSERT(SkShader::kRepeat_TileMode == fTileMode);
+            SkASSERT(SkShader::kRepeat_TileMode == twoPointRadialGradient.fTileMode);
         }
         (*shadeProc)(fx, dx, fy, dy, b, db,
-                     fSr2D2, foura, fOneOverTwoA, posRoot,
+                     twoPointRadialGradient.fSr2D2, foura,
+                     twoPointRadialGradient.fOneOverTwoA, posRoot,
                      dstC, cache, count);
     } else {    // perspective case
         SkScalar dstX = SkIntToScalar(x);
@@ -278,10 +317,11 @@
             dstProc(fDstToIndex, dstX, dstY, &srcPt);
             SkScalar fx = srcPt.fX;
             SkScalar fy = srcPt.fY;
-            SkScalar b = (SkScalarMul(fDiff.fX, fx) +
-                         SkScalarMul(fDiff.fY, fy) - fStartRadius) * 2;
-            SkFixed t = two_point_radial(b, fx, fy, fSr2D2, foura,
-                                         fOneOverTwoA, posRoot);
+            SkScalar b = (SkScalarMul(twoPointRadialGradient.fDiff.fX, fx) +
+                         SkScalarMul(twoPointRadialGradient.fDiff.fY, fy) -
+                         twoPointRadialGradient.fStartRadius) * 2;
+            SkFixed t = two_point_radial(b, fx, fy, twoPointRadialGradient.fSr2D2, foura,
+                                         twoPointRadialGradient.fOneOverTwoA, posRoot);
             SkFixed index = proc(t);
             SkASSERT(index <= 0xFFFF);
             *dstC++ = cache[index >> SkGradientShaderBase::kCache32Shift];
@@ -290,23 +330,6 @@
     }
 }
 
-bool SkTwoPointRadialGradient::setContext( const SkBitmap& device,
-                                          const SkPaint& paint,
-                                          const SkMatrix& matrix){
-    // For now, we might have divided by zero, so detect that
-    if (0 == fDiffRadius) {
-        return false;
-    }
-
-    if (!this->INHERITED::setContext(device, paint, matrix)) {
-        return false;
-    }
-
-    // we don't have a span16 proc
-    fFlags &= ~kHasSpan16_Flag;
-    return true;
-}
-
 #ifndef SK_IGNORE_TO_STRING
 void SkTwoPointRadialGradient::toString(SkString* str) const {
     str->append("SkTwoPointRadialGradient: (");
diff --git a/src/effects/gradients/SkTwoPointRadialGradient.h b/src/effects/gradients/SkTwoPointRadialGradient.h
index ee1b49e..9ba89f2 100644
--- a/src/effects/gradients/SkTwoPointRadialGradient.h
+++ b/src/effects/gradients/SkTwoPointRadialGradient.h
@@ -23,11 +23,26 @@
     virtual GradientType asAGradient(GradientInfo* info) const SK_OVERRIDE;
     virtual GrEffectRef* asNewEffect(GrContext* context, const SkPaint&) const SK_OVERRIDE;
 
-    virtual void shadeSpan(int x, int y, SkPMColor* dstCParam,
-                           int count) SK_OVERRIDE;
-    virtual bool setContext(const SkBitmap& device,
-                            const SkPaint& paint,
-                            const SkMatrix& matrix) SK_OVERRIDE;
+
+    virtual size_t contextSize() const SK_OVERRIDE;
+    virtual bool validContext(const SkBitmap&, const SkPaint&,
+                              const SkMatrix&, SkMatrix* totalInverse = NULL) const SK_OVERRIDE;
+    virtual SkShader::Context* createContext(const SkBitmap&, const SkPaint&, const SkMatrix&,
+                                             void* storage) const SK_OVERRIDE;
+
+    class TwoPointRadialGradientContext : public SkGradientShaderBase::GradientShaderBaseContext {
+    public:
+        TwoPointRadialGradientContext(const SkTwoPointRadialGradient& shader,
+                                      const SkBitmap& device,
+                                      const SkPaint& paint,
+                                      const SkMatrix& matrix);
+        ~TwoPointRadialGradientContext() {}
+
+        virtual void shadeSpan(int x, int y, SkPMColor dstC[], int count) SK_OVERRIDE;
+
+    private:
+        typedef SkGradientShaderBase::GradientShaderBaseContext INHERITED;
+    };
 
     SkScalar getCenterX1() const { return fDiff.length(); }
     SkScalar getStartRadius() const { return fStartRadius; }
@@ -41,7 +56,6 @@
     virtual void flatten(SkWriteBuffer& buffer) const SK_OVERRIDE;
 
 private:
-    typedef SkGradientShaderBase INHERITED;
     const SkPoint fCenter1;
     const SkPoint fCenter2;
     const SkScalar fRadius1;
@@ -50,6 +64,8 @@
     SkScalar fStartRadius, fDiffRadius, fSr2D2, fA, fOneOverTwoA;
 
     void init();
+
+    typedef SkGradientShaderBase INHERITED;
 };
 
 #endif