Revert of Extract most of the mutable state of SkShader into a separate Context object. (https://codereview.chromium.org/207683004/)

Reason for revert:
This is blocking the DEPS roll into Chromium. Failures can be seen here:

http://build.chromium.org/p/tryserver.chromium/builders/android_dbg/builds/174333

Original issue's description:
> Extract most of the mutable state of SkShader into a separate Context object.
>
> SkShader currently stores some state during draw calls via setContext(...).
> Move that mutable state into a separate SkShader::Context class that is
> constructed on demand for the duration of the draw.
>
> Calls to setContext() are replaced with createContext() which returns a context
> corresponding to the shader object or NULL if the parameters to createContext
> are invalid.
>
> TEST=out/Debug/dm
> BUG=skia:1976
>
> Committed: http://code.google.com/p/skia/source/detail?r=14216
>
> Committed: http://code.google.com/p/skia/source/detail?r=14323

R=scroggo@google.com, skyostil@chromium.org, tomhudson@chromium.org, senorblanco@chromium.org, reed@google.com, bungeman@google.com, dominikg@chromium.org
TBR=bungeman@google.com, dominikg@chromium.org, reed@google.com, scroggo@google.com, senorblanco@chromium.org, skyostil@chromium.org, tomhudson@chromium.org
NOTREECHECKS=true
NOTRY=true
BUG=skia:1976

Author: bsalomon@google.com

Review URL: https://codereview.chromium.org/249643002

git-svn-id: http://skia.googlecode.com/svn/trunk@14326 2bbb7eff-a529-9590-31e7-b0007b416f81
diff --git a/include/core/SkColorShader.h b/include/core/SkColorShader.h
index 56e5add..975156c 100644
--- a/include/core/SkColorShader.h
+++ b/include/core/SkColorShader.h
@@ -30,35 +30,16 @@
     */
     SkColorShader(SkColor c);
 
+    virtual ~SkColorShader();
+
+    virtual uint32_t getFlags() SK_OVERRIDE;
+    virtual uint8_t getSpan16Alpha() const SK_OVERRIDE;
     virtual bool isOpaque() const SK_OVERRIDE;
-
-    virtual SkShader::Context* createContext(const SkBitmap& device,
-                                             const SkPaint& paint,
-                                             const SkMatrix& matrix,
-                                             void* storage) const SK_OVERRIDE;
-
-    virtual size_t contextSize() const SK_OVERRIDE {
-        return sizeof(ColorShaderContext);
-    }
-
-    class ColorShaderContext : public SkShader::Context {
-    public:
-        ColorShaderContext(const SkColorShader& shader, const SkBitmap& device,
-                           const SkPaint& paint, const SkMatrix& matrix);
-
-        virtual uint32_t getFlags() const SK_OVERRIDE;
-        virtual uint8_t getSpan16Alpha() const SK_OVERRIDE;
-        virtual void shadeSpan(int x, int y, SkPMColor span[], int count) SK_OVERRIDE;
-        virtual void shadeSpan16(int x, int y, uint16_t span[], int count) SK_OVERRIDE;
-        virtual void shadeSpanAlpha(int x, int y, uint8_t alpha[], int count) SK_OVERRIDE;
-
-    private:
-        SkPMColor   fPMColor;
-        uint32_t    fFlags;
-        uint16_t    fColor16;
-
-        typedef SkShader::Context INHERITED;
-    };
+    virtual bool setContext(const SkBitmap& device, const SkPaint& paint,
+                            const SkMatrix& matrix) SK_OVERRIDE;
+    virtual void shadeSpan(int x, int y, SkPMColor span[], int count) SK_OVERRIDE;
+    virtual void shadeSpan16(int x, int y, uint16_t span[], int count) SK_OVERRIDE;
+    virtual void shadeSpanAlpha(int x, int y, uint8_t alpha[], int count) SK_OVERRIDE;
 
     // we return false for this, use asAGradient
     virtual BitmapType asABitmap(SkBitmap* outTexture,
@@ -75,7 +56,11 @@
     virtual void flatten(SkWriteBuffer&) const SK_OVERRIDE;
 
 private:
+
     SkColor     fColor;         // ignored if fInheritColor is true
+    SkPMColor   fPMColor;       // cached after setContext()
+    uint32_t    fFlags;         // cached after setContext()
+    uint16_t    fColor16;       // cached after setContext()
     SkBool8     fInheritColor;
 
     typedef SkShader INHERITED;
diff --git a/include/core/SkComposeShader.h b/include/core/SkComposeShader.h
index d42da0c..b54e5ef 100644
--- a/include/core/SkComposeShader.h
+++ b/include/core/SkComposeShader.h
@@ -34,38 +34,10 @@
     SkComposeShader(SkShader* sA, SkShader* sB, SkXfermode* mode = NULL);
     virtual ~SkComposeShader();
 
-    virtual bool validContext(const SkBitmap&, const SkPaint&,
-                              const SkMatrix&, SkMatrix* totalInverse = NULL) const SK_OVERRIDE;
-    virtual SkShader::Context* createContext(const SkBitmap&, const SkPaint&,
-                                             const SkMatrix&, void*) const SK_OVERRIDE;
-    virtual size_t contextSize() const SK_OVERRIDE;
-
-    class ComposeShaderContext : public SkShader::Context {
-    public:
-        // When this object gets destroyed, it will call contextA and contextB's destructor
-        // but it will NOT free the memory.
-        ComposeShaderContext(const SkComposeShader&, const SkBitmap&,
-                             const SkPaint&, const SkMatrix&,
-                             SkShader::Context* contextA, SkShader::Context* contextB);
-
-        SkShader::Context* getShaderContextA() const { return fShaderContextA; }
-        SkShader::Context* getShaderContextB() const { return fShaderContextB; }
-
-        virtual ~ComposeShaderContext();
-
-        virtual void shadeSpan(int x, int y, SkPMColor[], int count) SK_OVERRIDE;
-
-    private:
-        SkShader::Context* fShaderContextA;
-        SkShader::Context* fShaderContextB;
-
-        typedef SkShader::Context INHERITED;
-    };
-
-#ifdef SK_DEBUG
-    SkShader* getShaderA() { return fShaderA; }
-    SkShader* getShaderB() { return fShaderB; }
-#endif
+    virtual bool setContext(const SkBitmap&, const SkPaint&,
+                            const SkMatrix&) SK_OVERRIDE;
+    virtual void endContext() SK_OVERRIDE;
+    virtual void shadeSpan(int x, int y, SkPMColor[], int count) SK_OVERRIDE;
 
     SK_TO_STRING_OVERRIDE()
     SK_DECLARE_PUBLIC_FLATTENABLE_DESERIALIZATION_PROCS(SkComposeShader)
@@ -75,6 +47,7 @@
     virtual void flatten(SkWriteBuffer&) const SK_OVERRIDE;
 
 private:
+
     SkShader*   fShaderA;
     SkShader*   fShaderB;
     SkXfermode* fMode;
diff --git a/include/core/SkEmptyShader.h b/include/core/SkEmptyShader.h
index 7494eff..d2ebb61 100644
--- a/include/core/SkEmptyShader.h
+++ b/include/core/SkEmptyShader.h
@@ -15,28 +15,20 @@
 
 /**
  *  \class SkEmptyShader
- *  A Shader that always draws nothing. Its createContext always returns NULL.
+ *  A Shader that always draws nothing. Its setContext always returns false,
+ *  so it never expects that its shadeSpan() methods will get called.
  */
 class SK_API SkEmptyShader : public SkShader {
 public:
     SkEmptyShader() {}
 
-    virtual size_t contextSize() const SK_OVERRIDE {
-        // Even though createContext returns NULL we have to return a value of at least
-        // sizeof(SkShader::Context) to satisfy SkSmallAllocator.
-        return sizeof(SkShader::Context);
-    }
-
-    virtual bool validContext(const SkBitmap&, const SkPaint&,
-                              const SkMatrix&, SkMatrix* totalInverse = NULL) const SK_OVERRIDE {
-        return false;
-    }
-
-    virtual SkShader::Context* createContext(const SkBitmap&, const SkPaint&,
-                                             const SkMatrix&, void*) const SK_OVERRIDE {
-        // validContext returns false.
-        return NULL;
-    }
+    virtual uint32_t getFlags() SK_OVERRIDE;
+    virtual uint8_t getSpan16Alpha() const SK_OVERRIDE;
+    virtual bool setContext(const SkBitmap&, const SkPaint&,
+                            const SkMatrix&) SK_OVERRIDE;
+    virtual void shadeSpan(int x, int y, SkPMColor span[], int count) SK_OVERRIDE;
+    virtual void shadeSpan16(int x, int y, uint16_t span[], int count) SK_OVERRIDE;
+    virtual void shadeSpanAlpha(int x, int y, uint8_t alpha[], int count) SK_OVERRIDE;
 
     SK_TO_STRING_OVERRIDE()
     SK_DECLARE_PUBLIC_FLATTENABLE_DESERIALIZATION_PROCS(SkEmptyShader)
diff --git a/include/core/SkShader.h b/include/core/SkShader.h
index cc2cc75..6566e69 100644
--- a/include/core/SkShader.h
+++ b/include/core/SkShader.h
@@ -38,7 +38,7 @@
     virtual ~SkShader();
 
     /**
-     *  Returns true if the local matrix is not an identity matrix.
+     * Returns true if the local matrix is not an identity matrix.
      */
     bool hasLocalMatrix() const { return !fLocalMatrix.isIdentity(); }
 
@@ -96,7 +96,7 @@
         */
         kIntrinsicly16_Flag = 0x04,
 
-        /** set if the spans only vary in X (const in Y).
+        /** set (after setContext) if the spans only vary in X (const in Y).
             e.g. an Nx1 bitmap that is being tiled in Y, or a linear-gradient
             that varies from left-to-right. This flag specifies this for
             shadeSpan().
@@ -112,111 +112,84 @@
     };
 
     /**
+     *  Called sometimes before drawing with this shader. Return the type of
+     *  alpha your shader will return. The default implementation returns 0.
+     *  Your subclass should override if it can (even sometimes) report a
+     *  non-zero value, since that will enable various blitters to perform
+     *  faster.
+     */
+    virtual uint32_t getFlags() { return 0; }
+
+    /**
      *  Returns true if the shader is guaranteed to produce only opaque
      *  colors, subject to the SkPaint using the shader to apply an opaque
      *  alpha value. Subclasses should override this to allow some
-     *  optimizations.
+     *  optimizations.  isOpaque() can be called at any time, unlike getFlags,
+     *  which only works properly when the context is set.
      */
     virtual bool isOpaque() const { return false; }
 
-    class Context : public ::SkNoncopyable {
-    public:
-        Context(const SkShader& shader, const SkBitmap& device,
-                const SkPaint& paint, const SkMatrix& matrix);
-
-        virtual ~Context();
-
-        /**
-         *  Called sometimes before drawing with this shader. Return the type of
-         *  alpha your shader will return. The default implementation returns 0.
-         *  Your subclass should override if it can (even sometimes) report a
-         *  non-zero value, since that will enable various blitters to perform
-         *  faster.
-         */
-        virtual uint32_t getFlags() const { return 0; }
-
-        /**
-         *  Return the alpha associated with the data returned by shadeSpan16(). If
-         *  kHasSpan16_Flag is not set, this value is meaningless.
-         */
-        virtual uint8_t getSpan16Alpha() const { return fPaintAlpha; }
-
-        /**
-         *  Called for each span of the object being drawn. Your subclass should
-         *  set the appropriate colors (with premultiplied alpha) that correspond
-         *  to the specified device coordinates.
-         */
-        virtual void shadeSpan(int x, int y, SkPMColor[], int count) = 0;
-
-        typedef void (*ShadeProc)(void* ctx, int x, int y, SkPMColor[], int count);
-        virtual ShadeProc asAShadeProc(void** ctx);
-
-        /**
-         *  Called only for 16bit devices when getFlags() returns
-         *  kOpaqueAlphaFlag | kHasSpan16_Flag
-         */
-        virtual void shadeSpan16(int x, int y, uint16_t[], int count);
-
-        /**
-         *  Similar to shadeSpan, but only returns the alpha-channel for a span.
-         *  The default implementation calls shadeSpan() and then extracts the alpha
-         *  values from the returned colors.
-         */
-        virtual void shadeSpanAlpha(int x, int y, uint8_t alpha[], int count);
-
-        /**
-         *  Helper function that returns true if this shader's shadeSpan16() method
-         *  can be called.
-         */
-        bool canCallShadeSpan16() {
-            return SkShader::CanCallShadeSpan16(this->getFlags());
-        }
-
-    protected:
-        // Reference to shader, so we don't have to dupe information.
-        const SkShader& fShader;
-
-        enum MatrixClass {
-            kLinear_MatrixClass,            // no perspective
-            kFixedStepInX_MatrixClass,      // fast perspective, need to call fixedStepInX() each
-                                            // scanline
-            kPerspective_MatrixClass        // slow perspective, need to mappoints each pixel
-        };
-        static MatrixClass ComputeMatrixClass(const SkMatrix&);
-
-        uint8_t             getPaintAlpha() const { return fPaintAlpha; }
-        const SkMatrix&     getTotalInverse() const { return fTotalInverse; }
-        MatrixClass         getInverseClass() const { return (MatrixClass)fTotalInverseClass; }
-
-    private:
-        SkMatrix            fTotalInverse;
-        uint8_t             fPaintAlpha;
-        uint8_t             fTotalInverseClass;
-
-        typedef SkNoncopyable INHERITED;
-    };
+    /**
+     *  Return the alpha associated with the data returned by shadeSpan16(). If
+     *  kHasSpan16_Flag is not set, this value is meaningless.
+     */
+    virtual uint8_t getSpan16Alpha() const { return fPaintAlpha; }
 
     /**
-     *  Subclasses should be sure to call their INHERITED::validContext() if
-     *  they override this method.
+     *  Called once before drawing, with the current paint and device matrix.
+     *  Return true if your shader supports these parameters, or false if not.
+     *  If false is returned, nothing will be drawn. If true is returned, then
+     *  a balancing call to endContext() will be made before the next call to
+     *  setContext.
+     *
+     *  Subclasses should be sure to call their INHERITED::setContext() if they
+     *  override this method.
      */
-    virtual bool validContext(const SkBitmap& device, const SkPaint& paint,
-                              const SkMatrix& matrix, SkMatrix* totalInverse = NULL) const;
+    virtual bool setContext(const SkBitmap& device, const SkPaint& paint,
+                            const SkMatrix& matrix);
 
     /**
-     *  Create the actual object that does the shading.
-     *  Returns NULL if validContext() returns false.
-     *  Size of storage must be >= contextSize.
+     *  Assuming setContext returned true, endContext() will be called when
+     *  the draw using the shader has completed. It is an error for setContext
+     *  to be called twice w/o an intervening call to endContext().
+     *
+     *  Subclasses should be sure to call their INHERITED::endContext() if they
+     *  override this method.
      */
-    virtual Context* createContext(const SkBitmap& device,
-                                   const SkPaint& paint,
-                                   const SkMatrix& matrix,
-                                   void* storage) const = 0;
+    virtual void endContext();
+
+    SkDEBUGCODE(bool setContextHasBeenCalled() const { return SkToBool(fInSetContext); })
 
     /**
-     *  Return the size of a Context returned by createContext.
+     *  Called for each span of the object being drawn. Your subclass should
+     *  set the appropriate colors (with premultiplied alpha) that correspond
+     *  to the specified device coordinates.
      */
-    virtual size_t contextSize() const = 0;
+    virtual void shadeSpan(int x, int y, SkPMColor[], int count) = 0;
+
+    typedef void (*ShadeProc)(void* ctx, int x, int y, SkPMColor[], int count);
+    virtual ShadeProc asAShadeProc(void** ctx);
+
+    /**
+     *  Called only for 16bit devices when getFlags() returns
+     *  kOpaqueAlphaFlag | kHasSpan16_Flag
+     */
+    virtual void shadeSpan16(int x, int y, uint16_t[], int count);
+
+    /**
+     *  Similar to shadeSpan, but only returns the alpha-channel for a span.
+     *  The default implementation calls shadeSpan() and then extracts the alpha
+     *  values from the returned colors.
+     */
+    virtual void shadeSpanAlpha(int x, int y, uint8_t alpha[], int count);
+
+    /**
+     *  Helper function that returns true if this shader's shadeSpan16() method
+     *  can be called.
+     */
+    bool canCallShadeSpan16() {
+        return SkShader::CanCallShadeSpan16(this->getFlags());
+    }
 
     /**
      *  Helper to check the flags to know if it is legal to call shadeSpan16()
@@ -349,7 +322,7 @@
      *  The incoming color to the effect has r=g=b=a all extracted from the SkPaint's alpha.
      *  The output color should be the computed SkShader premul color modulated by the incoming
      *  color. The GrContext may be used by the effect to create textures. The GPU device does not
-     *  call createContext. Instead we pass the SkPaint here in case the shader needs paint info.
+     *  call setContext. Instead we pass the SkPaint here in case the shader needs paint info.
      */
     virtual GrEffectRef* asNewEffect(GrContext* context, const SkPaint& paint) const;
 
@@ -389,14 +362,26 @@
     SK_DEFINE_FLATTENABLE_TYPE(SkShader)
 
 protected:
+    enum MatrixClass {
+        kLinear_MatrixClass,            // no perspective
+        kFixedStepInX_MatrixClass,      // fast perspective, need to call fixedStepInX() each scanline
+        kPerspective_MatrixClass        // slow perspective, need to mappoints each pixel
+    };
+    static MatrixClass ComputeMatrixClass(const SkMatrix&);
+
+    // These can be called by your subclass after setContext() has been called
+    uint8_t             getPaintAlpha() const { return fPaintAlpha; }
+    const SkMatrix&     getTotalInverse() const { return fTotalInverse; }
+    MatrixClass         getInverseClass() const { return (MatrixClass)fTotalInverseClass; }
 
     SkShader(SkReadBuffer& );
     virtual void flatten(SkWriteBuffer&) const SK_OVERRIDE;
-
 private:
     SkMatrix            fLocalMatrix;
-
-    bool computeTotalInverse(const SkMatrix& matrix, SkMatrix* totalInverse) const;
+    SkMatrix            fTotalInverse;
+    uint8_t             fPaintAlpha;
+    uint8_t             fTotalInverseClass;
+    SkDEBUGCODE(SkBool8 fInSetContext;)
 
     typedef SkFlattenable INHERITED;
 };
diff --git a/include/effects/SkPerlinNoiseShader.h b/include/effects/SkPerlinNoiseShader.h
index 5b27029..dfd5a8c 100644
--- a/include/effects/SkPerlinNoiseShader.h
+++ b/include/effects/SkPerlinNoiseShader.h
@@ -72,32 +72,10 @@
     }
 
 
-    virtual SkShader::Context* createContext(
-        const SkBitmap& device, const SkPaint& paint,
-        const SkMatrix& matrix, void* storage) const SK_OVERRIDE;
-    virtual size_t contextSize() const SK_OVERRIDE;
-
-    class PerlinNoiseShaderContext : public SkShader::Context {
-    public:
-        PerlinNoiseShaderContext(const SkPerlinNoiseShader& shader, const SkBitmap& device,
-                                 const SkPaint& paint, const SkMatrix& matrix);
-        virtual ~PerlinNoiseShaderContext() {}
-
-        virtual void shadeSpan(int x, int y, SkPMColor[], int count) SK_OVERRIDE;
-        virtual void shadeSpan16(int x, int y, uint16_t[], int count) SK_OVERRIDE;
-
-    private:
-        SkPMColor shade(const SkPoint& point, StitchData& stitchData) const;
-        SkScalar calculateTurbulenceValueForPoint(
-            int channel, const PaintingData& paintingData,
-            StitchData& stitchData, const SkPoint& point) const;
-        SkScalar noise2D(int channel, const PaintingData& paintingData,
-                         const StitchData& stitchData, const SkPoint& noiseVector) const;
-
-        SkMatrix fMatrix;
-
-        typedef SkShader::Context INHERITED;
-    };
+    virtual bool setContext(const SkBitmap& device, const SkPaint& paint,
+                            const SkMatrix& matrix);
+    virtual void shadeSpan(int x, int y, SkPMColor[], int count) SK_OVERRIDE;
+    virtual void shadeSpan16(int x, int y, uint16_t[], int count) SK_OVERRIDE;
 
     virtual GrEffectRef* asNewEffect(GrContext* context, const SkPaint&) const SK_OVERRIDE;
 
@@ -114,6 +92,14 @@
                         const SkISize* tileSize);
     virtual ~SkPerlinNoiseShader();
 
+    SkScalar noise2D(int channel, const PaintingData& paintingData,
+                     const StitchData& stitchData, const SkPoint& noiseVector) const;
+
+    SkScalar calculateTurbulenceValueForPoint(int channel, const PaintingData& paintingData,
+                                              StitchData& stitchData, const SkPoint& point) const;
+
+    SkPMColor shade(const SkPoint& point, StitchData& stitchData) const;
+
     // TODO (scroggo): Once all SkShaders are created from a factory, and we have removed the
     // constructor that creates SkPerlinNoiseShader from an SkReadBuffer, several fields can
     // be made constant.
@@ -124,6 +110,8 @@
     /*const*/ SkScalar                  fSeed;
     /*const*/ SkISize                   fTileSize;
     /*const*/ bool                      fStitchTiles;
+    // TODO (scroggo): Once setContext creates a new object, place this on that object.
+    SkMatrix fMatrix;
 
     PaintingData* fPaintingData;
 
diff --git a/include/effects/SkTransparentShader.h b/include/effects/SkTransparentShader.h
index 790e5ae..7428d44 100644
--- a/include/effects/SkTransparentShader.h
+++ b/include/effects/SkTransparentShader.h
@@ -14,31 +14,21 @@
 public:
     SkTransparentShader() {}
 
-    virtual SkShader::Context* createContext(const SkBitmap& device, const SkPaint& paint,
-                                             const SkMatrix& matrix, void* storage) const
-            SK_OVERRIDE;
-    virtual size_t contextSize() const SK_OVERRIDE;
-
-    class TransparentShaderContext : public SkShader::Context {
-    public:
-        TransparentShaderContext(const SkTransparentShader& shader, const SkBitmap& device,
-                                 const SkPaint& paint, const SkMatrix& matrix);
-        virtual ~TransparentShaderContext();
-
-        virtual uint32_t getFlags() const SK_OVERRIDE;
-        virtual void shadeSpan(int x, int y, SkPMColor[], int count) SK_OVERRIDE;
-        virtual void shadeSpan16(int x, int y, uint16_t span[], int count) SK_OVERRIDE;
-
-    private:
-        const SkBitmap* fDevice;
-
-        typedef SkShader::Context INHERITED;
-    };
+    virtual uint32_t getFlags() SK_OVERRIDE;
+    virtual bool    setContext(const SkBitmap& device,
+                               const SkPaint& paint,
+                               const SkMatrix& matrix) SK_OVERRIDE;
+    virtual void    shadeSpan(int x, int y, SkPMColor[], int count) SK_OVERRIDE;
+    virtual void    shadeSpan16(int x, int y, uint16_t span[], int count) SK_OVERRIDE;
 
     SK_TO_STRING_OVERRIDE()
     SK_DECLARE_PUBLIC_FLATTENABLE_DESERIALIZATION_PROCS(SkTransparentShader)
 
 private:
+    // these are a cache from the call to setContext()
+    const SkBitmap* fDevice;
+    uint8_t         fAlpha;
+
     SkTransparentShader(SkReadBuffer& buffer) : INHERITED(buffer) {}
 
     typedef SkShader INHERITED;
diff --git a/src/core/SkBitmapProcShader.cpp b/src/core/SkBitmapProcShader.cpp
index 5f5eb18..a397b78 100644
--- a/src/core/SkBitmapProcShader.cpp
+++ b/src/core/SkBitmapProcShader.cpp
@@ -34,16 +34,18 @@
 SkBitmapProcShader::SkBitmapProcShader(const SkBitmap& src,
                                        TileMode tmx, TileMode tmy) {
     fRawBitmap = src;
-    fTileModeX = (uint8_t)tmx;
-    fTileModeY = (uint8_t)tmy;
+    fState.fTileModeX = (uint8_t)tmx;
+    fState.fTileModeY = (uint8_t)tmy;
+    fFlags = 0; // computed in setContext
 }
 
 SkBitmapProcShader::SkBitmapProcShader(SkReadBuffer& buffer)
         : INHERITED(buffer) {
     buffer.readBitmap(&fRawBitmap);
     fRawBitmap.setImmutable();
-    fTileModeX = buffer.readUInt();
-    fTileModeY = buffer.readUInt();
+    fState.fTileModeX = buffer.readUInt();
+    fState.fTileModeY = buffer.readUInt();
+    fFlags = 0; // computed in setContext
 }
 
 SkShader::BitmapType SkBitmapProcShader::asABitmap(SkBitmap* texture,
@@ -56,8 +58,8 @@
         texM->reset();
     }
     if (xy) {
-        xy[0] = (TileMode)fTileModeX;
-        xy[1] = (TileMode)fTileModeY;
+        xy[0] = (TileMode)fState.fTileModeX;
+        xy[1] = (TileMode)fState.fTileModeY;
     }
     return kDefault_BitmapType;
 }
@@ -66,8 +68,8 @@
     this->INHERITED::flatten(buffer);
 
     buffer.writeBitmap(fRawBitmap);
-    buffer.writeUInt(fTileModeX);
-    buffer.writeUInt(fTileModeY);
+    buffer.writeUInt(fState.fTileModeX);
+    buffer.writeUInt(fState.fTileModeY);
 }
 
 static bool only_scale_and_translate(const SkMatrix& matrix) {
@@ -96,67 +98,25 @@
     return true;
 }
 
-bool SkBitmapProcShader::validInternal(const SkBitmap& device,
-                                       const SkPaint& paint,
-                                       const SkMatrix& matrix,
-                                       SkMatrix* totalInverse,
-                                       SkBitmapProcState* state) const {
+bool SkBitmapProcShader::setContext(const SkBitmap& device,
+                                    const SkPaint& paint,
+                                    const SkMatrix& matrix) {
     if (!fRawBitmap.getTexture() && !valid_for_drawing(fRawBitmap)) {
         return false;
     }
 
-    // Make sure we can use totalInverse as a cache.
-    SkMatrix totalInverseLocal;
-    if (NULL == totalInverse) {
-        totalInverse = &totalInverseLocal;
-    }
-
-    // Do this first, so we know the matrix can be inverted.
-    if (!this->INHERITED::validContext(device, paint, matrix, totalInverse)) {
+    // do this first, so we have a correct inverse matrix
+    if (!this->INHERITED::setContext(device, paint, matrix)) {
         return false;
     }
 
-    SkASSERT(state);
-    state->fTileModeX = fTileModeX;
-    state->fTileModeY = fTileModeY;
-    state->fOrigBitmap = fRawBitmap;
-    return state->chooseProcs(*totalInverse, paint);
-}
-
-bool SkBitmapProcShader::validContext(const SkBitmap& device,
-                                      const SkPaint& paint,
-                                      const SkMatrix& matrix,
-                                      SkMatrix* totalInverse) const {
-    SkBitmapProcState state;
-    return this->validInternal(device, paint, matrix, totalInverse, &state);
-}
-
-SkShader::Context* SkBitmapProcShader::createContext(const SkBitmap& device, const SkPaint& paint,
-                                                     const SkMatrix& matrix, void* storage) const {
-    void* stateStorage = (char*)storage + sizeof(BitmapProcShaderContext);
-    SkBitmapProcState* state = SkNEW_PLACEMENT(stateStorage, SkBitmapProcState);
-    if (!this->validInternal(device, paint, matrix, NULL, state)) {
-        state->~SkBitmapProcState();
-        return NULL;
+    fState.fOrigBitmap = fRawBitmap;
+    if (!fState.chooseProcs(this->getTotalInverse(), paint)) {
+        this->INHERITED::endContext();
+        return false;
     }
 
-    return SkNEW_PLACEMENT_ARGS(storage, BitmapProcShaderContext,
-                                (*this, device, paint, matrix, state));
-}
-
-size_t SkBitmapProcShader::contextSize() const {
-    // The SkBitmapProcState is stored outside of the context object, with the context holding
-    // a pointer to it.
-    return sizeof(BitmapProcShaderContext) + sizeof(SkBitmapProcState);
-}
-
-SkBitmapProcShader::BitmapProcShaderContext::BitmapProcShaderContext(
-        const SkBitmapProcShader& shader, const SkBitmap& device,
-        const SkPaint& paint, const SkMatrix& matrix, SkBitmapProcState* state)
-    : INHERITED(shader, device, paint, matrix)
-    , fState(state)
-{
-    const SkBitmap& bitmap = *fState->fBitmap;
+    const SkBitmap& bitmap = *fState.fBitmap;
     bool bitmapIsOpaque = bitmap.isOpaque();
 
     // update fFlags
@@ -197,12 +157,12 @@
     }
 
     fFlags = flags;
+    return true;
 }
 
-SkBitmapProcShader::BitmapProcShaderContext::~BitmapProcShaderContext() {
-    // The bitmap proc state has been created outside of the context on memory that will be freed
-    // elsewhere. Only call the destructor but leave the freeing of the memory to the caller.
-    fState->~SkBitmapProcState();
+void SkBitmapProcShader::endContext() {
+    fState.endContext();
+    this->INHERITED::endContext();
 }
 
 #define BUF_MAX     128
@@ -216,9 +176,8 @@
     #define TEST_BUFFER_EXTRA   0
 #endif
 
-void SkBitmapProcShader::BitmapProcShaderContext::shadeSpan(int x, int y, SkPMColor dstC[],
-                                                            int count) {
-    const SkBitmapProcState& state = *fState;
+void SkBitmapProcShader::shadeSpan(int x, int y, SkPMColor dstC[], int count) {
+    const SkBitmapProcState& state = fState;
     if (state.getShaderProc32()) {
         state.getShaderProc32()(state, x, y, dstC, count);
         return;
@@ -227,7 +186,7 @@
     uint32_t buffer[BUF_MAX + TEST_BUFFER_EXTRA];
     SkBitmapProcState::MatrixProc   mproc = state.getMatrixProc();
     SkBitmapProcState::SampleProc32 sproc = state.getSampleProc32();
-    int max = state.maxCountForBufferSize(sizeof(buffer[0]) * BUF_MAX);
+    int max = fState.maxCountForBufferSize(sizeof(buffer[0]) * BUF_MAX);
 
     SkASSERT(state.fBitmap->getPixels());
     SkASSERT(state.fBitmap->pixelRef() == NULL ||
@@ -261,17 +220,16 @@
     }
 }
 
-SkShader::Context::ShadeProc SkBitmapProcShader::BitmapProcShaderContext::asAShadeProc(void** ctx) {
-    if (fState->getShaderProc32()) {
-        *ctx = fState;
-        return (ShadeProc)fState->getShaderProc32();
+SkShader::ShadeProc SkBitmapProcShader::asAShadeProc(void** ctx) {
+    if (fState.getShaderProc32()) {
+        *ctx = &fState;
+        return (ShadeProc)fState.getShaderProc32();
     }
     return NULL;
 }
 
-void SkBitmapProcShader::BitmapProcShaderContext::shadeSpan16(int x, int y, uint16_t dstC[],
-                                                              int count) {
-    const SkBitmapProcState& state = *fState;
+void SkBitmapProcShader::shadeSpan16(int x, int y, uint16_t dstC[], int count) {
+    const SkBitmapProcState& state = fState;
     if (state.getShaderProc16()) {
         state.getShaderProc16()(state, x, y, dstC, count);
         return;
@@ -280,7 +238,7 @@
     uint32_t buffer[BUF_MAX];
     SkBitmapProcState::MatrixProc   mproc = state.getMatrixProc();
     SkBitmapProcState::SampleProc16 sproc = state.getSampleProc16();
-    int max = state.maxCountForBufferSize(sizeof(buffer));
+    int max = fState.maxCountForBufferSize(sizeof(buffer));
 
     SkASSERT(state.fBitmap->getPixels());
     SkASSERT(state.fBitmap->pixelRef() == NULL ||
@@ -384,8 +342,8 @@
     str->append("BitmapShader: (");
 
     str->appendf("(%s, %s)",
-                 gTileModeName[fTileModeX],
-                 gTileModeName[fTileModeY]);
+                 gTileModeName[fState.fTileModeX],
+                 gTileModeName[fState.fTileModeY]);
 
     str->append(" ");
     fRawBitmap.toString(str);
@@ -426,8 +384,8 @@
     matrix.preConcat(lmInverse);
 
     SkShader::TileMode tm[] = {
-        (TileMode)fTileModeX,
-        (TileMode)fTileModeY,
+        (TileMode)fState.fTileModeX,
+        (TileMode)fState.fTileModeY,
     };
 
     // Must set wrap and filter on the sampler before requesting a texture. In two places below
diff --git a/src/core/SkBitmapProcShader.h b/src/core/SkBitmapProcShader.h
index e0c78b8..8e225a5 100644
--- a/src/core/SkBitmapProcShader.h
+++ b/src/core/SkBitmapProcShader.h
@@ -20,16 +20,14 @@
 
     // overrides from SkShader
     virtual bool isOpaque() const SK_OVERRIDE;
+    virtual bool setContext(const SkBitmap&, const SkPaint&, const SkMatrix&) SK_OVERRIDE;
+    virtual void endContext() SK_OVERRIDE;
+    virtual uint32_t getFlags() SK_OVERRIDE { return fFlags; }
+    virtual void shadeSpan(int x, int y, SkPMColor dstC[], int count) SK_OVERRIDE;
+    virtual ShadeProc asAShadeProc(void** ctx) SK_OVERRIDE;
+    virtual void shadeSpan16(int x, int y, uint16_t dstC[], int count) SK_OVERRIDE;
     virtual BitmapType asABitmap(SkBitmap*, SkMatrix*, TileMode*) const SK_OVERRIDE;
 
-    virtual bool validContext(const SkBitmap& device,
-                              const SkPaint& paint,
-                              const SkMatrix& matrix,
-                              SkMatrix* totalInverse = NULL) const SK_OVERRIDE;
-    virtual SkShader::Context* createContext(const SkBitmap&, const SkPaint&,
-                                             const SkMatrix&, void* storage) const SK_OVERRIDE;
-    virtual size_t contextSize() const SK_OVERRIDE;
-
     static bool CanDo(const SkBitmap&, TileMode tx, TileMode ty);
 
     SK_TO_STRING_OVERRIDE()
@@ -39,54 +37,22 @@
     GrEffectRef* asNewEffect(GrContext*, const SkPaint&) const SK_OVERRIDE;
 #endif
 
-    class BitmapProcShaderContext : public SkShader::Context {
-    public:
-        // The context takes ownership of the state. It will call its destructor
-        // but will NOT free the memory.
-        BitmapProcShaderContext(const SkBitmapProcShader& shader,
-                                const SkBitmap& device,
-                                const SkPaint& paint,
-                                const SkMatrix& matrix,
-                                SkBitmapProcState* state);
-        virtual ~BitmapProcShaderContext();
-
-        virtual void shadeSpan(int x, int y, SkPMColor dstC[], int count) SK_OVERRIDE;
-        virtual ShadeProc asAShadeProc(void** ctx) SK_OVERRIDE;
-        virtual void shadeSpan16(int x, int y, uint16_t dstC[], int count) SK_OVERRIDE;
-
-        virtual uint32_t getFlags() const SK_OVERRIDE { return fFlags; }
-
-    private:
-        SkBitmapProcState*  fState;
-        uint32_t            fFlags;
-
-        typedef SkShader::Context INHERITED;
-    };
-
 protected:
     SkBitmapProcShader(SkReadBuffer& );
     virtual void flatten(SkWriteBuffer&) const SK_OVERRIDE;
 
-    SkBitmap    fRawBitmap;   // experimental for RLE encoding
-    uint8_t     fTileModeX, fTileModeY;
+    SkBitmap          fRawBitmap;   // experimental for RLE encoding
+    SkBitmapProcState fState;
+    uint32_t          fFlags;
 
 private:
-    bool validInternal(const SkBitmap& device, const SkPaint& paint,
-                       const SkMatrix& matrix, SkMatrix* totalInverse,
-                       SkBitmapProcState* state) const;
-
     typedef SkShader INHERITED;
 };
 
-// Commonly used allocator. It currently is only used to allocate up to 3 objects. The total
-// bytes requested is calculated using one of our large shaders, its context size plus the size of
-// an Sk3DBlitter in SkDraw.cpp
-// Note that some contexts may contain other contexts (e.g. for compose shaders), but we've not
-// yet found a situation where the size below isn't big enough.
-typedef SkSmallAllocator<3, sizeof(SkBitmapProcShader) +
-                            sizeof(SkBitmapProcShader::BitmapProcShaderContext) +
-                            sizeof(SkBitmapProcState) +
-                            sizeof(void*) * 2> SkTBlitterAllocator;
+// Commonly used allocator. It currently is only used to allocate up to 2 objects. The total
+// bytes requested is calculated using one of our large shaders plus the size of an Sk3DBlitter
+// in SkDraw.cpp
+typedef SkSmallAllocator<2, sizeof(SkBitmapProcShader) + sizeof(void*) * 2> SkTBlitterAllocator;
 
 // If alloc is non-NULL, it will be used to allocate the returned SkShader, and MUST outlive
 // the SkShader.
diff --git a/src/core/SkBitmapProcState.cpp b/src/core/SkBitmapProcState.cpp
index eecfbbc..be87d83 100644
--- a/src/core/SkBitmapProcState.cpp
+++ b/src/core/SkBitmapProcState.cpp
@@ -360,6 +360,17 @@
     return true;
 }
 
+void SkBitmapProcState::endContext() {
+    SkDELETE(fBitmapFilter);
+    fBitmapFilter = NULL;
+    fScaledBitmap.reset();
+
+    if (fScaledCacheID) {
+        SkScaledImageCache::Unlock(fScaledCacheID);
+        fScaledCacheID = NULL;
+    }
+}
+
 SkBitmapProcState::~SkBitmapProcState() {
     if (fScaledCacheID) {
         SkScaledImageCache::Unlock(fScaledCacheID);
@@ -388,7 +399,6 @@
     }
     // The above logic should have always assigned fBitmap, but in case it
     // didn't, we check for that now...
-    // TODO(dominikg): Ask humper@ if we can just use an SkASSERT(fBitmap)?
     if (NULL == fBitmap) {
         return false;
     }
@@ -477,7 +487,6 @@
     // shader will perform.
 
     fMatrixProc = this->chooseMatrixProc(trivialMatrix);
-    // TODO(dominikg): SkASSERT(fMatrixProc) instead? chooseMatrixProc never returns NULL.
     if (NULL == fMatrixProc) {
         return false;
     }
@@ -519,7 +528,6 @@
                 fPaintPMColor = SkPreMultiplyColor(paint.getColor());
                 break;
             default:
-                // TODO(dominikg): Should we ever get here? SkASSERT(false) instead?
                 return false;
         }
 
diff --git a/src/core/SkBitmapProcState.h b/src/core/SkBitmapProcState.h
index 663bcb8..d5a354e 100644
--- a/src/core/SkBitmapProcState.h
+++ b/src/core/SkBitmapProcState.h
@@ -89,6 +89,12 @@
     uint8_t             fTileModeY;         // CONSTRUCTOR
     uint8_t             fFilterLevel;       // chooseProcs
 
+    /** The shader will let us know when we can release some of our resources
+      * like scaled bitmaps.
+      */
+
+    void endContext();
+
     /** Platforms implement this, and can optionally overwrite only the
         following fields:
 
diff --git a/src/core/SkBlitter.cpp b/src/core/SkBlitter.cpp
index 7243f52..52a05ed 100644
--- a/src/core/SkBlitter.cpp
+++ b/src/core/SkBlitter.cpp
@@ -26,15 +26,6 @@
 
 bool SkBlitter::isNullBlitter() const { return false; }
 
-bool SkBlitter::resetShaderContext(const SkBitmap& device, const SkPaint& paint,
-                                   const SkMatrix& matrix) {
-    return true;
-}
-
-SkShader::Context* SkBlitter::getShaderContext() const {
-    return NULL;
-}
-
 const SkBitmap* SkBlitter::justAnOpaqueColor(uint32_t* value) {
     return NULL;
 }
@@ -577,149 +568,102 @@
 public:
     Sk3DShader(SkShader* proxy) : fProxy(proxy) {
         SkSafeRef(proxy);
+        fMask = NULL;
     }
 
     virtual ~Sk3DShader() {
         SkSafeUnref(fProxy);
     }
 
-    virtual size_t contextSize() const SK_OVERRIDE {
-        size_t size = sizeof(Sk3DShaderContext);
-        if (fProxy) {
-            size += fProxy->contextSize();
-        }
-        return size;
-    }
+    void setMask(const SkMask* mask) { fMask = mask; }
 
-    virtual bool validContext(const SkBitmap& device, const SkPaint& paint,
-                              const SkMatrix& matrix, SkMatrix* totalInverse = NULL) const
-            SK_OVERRIDE
-    {
-        if (!this->INHERITED::validContext(device, paint, matrix, totalInverse)) {
+    virtual bool setContext(const SkBitmap& device, const SkPaint& paint,
+                            const SkMatrix& matrix) SK_OVERRIDE {
+        if (!this->INHERITED::setContext(device, paint, matrix)) {
             return false;
         }
         if (fProxy) {
-            return fProxy->validContext(device, paint, matrix);
+            if (!fProxy->setContext(device, paint, matrix)) {
+                // must keep our set/end context calls balanced
+                this->INHERITED::endContext();
+                return false;
+            }
+        } else {
+            fPMColor = SkPreMultiplyColor(paint.getColor());
         }
         return true;
     }
 
-    virtual SkShader::Context* createContext(const SkBitmap& device,
-                                             const SkPaint& paint,
-                                             const SkMatrix& matrix,
-                                             void* storage) const SK_OVERRIDE
-    {
-        if (!this->validContext(device, paint, matrix)) {
-            return NULL;
-        }
-
-        SkShader::Context* proxyContext;
+    virtual void endContext() SK_OVERRIDE {
         if (fProxy) {
-            char* proxyContextStorage = (char*) storage + sizeof(Sk3DShaderContext);
-            proxyContext = fProxy->createContext(device, paint, matrix, proxyContextStorage);
-            SkASSERT(proxyContext);
-        } else {
-            proxyContext = NULL;
+            fProxy->endContext();
         }
-        return SkNEW_PLACEMENT_ARGS(storage, Sk3DShaderContext, (*this, device, paint, matrix,
-                                                                 proxyContext));
+        this->INHERITED::endContext();
     }
 
-    class Sk3DShaderContext : public SkShader::Context {
-    public:
-        // Calls proxyContext's destructor but will NOT free its memory.
-        Sk3DShaderContext(const Sk3DShader& shader, const SkBitmap& device, const SkPaint& paint,
-                          const SkMatrix& matrix, SkShader::Context* proxyContext)
-            : INHERITED(shader, device, paint, matrix)
-            , fMask(NULL)
-            , fProxyContext(proxyContext)
-        {
-            if (!fProxyContext) {
-                fPMColor = SkPreMultiplyColor(paint.getColor());
-            }
+    virtual void shadeSpan(int x, int y, SkPMColor span[], int count) SK_OVERRIDE {
+        if (fProxy) {
+            fProxy->shadeSpan(x, y, span, count);
         }
 
-        virtual ~Sk3DShaderContext() {
-            if (fProxyContext) {
-                fProxyContext->~Context();
+        if (fMask == NULL) {
+            if (fProxy == NULL) {
+                sk_memset32(span, fPMColor, count);
             }
+            return;
         }
 
-        void setMask(const SkMask* mask) { fMask = mask; }
+        SkASSERT(fMask->fBounds.contains(x, y));
+        SkASSERT(fMask->fBounds.contains(x + count - 1, y));
 
-        virtual void shadeSpan(int x, int y, SkPMColor span[], int count) SK_OVERRIDE {
-            if (fProxyContext) {
-                fProxyContext->shadeSpan(x, y, span, count);
-            }
+        size_t          size = fMask->computeImageSize();
+        const uint8_t*  alpha = fMask->getAddr8(x, y);
+        const uint8_t*  mulp = alpha + size;
+        const uint8_t*  addp = mulp + size;
 
-            if (fMask == NULL) {
-                if (fProxyContext == NULL) {
-                    sk_memset32(span, fPMColor, count);
-                }
-                return;
-            }
+        if (fProxy) {
+            for (int i = 0; i < count; i++) {
+                if (alpha[i]) {
+                    SkPMColor c = span[i];
+                    if (c) {
+                        unsigned a = SkGetPackedA32(c);
+                        unsigned r = SkGetPackedR32(c);
+                        unsigned g = SkGetPackedG32(c);
+                        unsigned b = SkGetPackedB32(c);
 
-            SkASSERT(fMask->fBounds.contains(x, y));
-            SkASSERT(fMask->fBounds.contains(x + count - 1, y));
-
-            size_t          size = fMask->computeImageSize();
-            const uint8_t*  alpha = fMask->getAddr8(x, y);
-            const uint8_t*  mulp = alpha + size;
-            const uint8_t*  addp = mulp + size;
-
-            if (fProxyContext) {
-                for (int i = 0; i < count; i++) {
-                    if (alpha[i]) {
-                        SkPMColor c = span[i];
-                        if (c) {
-                            unsigned a = SkGetPackedA32(c);
-                            unsigned r = SkGetPackedR32(c);
-                            unsigned g = SkGetPackedG32(c);
-                            unsigned b = SkGetPackedB32(c);
-
-                            unsigned mul = SkAlpha255To256(mulp[i]);
-                            unsigned add = addp[i];
-
-                            r = SkFastMin32(SkAlphaMul(r, mul) + add, a);
-                            g = SkFastMin32(SkAlphaMul(g, mul) + add, a);
-                            b = SkFastMin32(SkAlphaMul(b, mul) + add, a);
-
-                            span[i] = SkPackARGB32(a, r, g, b);
-                        }
-                    } else {
-                        span[i] = 0;
-                    }
-                }
-            } else {    // color
-                unsigned a = SkGetPackedA32(fPMColor);
-                unsigned r = SkGetPackedR32(fPMColor);
-                unsigned g = SkGetPackedG32(fPMColor);
-                unsigned b = SkGetPackedB32(fPMColor);
-                for (int i = 0; i < count; i++) {
-                    if (alpha[i]) {
                         unsigned mul = SkAlpha255To256(mulp[i]);
                         unsigned add = addp[i];
 
-                        span[i] = SkPackARGB32( a,
-                                        SkFastMin32(SkAlphaMul(r, mul) + add, a),
-                                        SkFastMin32(SkAlphaMul(g, mul) + add, a),
-                                        SkFastMin32(SkAlphaMul(b, mul) + add, a));
-                    } else {
-                        span[i] = 0;
+                        r = SkFastMin32(SkAlphaMul(r, mul) + add, a);
+                        g = SkFastMin32(SkAlphaMul(g, mul) + add, a);
+                        b = SkFastMin32(SkAlphaMul(b, mul) + add, a);
+
+                        span[i] = SkPackARGB32(a, r, g, b);
                     }
+                } else {
+                    span[i] = 0;
+                }
+            }
+        } else {    // color
+            unsigned a = SkGetPackedA32(fPMColor);
+            unsigned r = SkGetPackedR32(fPMColor);
+            unsigned g = SkGetPackedG32(fPMColor);
+            unsigned b = SkGetPackedB32(fPMColor);
+            for (int i = 0; i < count; i++) {
+                if (alpha[i]) {
+                    unsigned mul = SkAlpha255To256(mulp[i]);
+                    unsigned add = addp[i];
+
+                    span[i] = SkPackARGB32( a,
+                                    SkFastMin32(SkAlphaMul(r, mul) + add, a),
+                                    SkFastMin32(SkAlphaMul(g, mul) + add, a),
+                                    SkFastMin32(SkAlphaMul(b, mul) + add, a));
+                } else {
+                    span[i] = 0;
                 }
             }
         }
-
-    private:
-        // Unowned.
-        const SkMask*       fMask;
-        // Memory is unowned, but we need to call the destructor.
-        SkShader::Context*  fProxyContext;
-        SkPMColor           fPMColor;
-
-        typedef SkShader::Context INHERITED;
-    };
+    }
 
 #ifndef SK_IGNORE_TO_STRING
     virtual void toString(SkString* str) const SK_OVERRIDE {
@@ -741,30 +685,29 @@
 protected:
     Sk3DShader(SkReadBuffer& buffer) : INHERITED(buffer) {
         fProxy = buffer.readShader();
-        // Leaving this here until we bump the picture version, though this
-        // shader should never be recorded.
-        buffer.readColor();
+        fPMColor = buffer.readColor();
+        fMask = NULL;
     }
 
     virtual void flatten(SkWriteBuffer& buffer) const SK_OVERRIDE {
         this->INHERITED::flatten(buffer);
         buffer.writeFlattenable(fProxy);
-        // Leaving this here until we bump the picture version, though this
-        // shader should never be recorded.
-        buffer.writeColor(SkColor());
+        buffer.writeColor(fPMColor);
     }
 
 private:
     SkShader*       fProxy;
+    SkPMColor       fPMColor;
+    const SkMask*   fMask;
 
     typedef SkShader INHERITED;
 };
 
 class Sk3DBlitter : public SkBlitter {
 public:
-    Sk3DBlitter(SkBlitter* proxy, Sk3DShader::Sk3DShaderContext* shaderContext)
+    Sk3DBlitter(SkBlitter* proxy, Sk3DShader* shader)
         : fProxy(proxy)
-        , f3DShaderContext(shaderContext)
+        , f3DShader(SkRef(shader))
     {}
 
     virtual void blitH(int x, int y, int width) {
@@ -786,22 +729,22 @@
 
     virtual void blitMask(const SkMask& mask, const SkIRect& clip) {
         if (mask.fFormat == SkMask::k3D_Format) {
-            f3DShaderContext->setMask(&mask);
+            f3DShader->setMask(&mask);
 
             ((SkMask*)&mask)->fFormat = SkMask::kA8_Format;
             fProxy->blitMask(mask, clip);
             ((SkMask*)&mask)->fFormat = SkMask::k3D_Format;
 
-            f3DShaderContext->setMask(NULL);
+            f3DShader->setMask(NULL);
         } else {
             fProxy->blitMask(mask, clip);
         }
     }
 
 private:
-    // Both pointers are unowned. They will be deleted by SkSmallAllocator.
-    SkBlitter*                     fProxy;
-    Sk3DShader::Sk3DShaderContext* f3DShaderContext;
+    // fProxy is unowned. It will be deleted by SkSmallAllocator.
+    SkBlitter*               fProxy;
+    SkAutoTUnref<Sk3DShader> f3DShader;
 };
 
 ///////////////////////////////////////////////////////////////////////////////
@@ -811,7 +754,8 @@
 static bool just_solid_color(const SkPaint& paint) {
     if (paint.getAlpha() == 0xFF && paint.getColorFilter() == NULL) {
         SkShader* shader = paint.getShader();
-        if (NULL == shader) {
+        if (NULL == shader ||
+            (shader->getFlags() & SkShader::kOpaqueAlpha_Flag)) {
             return true;
         }
     }
@@ -949,22 +893,16 @@
     }
 
     /*
-     *  We create a SkShader::Context object, and store it on the blitter.
+     *  We need to have balanced calls to the shader:
+     *      setContext
+     *      endContext
+     *  We make the first call here, in case it fails we can abort the draw.
+     *  The endContext() call is made by the blitter (assuming setContext did
+     *  not fail) in its destructor.
      */
-    SkShader::Context* shaderContext;
-    if (shader) {
-        // Try to create the ShaderContext
-        void* storage = allocator->reserveT<SkShader::Context>(shader->contextSize());
-        shaderContext = shader->createContext(device, *paint, matrix, storage);
-        if (!shaderContext) {
-            allocator->freeLast();
-            blitter = allocator->createT<SkNullBlitter>();
-            return blitter;
-        }
-        SkASSERT(shaderContext);
-        SkASSERT((void*) shaderContext == storage);
-    } else {
-        shaderContext = NULL;
+    if (shader && !shader->setContext(device, *paint, matrix)) {
+        blitter = allocator->createT<SkNullBlitter>();
+        return blitter;
     }
 
 
@@ -975,20 +913,19 @@
                 SkASSERT(NULL == paint->getXfermode());
                 blitter = allocator->createT<SkA8_Coverage_Blitter>(device, *paint);
             } else if (shader) {
-                blitter = allocator->createT<SkA8_Shader_Blitter>(device, *paint, shaderContext);
+                blitter = allocator->createT<SkA8_Shader_Blitter>(device, *paint);
             } else {
                 blitter = allocator->createT<SkA8_Blitter>(device, *paint);
             }
             break;
 
         case kRGB_565_SkColorType:
-            blitter = SkBlitter_ChooseD565(device, *paint, shaderContext, allocator);
+            blitter = SkBlitter_ChooseD565(device, *paint, allocator);
             break;
 
         case kN32_SkColorType:
             if (shader) {
-                blitter = allocator->createT<SkARGB32_Shader_Blitter>(
-                        device, *paint, shaderContext);
+                blitter = allocator->createT<SkARGB32_Shader_Blitter>(device, *paint);
             } else if (paint->getColor() == SK_ColorBLACK) {
                 blitter = allocator->createT<SkARGB32_Black_Blitter>(device, *paint);
             } else if (paint->getAlpha() == 0xFF) {
@@ -1007,9 +944,7 @@
     if (shader3D) {
         SkBlitter* innerBlitter = blitter;
         // innerBlitter was allocated by allocator, which will delete it.
-        // We know shaderContext is of type Sk3DShaderContext because it belongs to shader3D.
-        blitter = allocator->createT<Sk3DBlitter>(innerBlitter,
-                static_cast<Sk3DShader::Sk3DShaderContext*>(shaderContext));
+        blitter = allocator->createT<Sk3DBlitter>(innerBlitter, shader3D);
     }
     return blitter;
 }
@@ -1021,36 +956,18 @@
 
 ///////////////////////////////////////////////////////////////////////////////
 
-SkShaderBlitter::SkShaderBlitter(const SkBitmap& device, const SkPaint& paint,
-                                 SkShader::Context* shaderContext)
-        : INHERITED(device)
-        , fShader(paint.getShader())
-        , fShaderContext(shaderContext) {
+SkShaderBlitter::SkShaderBlitter(const SkBitmap& device, const SkPaint& paint)
+        : INHERITED(device) {
+    fShader = paint.getShader();
     SkASSERT(fShader);
-    SkASSERT(fShaderContext);
+    SkASSERT(fShader->setContextHasBeenCalled());
 
     fShader->ref();
-    fShaderFlags = fShaderContext->getFlags();
+    fShaderFlags = fShader->getFlags();
 }
 
 SkShaderBlitter::~SkShaderBlitter() {
+    SkASSERT(fShader->setContextHasBeenCalled());
+    fShader->endContext();
     fShader->unref();
 }
-
-bool SkShaderBlitter::resetShaderContext(const SkBitmap& device, const SkPaint& paint,
-                                         const SkMatrix& matrix) {
-    if (!fShader->validContext(device, paint, matrix)) {
-        return false;
-    }
-
-    // Only destroy the old context if we have a new one. We need to ensure to have a
-    // live context in fShaderContext because the storage is owned by an SkSmallAllocator
-    // outside of this class.
-    // The new context will be of the same size as the old one because we use the same
-    // shader to create it. It is therefore safe to re-use the storage.
-    fShaderContext->~Context();
-    fShaderContext = fShader->createContext(device, paint, matrix, (void*)fShaderContext);
-    SkASSERT(fShaderContext);
-
-    return true;
-}
diff --git a/src/core/SkBlitter.h b/src/core/SkBlitter.h
index f76839e..d19a34b 100644
--- a/src/core/SkBlitter.h
+++ b/src/core/SkBlitter.h
@@ -61,13 +61,6 @@
      */
     virtual bool isNullBlitter() const;
 
-    /**
-     *  Special methods for SkShaderBlitter. On all other classes this is a no-op.
-     */
-    virtual bool resetShaderContext(const SkBitmap& device, const SkPaint& paint,
-                                    const SkMatrix& matrix);
-    virtual SkShader::Context* getShaderContext() const;
-
     ///@name non-virtual helpers
     void blitMaskRegion(const SkMask& mask, const SkRegion& clip);
     void blitRectRegion(const SkIRect& rect, const SkRegion& clip);
diff --git a/src/core/SkBlitter_A8.cpp b/src/core/SkBlitter_A8.cpp
index 11f4259..983a226 100644
--- a/src/core/SkBlitter_A8.cpp
+++ b/src/core/SkBlitter_A8.cpp
@@ -228,12 +228,11 @@
 
 ///////////////////////////////////////////////////////////////////////
 
-SkA8_Shader_Blitter::SkA8_Shader_Blitter(const SkBitmap& device, const SkPaint& paint,
-                                         SkShader::Context* shaderContext)
-    : INHERITED(device, paint, shaderContext) {
+SkA8_Shader_Blitter::SkA8_Shader_Blitter(const SkBitmap& device, const SkPaint& paint)
+    : INHERITED(device, paint) {
     if ((fXfermode = paint.getXfermode()) != NULL) {
         fXfermode->ref();
-        SkASSERT(fShaderContext);
+        SkASSERT(fShader);
     }
 
     int width = device.width();
@@ -251,14 +250,13 @@
              (unsigned)(x + width) <= (unsigned)fDevice.width());
 
     uint8_t* device = fDevice.getAddr8(x, y);
-    SkShader::Context* shaderContext = fShaderContext;
 
-    if ((shaderContext->getFlags() & SkShader::kOpaqueAlpha_Flag) && !fXfermode) {
+    if ((fShader->getFlags() & SkShader::kOpaqueAlpha_Flag) && !fXfermode) {
         memset(device, 0xFF, width);
     } else {
         SkPMColor*  span = fBuffer;
 
-        shaderContext->shadeSpan(x, y, span, width);
+        fShader->shadeSpan(x, y, span, width);
         if (fXfermode) {
             fXfermode->xferA8(device, span, width, NULL);
         } else {
@@ -284,12 +282,12 @@
 
 void SkA8_Shader_Blitter::blitAntiH(int x, int y, const SkAlpha antialias[],
                                     const int16_t runs[]) {
-    SkShader::Context* shaderContext = fShaderContext;
-    SkXfermode*        mode = fXfermode;
-    uint8_t*           aaExpand = fAAExpand;
-    SkPMColor*         span = fBuffer;
-    uint8_t*           device = fDevice.getAddr8(x, y);
-    int                opaque = shaderContext->getFlags() & SkShader::kOpaqueAlpha_Flag;
+    SkShader*   shader = fShader;
+    SkXfermode* mode = fXfermode;
+    uint8_t*    aaExpand = fAAExpand;
+    SkPMColor*  span = fBuffer;
+    uint8_t*    device = fDevice.getAddr8(x, y);
+    int         opaque = fShader->getFlags() & SkShader::kOpaqueAlpha_Flag;
 
     for (;;) {
         int count = *runs;
@@ -301,7 +299,7 @@
             if (opaque && aa == 255 && mode == NULL) {
                 memset(device, 0xFF, count);
             } else {
-                shaderContext->shadeSpan(x, y, span, count);
+                shader->shadeSpan(x, y, span, count);
                 if (mode) {
                     memset(aaExpand, aa, count);
                     mode->xferA8(device, span, count, aaExpand);
@@ -331,12 +329,11 @@
     int height = clip.height();
     uint8_t* device = fDevice.getAddr8(x, y);
     const uint8_t* alpha = mask.getAddr8(x, y);
-    SkShader::Context* shaderContext = fShaderContext;
 
     SkPMColor*  span = fBuffer;
 
     while (--height >= 0) {
-        shaderContext->shadeSpan(x, y, span, width);
+        fShader->shadeSpan(x, y, span, width);
         if (fXfermode) {
             fXfermode->xferA8(device, span, width, alpha);
         } else {
diff --git a/src/core/SkBlitter_ARGB32.cpp b/src/core/SkBlitter_ARGB32.cpp
index 118a1d1..d4bec1b 100644
--- a/src/core/SkBlitter_ARGB32.cpp
+++ b/src/core/SkBlitter_ARGB32.cpp
@@ -275,16 +275,14 @@
 }
 
 SkARGB32_Shader_Blitter::SkARGB32_Shader_Blitter(const SkBitmap& device,
-        const SkPaint& paint, SkShader::Context* shaderContext)
-    : INHERITED(device, paint, shaderContext)
-{
+                            const SkPaint& paint) : INHERITED(device, paint) {
     fBuffer = (SkPMColor*)sk_malloc_throw(device.width() * (sizeof(SkPMColor)));
 
     fXfermode = paint.getXfermode();
     SkSafeRef(fXfermode);
 
     int flags = 0;
-    if (!(shaderContext->getFlags() & SkShader::kOpaqueAlpha_Flag)) {
+    if (!(fShader->getFlags() & SkShader::kOpaqueAlpha_Flag)) {
         flags |= SkBlitRow::kSrcPixelAlpha_Flag32;
     }
     // we call this on the output from the shader
@@ -294,7 +292,7 @@
 
     fShadeDirectlyIntoDevice = false;
     if (fXfermode == NULL) {
-        if (shaderContext->getFlags() & SkShader::kOpaqueAlpha_Flag) {
+        if (fShader->getFlags() & SkShader::kOpaqueAlpha_Flag) {
             fShadeDirectlyIntoDevice = true;
         }
     } else {
@@ -307,7 +305,7 @@
         }
     }
 
-    fConstInY = SkToBool(shaderContext->getFlags() & SkShader::kConstInY32_Flag);
+    fConstInY = SkToBool(fShader->getFlags() & SkShader::kConstInY32_Flag);
 }
 
 SkARGB32_Shader_Blitter::~SkARGB32_Shader_Blitter() {
@@ -321,10 +319,10 @@
     uint32_t*   device = fDevice.getAddr32(x, y);
 
     if (fShadeDirectlyIntoDevice) {
-        fShaderContext->shadeSpan(x, y, device, width);
+        fShader->shadeSpan(x, y, device, width);
     } else {
         SkPMColor*  span = fBuffer;
-        fShaderContext->shadeSpan(x, y, span, width);
+        fShader->shadeSpan(x, y, span, width);
         if (fXfermode) {
             fXfermode->xfer32(device, span, width, NULL);
         } else {
@@ -337,22 +335,22 @@
     SkASSERT(x >= 0 && y >= 0 &&
              x + width <= fDevice.width() && y + height <= fDevice.height());
 
-    uint32_t*          device = fDevice.getAddr32(x, y);
-    size_t             deviceRB = fDevice.rowBytes();
-    SkShader::Context* shaderContext = fShaderContext;
-    SkPMColor*         span = fBuffer;
+    uint32_t*   device = fDevice.getAddr32(x, y);
+    size_t      deviceRB = fDevice.rowBytes();
+    SkShader*   shader = fShader;
+    SkPMColor*  span = fBuffer;
 
     if (fConstInY) {
         if (fShadeDirectlyIntoDevice) {
             // shade the first row directly into the device
-            shaderContext->shadeSpan(x, y, device, width);
+            fShader->shadeSpan(x, y, device, width);
             span = device;
             while (--height > 0) {
                 device = (uint32_t*)((char*)device + deviceRB);
                 memcpy(device, span, width << 2);
             }
         } else {
-            shaderContext->shadeSpan(x, y, span, width);
+            fShader->shadeSpan(x, y, span, width);
             SkXfermode* xfer = fXfermode;
             if (xfer) {
                 do {
@@ -374,7 +372,7 @@
 
     if (fShadeDirectlyIntoDevice) {
         void* ctx;
-        SkShader::Context::ShadeProc shadeProc = shaderContext->asAShadeProc(&ctx);
+        SkShader::ShadeProc shadeProc = fShader->asAShadeProc(&ctx);
         if (shadeProc) {
             do {
                 shadeProc(ctx, x, y, device, width);
@@ -383,7 +381,7 @@
             } while (--height > 0);
         } else {
             do {
-                shaderContext->shadeSpan(x, y, device, width);
+                shader->shadeSpan(x, y, device, width);
                 y += 1;
                 device = (uint32_t*)((char*)device + deviceRB);
             } while (--height > 0);
@@ -392,7 +390,7 @@
         SkXfermode* xfer = fXfermode;
         if (xfer) {
             do {
-                shaderContext->shadeSpan(x, y, span, width);
+                shader->shadeSpan(x, y, span, width);
                 xfer->xfer32(device, span, width, NULL);
                 y += 1;
                 device = (uint32_t*)((char*)device + deviceRB);
@@ -400,7 +398,7 @@
         } else {
             SkBlitRow::Proc32 proc = fProc32;
             do {
-                shaderContext->shadeSpan(x, y, span, width);
+                shader->shadeSpan(x, y, span, width);
                 proc(device, span, width, 255);
                 y += 1;
                 device = (uint32_t*)((char*)device + deviceRB);
@@ -411,9 +409,9 @@
 
 void SkARGB32_Shader_Blitter::blitAntiH(int x, int y, const SkAlpha antialias[],
                                         const int16_t runs[]) {
-    SkPMColor*         span = fBuffer;
-    uint32_t*          device = fDevice.getAddr32(x, y);
-    SkShader::Context* shaderContext = fShaderContext;
+    SkPMColor*  span = fBuffer;
+    uint32_t*   device = fDevice.getAddr32(x, y);
+    SkShader*   shader = fShader;
 
     if (fXfermode && !fShadeDirectlyIntoDevice) {
         for (;;) {
@@ -424,7 +422,7 @@
                 break;
             int aa = *antialias;
             if (aa) {
-                shaderContext->shadeSpan(x, y, span, count);
+                shader->shadeSpan(x, y, span, count);
                 if (aa == 255) {
                     xfer->xfer32(device, span, count, NULL);
                 } else {
@@ -440,7 +438,7 @@
             x += count;
         }
     } else if (fShadeDirectlyIntoDevice ||
-               (shaderContext->getFlags() & SkShader::kOpaqueAlpha_Flag)) {
+               (fShader->getFlags() & SkShader::kOpaqueAlpha_Flag)) {
         for (;;) {
             int count = *runs;
             if (count <= 0) {
@@ -450,9 +448,9 @@
             if (aa) {
                 if (aa == 255) {
                     // cool, have the shader draw right into the device
-                    shaderContext->shadeSpan(x, y, device, count);
+                    shader->shadeSpan(x, y, device, count);
                 } else {
-                    shaderContext->shadeSpan(x, y, span, count);
+                    shader->shadeSpan(x, y, span, count);
                     fProc32Blend(device, span, count, aa);
                 }
             }
@@ -469,7 +467,7 @@
             }
             int aa = *antialias;
             if (aa) {
-                shaderContext->shadeSpan(x, y, span, count);
+                fShader->shadeSpan(x, y, span, count);
                 if (aa == 255) {
                     fProc32(device, span, count, 255);
                 } else {
@@ -493,11 +491,10 @@
 
     SkASSERT(mask.fBounds.contains(clip));
 
-    SkShader::Context*  shaderContext = fShaderContext;
     SkBlitMask::RowProc proc = NULL;
     if (!fXfermode) {
         unsigned flags = 0;
-        if (shaderContext->getFlags() & SkShader::kOpaqueAlpha_Flag) {
+        if (fShader->getFlags() & SkShader::kOpaqueAlpha_Flag) {
             flags |= SkBlitMask::kSrcIsOpaque_RowFlag;
         }
         proc = SkBlitMask::RowFactory(SkBitmap::kARGB_8888_Config, mask.fFormat,
@@ -518,13 +515,14 @@
     const uint8_t* maskRow = (const uint8_t*)mask.getAddr(x, y);
     const size_t maskRB = mask.fRowBytes;
 
+    SkShader* shader = fShader;
     SkPMColor* span = fBuffer;
 
     if (fXfermode) {
         SkASSERT(SkMask::kA8_Format == mask.fFormat);
         SkXfermode* xfer = fXfermode;
         do {
-            shaderContext->shadeSpan(x, y, span, width);
+            shader->shadeSpan(x, y, span, width);
             xfer->xfer32((SkPMColor*)dstRow, span, width, maskRow);
             dstRow += dstRB;
             maskRow += maskRB;
@@ -532,7 +530,7 @@
         } while (--height > 0);
     } else {
         do {
-            shaderContext->shadeSpan(x, y, span, width);
+            shader->shadeSpan(x, y, span, width);
             proc(dstRow, maskRow, span, width);
             dstRow += dstRB;
             maskRow += maskRB;
@@ -544,13 +542,13 @@
 void SkARGB32_Shader_Blitter::blitV(int x, int y, int height, SkAlpha alpha) {
     SkASSERT(x >= 0 && y >= 0 && y + height <= fDevice.height());
 
-    uint32_t*          device = fDevice.getAddr32(x, y);
-    size_t             deviceRB = fDevice.rowBytes();
-    SkShader::Context* shaderContext = fShaderContext;
+    uint32_t*   device = fDevice.getAddr32(x, y);
+    size_t      deviceRB = fDevice.rowBytes();
+    SkShader*   shader = fShader;
 
     if (fConstInY) {
         SkPMColor c;
-        shaderContext->shadeSpan(x, y, &c, 1);
+        fShader->shadeSpan(x, y, &c, 1);
 
         if (fShadeDirectlyIntoDevice) {
             if (255 == alpha) {
@@ -584,7 +582,7 @@
 
     if (fShadeDirectlyIntoDevice) {
         void* ctx;
-        SkShader::Context::ShadeProc shadeProc = shaderContext->asAShadeProc(&ctx);
+        SkShader::ShadeProc shadeProc = fShader->asAShadeProc(&ctx);
         if (255 == alpha) {
             if (shadeProc) {
                 do {
@@ -594,7 +592,7 @@
                 } while (--height > 0);
             } else {
                 do {
-                    shaderContext->shadeSpan(x, y, device, 1);
+                    shader->shadeSpan(x, y, device, 1);
                     y += 1;
                     device = (uint32_t*)((char*)device + deviceRB);
                 } while (--height > 0);
@@ -610,7 +608,7 @@
                 } while (--height > 0);
             } else {
                 do {
-                    shaderContext->shadeSpan(x, y, &c, 1);
+                    shader->shadeSpan(x, y, &c, 1);
                     *device = SkFourByteInterp(c, *device, alpha);
                     y += 1;
                     device = (uint32_t*)((char*)device + deviceRB);
@@ -622,7 +620,7 @@
         SkXfermode* xfer = fXfermode;
         if (xfer) {
             do {
-                shaderContext->shadeSpan(x, y, span, 1);
+                shader->shadeSpan(x, y, span, 1);
                 xfer->xfer32(device, span, 1, &alpha);
                 y += 1;
                 device = (uint32_t*)((char*)device + deviceRB);
@@ -630,7 +628,7 @@
         } else {
             SkBlitRow::Proc32 proc = (255 == alpha) ? fProc32 : fProc32Blend;
             do {
-                shaderContext->shadeSpan(x, y, span, 1);
+                shader->shadeSpan(x, y, span, 1);
                 proc(device, span, 1, alpha);
                 y += 1;
                 device = (uint32_t*)((char*)device + deviceRB);
diff --git a/src/core/SkBlitter_RGB16.cpp b/src/core/SkBlitter_RGB16.cpp
index e22aac4..21b5a16 100644
--- a/src/core/SkBlitter_RGB16.cpp
+++ b/src/core/SkBlitter_RGB16.cpp
@@ -107,8 +107,7 @@
 
 class SkRGB16_Shader_Blitter : public SkShaderBlitter {
 public:
-    SkRGB16_Shader_Blitter(const SkBitmap& device, const SkPaint& paint,
-                           SkShader::Context* shaderContext);
+    SkRGB16_Shader_Blitter(const SkBitmap& device, const SkPaint& paint);
     virtual ~SkRGB16_Shader_Blitter();
     virtual void blitH(int x, int y, int width);
     virtual void blitAntiH(int x, int y, const SkAlpha* antialias,
@@ -130,8 +129,7 @@
 // used only if the shader can perform shadSpan16
 class SkRGB16_Shader16_Blitter : public SkRGB16_Shader_Blitter {
 public:
-    SkRGB16_Shader16_Blitter(const SkBitmap& device, const SkPaint& paint,
-                             SkShader::Context* shaderContext);
+    SkRGB16_Shader16_Blitter(const SkBitmap& device, const SkPaint& paint);
     virtual void blitH(int x, int y, int width);
     virtual void blitAntiH(int x, int y, const SkAlpha* antialias,
                            const int16_t* runs);
@@ -143,8 +141,7 @@
 
 class SkRGB16_Shader_Xfermode_Blitter : public SkShaderBlitter {
 public:
-    SkRGB16_Shader_Xfermode_Blitter(const SkBitmap& device, const SkPaint& paint,
-                                    SkShader::Context* shaderContext);
+    SkRGB16_Shader_Xfermode_Blitter(const SkBitmap& device, const SkPaint& paint);
     virtual ~SkRGB16_Shader_Xfermode_Blitter();
     virtual void blitH(int x, int y, int width);
     virtual void blitAntiH(int x, int y, const SkAlpha* antialias,
@@ -682,9 +679,8 @@
 ///////////////////////////////////////////////////////////////////////////////
 
 SkRGB16_Shader16_Blitter::SkRGB16_Shader16_Blitter(const SkBitmap& device,
-                                                   const SkPaint& paint,
-                                                   SkShader::Context* shaderContext)
-    : SkRGB16_Shader_Blitter(device, paint, shaderContext) {
+                                                   const SkPaint& paint)
+    : SkRGB16_Shader_Blitter(device, paint) {
     SkASSERT(SkShader::CanCallShadeSpan16(fShaderFlags));
 }
 
@@ -692,28 +688,28 @@
     SkASSERT(x + width <= fDevice.width());
 
     uint16_t* SK_RESTRICT device = fDevice.getAddr16(x, y);
-    SkShader::Context*    shaderContext = fShaderContext;
+    SkShader*   shader = fShader;
 
-    int alpha = shaderContext->getSpan16Alpha();
+    int alpha = shader->getSpan16Alpha();
     if (0xFF == alpha) {
-        shaderContext->shadeSpan16(x, y, device, width);
+        shader->shadeSpan16(x, y, device, width);
     } else {
         uint16_t* span16 = (uint16_t*)fBuffer;
-        shaderContext->shadeSpan16(x, y, span16, width);
+        shader->shadeSpan16(x, y, span16, width);
         SkBlendRGB16(span16, device, SkAlpha255To256(alpha), width);
     }
 }
 
 void SkRGB16_Shader16_Blitter::blitRect(int x, int y, int width, int height) {
-    SkShader::Context* shaderContext = fShaderContext;
-    uint16_t*          dst = fDevice.getAddr16(x, y);
-    size_t             dstRB = fDevice.rowBytes();
-    int                alpha = shaderContext->getSpan16Alpha();
+    SkShader*   shader = fShader;
+    uint16_t*   dst = fDevice.getAddr16(x, y);
+    size_t      dstRB = fDevice.rowBytes();
+    int         alpha = shader->getSpan16Alpha();
 
     if (0xFF == alpha) {
         if (fShaderFlags & SkShader::kConstInY16_Flag) {
             // have the shader blit directly into the device the first time
-            shaderContext->shadeSpan16(x, y, dst, width);
+            shader->shadeSpan16(x, y, dst, width);
             // and now just memcpy that line on the subsequent lines
             if (--height > 0) {
                 const uint16_t* orig = dst;
@@ -724,7 +720,7 @@
             }
         } else {    // need to call shadeSpan16 for every line
             do {
-                shaderContext->shadeSpan16(x, y, dst, width);
+                shader->shadeSpan16(x, y, dst, width);
                 y += 1;
                 dst = (uint16_t*)((char*)dst + dstRB);
             } while (--height);
@@ -733,14 +729,14 @@
         int scale = SkAlpha255To256(alpha);
         uint16_t* span16 = (uint16_t*)fBuffer;
         if (fShaderFlags & SkShader::kConstInY16_Flag) {
-            shaderContext->shadeSpan16(x, y, span16, width);
+            shader->shadeSpan16(x, y, span16, width);
             do {
                 SkBlendRGB16(span16, dst, scale, width);
                 dst = (uint16_t*)((char*)dst + dstRB);
             } while (--height);
         } else {
             do {
-                shaderContext->shadeSpan16(x, y, span16, width);
+                shader->shadeSpan16(x, y, span16, width);
                 SkBlendRGB16(span16, dst, scale, width);
                 y += 1;
                 dst = (uint16_t*)((char*)dst + dstRB);
@@ -752,11 +748,11 @@
 void SkRGB16_Shader16_Blitter::blitAntiH(int x, int y,
                                          const SkAlpha* SK_RESTRICT antialias,
                                          const int16_t* SK_RESTRICT runs) {
-    SkShader::Context*     shaderContext = fShaderContext;
+    SkShader*   shader = fShader;
     SkPMColor* SK_RESTRICT span = fBuffer;
-    uint16_t* SK_RESTRICT  device = fDevice.getAddr16(x, y);
+    uint16_t* SK_RESTRICT device = fDevice.getAddr16(x, y);
 
-    int alpha = shaderContext->getSpan16Alpha();
+    int alpha = shader->getSpan16Alpha();
     uint16_t* span16 = (uint16_t*)span;
 
     if (0xFF == alpha) {
@@ -770,9 +766,9 @@
             int aa = *antialias;
             if (aa == 255) {
                 // go direct to the device!
-                shaderContext->shadeSpan16(x, y, device, count);
+                shader->shadeSpan16(x, y, device, count);
             } else if (aa) {
-                shaderContext->shadeSpan16(x, y, span16, count);
+                shader->shadeSpan16(x, y, span16, count);
                 SkBlendRGB16(span16, device, SkAlpha255To256(aa), count);
             }
             device += count;
@@ -791,7 +787,7 @@
 
             int aa = SkAlphaMul(*antialias, alpha);
             if (aa) {
-                shaderContext->shadeSpan16(x, y, span16, count);
+                shader->shadeSpan16(x, y, span16, count);
                 SkBlendRGB16(span16, device, SkAlpha255To256(aa), count);
             }
 
@@ -806,9 +802,8 @@
 ///////////////////////////////////////////////////////////////////////////////
 
 SkRGB16_Shader_Blitter::SkRGB16_Shader_Blitter(const SkBitmap& device,
-                                               const SkPaint& paint,
-                                               SkShader::Context* shaderContext)
-: INHERITED(device, paint, shaderContext) {
+                                               const SkPaint& paint)
+: INHERITED(device, paint) {
     SkASSERT(paint.getXfermode() == NULL);
 
     fBuffer = (SkPMColor*)sk_malloc_throw(device.width() * sizeof(SkPMColor));
@@ -839,20 +834,20 @@
 void SkRGB16_Shader_Blitter::blitH(int x, int y, int width) {
     SkASSERT(x + width <= fDevice.width());
 
-    fShaderContext->shadeSpan(x, y, fBuffer, width);
+    fShader->shadeSpan(x, y, fBuffer, width);
     // shaders take care of global alpha, so we pass 0xFF (should be ignored)
     fOpaqueProc(fDevice.getAddr16(x, y), fBuffer, width, 0xFF, x, y);
 }
 
 void SkRGB16_Shader_Blitter::blitRect(int x, int y, int width, int height) {
-    SkShader::Context* shaderContext = fShaderContext;
-    SkBlitRow::Proc    proc = fOpaqueProc;
-    SkPMColor*         buffer = fBuffer;
-    uint16_t*          dst = fDevice.getAddr16(x, y);
-    size_t             dstRB = fDevice.rowBytes();
+    SkShader*       shader = fShader;
+    SkBlitRow::Proc proc = fOpaqueProc;
+    SkPMColor*      buffer = fBuffer;
+    uint16_t*       dst = fDevice.getAddr16(x, y);
+    size_t          dstRB = fDevice.rowBytes();
 
     if (fShaderFlags & SkShader::kConstInY32_Flag) {
-        shaderContext->shadeSpan(x, y, buffer, width);
+        shader->shadeSpan(x, y, buffer, width);
         do {
             proc(dst, buffer, width, 0xFF, x, y);
             y += 1;
@@ -860,7 +855,7 @@
         } while (--height);
     } else {
         do {
-            shaderContext->shadeSpan(x, y, buffer, width);
+            shader->shadeSpan(x, y, buffer, width);
             proc(dst, buffer, width, 0xFF, x, y);
             y += 1;
             dst = (uint16_t*)((char*)dst + dstRB);
@@ -885,9 +880,9 @@
 void SkRGB16_Shader_Blitter::blitAntiH(int x, int y,
                                        const SkAlpha* SK_RESTRICT antialias,
                                        const int16_t* SK_RESTRICT runs) {
-    SkShader::Context*     shaderContext = fShaderContext;
+    SkShader*   shader = fShader;
     SkPMColor* SK_RESTRICT span = fBuffer;
-    uint16_t* SK_RESTRICT  device = fDevice.getAddr16(x, y);
+    uint16_t* SK_RESTRICT device = fDevice.getAddr16(x, y);
 
     for (;;) {
         int count = *runs;
@@ -906,7 +901,7 @@
         int nonZeroCount = count + count_nonzero_span(runs + count, antialias + count);
 
         SkASSERT(nonZeroCount <= fDevice.width()); // don't overrun fBuffer
-        shaderContext->shadeSpan(x, y, span, nonZeroCount);
+        shader->shadeSpan(x, y, span, nonZeroCount);
 
         SkPMColor* localSpan = span;
         for (;;) {
@@ -933,9 +928,8 @@
 ///////////////////////////////////////////////////////////////////////
 
 SkRGB16_Shader_Xfermode_Blitter::SkRGB16_Shader_Xfermode_Blitter(
-                                const SkBitmap& device, const SkPaint& paint,
-                                SkShader::Context* shaderContext)
-: INHERITED(device, paint, shaderContext) {
+                                const SkBitmap& device, const SkPaint& paint)
+: INHERITED(device, paint) {
     fXfermode = paint.getXfermode();
     SkASSERT(fXfermode);
     fXfermode->ref();
@@ -956,18 +950,18 @@
     uint16_t*   device = fDevice.getAddr16(x, y);
     SkPMColor*  span = fBuffer;
 
-    fShaderContext->shadeSpan(x, y, span, width);
+    fShader->shadeSpan(x, y, span, width);
     fXfermode->xfer16(device, span, width, NULL);
 }
 
 void SkRGB16_Shader_Xfermode_Blitter::blitAntiH(int x, int y,
                                 const SkAlpha* SK_RESTRICT antialias,
                                 const int16_t* SK_RESTRICT runs) {
-    SkShader::Context*     shaderContext = fShaderContext;
-    SkXfermode*            mode = fXfermode;
+    SkShader*   shader = fShader;
+    SkXfermode* mode = fXfermode;
     SkPMColor* SK_RESTRICT span = fBuffer;
-    uint8_t* SK_RESTRICT   aaExpand = fAAExpand;
-    uint16_t* SK_RESTRICT  device = fDevice.getAddr16(x, y);
+    uint8_t* SK_RESTRICT aaExpand = fAAExpand;
+    uint16_t* SK_RESTRICT device = fDevice.getAddr16(x, y);
 
     for (;;) {
         int count = *runs;
@@ -987,7 +981,7 @@
                                                       antialias + count);
 
         SkASSERT(nonZeroCount <= fDevice.width()); // don't overrun fBuffer
-        shaderContext->shadeSpan(x, y, span, nonZeroCount);
+        shader->shadeSpan(x, y, span, nonZeroCount);
 
         x += nonZeroCount;
         SkPMColor* localSpan = span;
@@ -1018,7 +1012,6 @@
 ///////////////////////////////////////////////////////////////////////////////
 
 SkBlitter* SkBlitter_ChooseD565(const SkBitmap& device, const SkPaint& paint,
-        SkShader::Context* shaderContext,
         SkTBlitterAllocator* allocator) {
     SkASSERT(allocator != NULL);
 
@@ -1030,14 +1023,12 @@
     SkASSERT(NULL == mode || NULL != shader);
 
     if (shader) {
-        SkASSERT(shaderContext != NULL);
         if (mode) {
-            blitter = allocator->createT<SkRGB16_Shader_Xfermode_Blitter>(device, paint,
-                                                                          shaderContext);
-        } else if (shaderContext->canCallShadeSpan16()) {
-            blitter = allocator->createT<SkRGB16_Shader16_Blitter>(device, paint, shaderContext);
+            blitter = allocator->createT<SkRGB16_Shader_Xfermode_Blitter>(device, paint);
+        } else if (shader->canCallShadeSpan16()) {
+            blitter = allocator->createT<SkRGB16_Shader16_Blitter>(device, paint);
         } else {
-            blitter = allocator->createT<SkRGB16_Shader_Blitter>(device, paint, shaderContext);
+            blitter = allocator->createT<SkRGB16_Shader_Blitter>(device, paint);
         }
     } else {
         // no shader, no xfermode, (and we always ignore colorfilter)
diff --git a/src/core/SkCanvas.cpp b/src/core/SkCanvas.cpp
index e3451cd..d839971 100644
--- a/src/core/SkCanvas.cpp
+++ b/src/core/SkCanvas.cpp
@@ -91,10 +91,32 @@
 };
 #endif
 
+class AutoCheckNoSetContext {
+public:
+    AutoCheckNoSetContext(const SkPaint& paint) : fPaint(paint) {
+        this->assertNoSetContext(fPaint);
+    }
+    ~AutoCheckNoSetContext() {
+        this->assertNoSetContext(fPaint);
+    }
+
+private:
+    const SkPaint& fPaint;
+
+    void assertNoSetContext(const SkPaint& paint) {
+        SkShader* s = paint.getShader();
+        if (s) {
+            SkASSERT(!s->setContextHasBeenCalled());
+        }
+    }
+};
+
 #define CHECK_LOCKCOUNT_BALANCE(bitmap)  AutoCheckLockCountBalance clcb(bitmap)
+#define CHECK_SHADER_NOSETCONTEXT(paint) AutoCheckNoSetContext     cshsc(paint)
 
 #else
     #define CHECK_LOCKCOUNT_BALANCE(bitmap)
+    #define CHECK_SHADER_NOSETCONTEXT(paint)
 #endif
 
 typedef SkTLazy<SkPaint> SkLazyPaint;
@@ -1918,6 +1940,8 @@
 }
 
 void SkCanvas::internalDrawPaint(const SkPaint& paint) {
+    CHECK_SHADER_NOSETCONTEXT(paint);
+
     LOOPER_BEGIN(paint, SkDrawFilter::kPaint_Type, NULL)
 
     while (iter.next()) {
@@ -1933,6 +1957,8 @@
         return;
     }
 
+    CHECK_SHADER_NOSETCONTEXT(paint);
+
     SkRect r, storage;
     const SkRect* bounds = NULL;
     if (paint.canComputeFastBounds()) {
@@ -1960,6 +1986,8 @@
 }
 
 void SkCanvas::drawRect(const SkRect& r, const SkPaint& paint) {
+    CHECK_SHADER_NOSETCONTEXT(paint);
+
     SkRect storage;
     const SkRect* bounds = NULL;
     if (paint.canComputeFastBounds()) {
@@ -1979,6 +2007,8 @@
 }
 
 void SkCanvas::drawOval(const SkRect& oval, const SkPaint& paint) {
+    CHECK_SHADER_NOSETCONTEXT(paint);
+
     SkRect storage;
     const SkRect* bounds = NULL;
     if (paint.canComputeFastBounds()) {
@@ -1998,6 +2028,8 @@
 }
 
 void SkCanvas::drawRRect(const SkRRect& rrect, const SkPaint& paint) {
+    CHECK_SHADER_NOSETCONTEXT(paint);
+
     SkRect storage;
     const SkRect* bounds = NULL;
     if (paint.canComputeFastBounds()) {
@@ -2028,6 +2060,8 @@
 
 void SkCanvas::onDrawDRRect(const SkRRect& outer, const SkRRect& inner,
                             const SkPaint& paint) {
+    CHECK_SHADER_NOSETCONTEXT(paint);
+
     SkRect storage;
     const SkRect* bounds = NULL;
     if (paint.canComputeFastBounds()) {
@@ -2047,6 +2081,8 @@
 }
 
 void SkCanvas::drawPath(const SkPath& path, const SkPaint& paint) {
+    CHECK_SHADER_NOSETCONTEXT(paint);
+
     if (!path.isFinite()) {
         return;
     }
@@ -2322,6 +2358,8 @@
 
 void SkCanvas::onDrawText(const void* text, size_t byteLength, SkScalar x, SkScalar y,
                           const SkPaint& paint) {
+    CHECK_SHADER_NOSETCONTEXT(paint);
+
     LOOPER_BEGIN(paint, SkDrawFilter::kText_Type, NULL)
 
     while (iter.next()) {
@@ -2336,8 +2374,10 @@
 
 void SkCanvas::onDrawPosText(const void* text, size_t byteLength, const SkPoint pos[],
                              const SkPaint& paint) {
+    CHECK_SHADER_NOSETCONTEXT(paint);
+    
     LOOPER_BEGIN(paint, SkDrawFilter::kText_Type, NULL)
-
+    
     while (iter.next()) {
         SkDeviceFilteredPaint dfp(iter.fDevice, looper.paint());
         iter.fDevice->drawPosText(iter, text, byteLength, &pos->fX, 0, 2,
@@ -2349,8 +2389,10 @@
 
 void SkCanvas::onDrawPosTextH(const void* text, size_t byteLength, const SkScalar xpos[],
                               SkScalar constY, const SkPaint& paint) {
+    CHECK_SHADER_NOSETCONTEXT(paint);
+    
     LOOPER_BEGIN(paint, SkDrawFilter::kText_Type, NULL)
-
+    
     while (iter.next()) {
         SkDeviceFilteredPaint dfp(iter.fDevice, looper.paint());
         iter.fDevice->drawPosText(iter, text, byteLength, xpos, constY, 1,
@@ -2362,8 +2404,10 @@
 
 void SkCanvas::onDrawTextOnPath(const void* text, size_t byteLength, const SkPath& path,
                                 const SkMatrix* matrix, const SkPaint& paint) {
+    CHECK_SHADER_NOSETCONTEXT(paint);
+    
     LOOPER_BEGIN(paint, SkDrawFilter::kText_Type, NULL)
-
+    
     while (iter.next()) {
         iter.fDevice->drawTextOnPath(iter, text, byteLength, path,
                                      matrix, looper.paint());
@@ -2395,6 +2439,8 @@
                             const SkColor colors[], SkXfermode* xmode,
                             const uint16_t indices[], int indexCount,
                             const SkPaint& paint) {
+    CHECK_SHADER_NOSETCONTEXT(paint);
+
     LOOPER_BEGIN(paint, SkDrawFilter::kPath_Type, NULL)
 
     while (iter.next()) {
diff --git a/src/core/SkComposeShader.cpp b/src/core/SkComposeShader.cpp
index 77bc46f..f53eedf 100644
--- a/src/core/SkComposeShader.cpp
+++ b/src/core/SkComposeShader.cpp
@@ -45,10 +45,6 @@
     fShaderA->unref();
 }
 
-size_t SkComposeShader::contextSize() const {
-    return sizeof(ComposeShaderContext) + fShaderA->contextSize() + fShaderB->contextSize();
-}
-
 class SkAutoAlphaRestore {
 public:
     SkAutoAlphaRestore(SkPaint* paint, uint8_t newAlpha) {
@@ -73,16 +69,17 @@
     buffer.writeFlattenable(fMode);
 }
 
-/*  We call validContext/createContext on our two worker shaders.
-    However, we always let them see opaque alpha, and if the paint
-    really is translucent, then we apply that after the fact.
+/*  We call setContext on our two worker shaders. However, we
+    always let them see opaque alpha, and if the paint really
+    is translucent, then we apply that after the fact.
 
+    We need to keep the calls to setContext/endContext balanced, since if we
+    return false, our endContext() will not be called.
  */
-bool SkComposeShader::validContext(const SkBitmap& device,
-                                   const SkPaint& paint,
-                                   const SkMatrix& matrix,
-                                   SkMatrix* totalInverse) const {
-    if (!this->INHERITED::validContext(device, paint, matrix, totalInverse)) {
+bool SkComposeShader::setContext(const SkBitmap& device,
+                                 const SkPaint& paint,
+                                 const SkMatrix& matrix) {
+    if (!this->INHERITED::setContext(device, paint, matrix)) {
         return false;
     }
 
@@ -93,62 +90,38 @@
 
     tmpM.setConcat(matrix, this->getLocalMatrix());
 
-    return fShaderA->validContext(device, paint, tmpM) &&
-           fShaderB->validContext(device, paint, tmpM);
-}
-
-SkShader::Context* SkComposeShader::createContext(const SkBitmap& device, const SkPaint& paint,
-                                                  const SkMatrix& matrix, void* storage) const {
-    if (!this->validContext(device, paint, matrix)) {
-        return NULL;
-    }
-
-    // we preconcat our localMatrix (if any) with the device matrix
-    // before calling our sub-shaders
-
-    SkMatrix tmpM;
-
-    tmpM.setConcat(matrix, this->getLocalMatrix());
-
     SkAutoAlphaRestore  restore(const_cast<SkPaint*>(&paint), 0xFF);
 
-    char* aStorage = (char*) storage + sizeof(ComposeShaderContext);
-    char* bStorage = aStorage + fShaderA->contextSize();
-
-    SkShader::Context* contextA = fShaderA->createContext(device, paint, tmpM, aStorage);
-    SkShader::Context* contextB = fShaderB->createContext(device, paint, tmpM, bStorage);
-
-    // Both functions must succeed; otherwise validContext should have returned
-    // false.
-    SkASSERT(contextA);
-    SkASSERT(contextB);
-
-    return SkNEW_PLACEMENT_ARGS(storage, ComposeShaderContext,
-                                (*this, device, paint, matrix, contextA, contextB));
+    bool setContextA = fShaderA->setContext(device, paint, tmpM);
+    bool setContextB = fShaderB->setContext(device, paint, tmpM);
+    if (!setContextA || !setContextB) {
+        if (setContextB) {
+            fShaderB->endContext();
+        }
+        else if (setContextA) {
+            fShaderA->endContext();
+        }
+        this->INHERITED::endContext();
+        return false;
+    }
+    return true;
 }
 
-SkComposeShader::ComposeShaderContext::ComposeShaderContext(
-        const SkComposeShader& shader, const SkBitmap& device,
-        const SkPaint& paint, const SkMatrix& matrix,
-        SkShader::Context* contextA, SkShader::Context* contextB)
-    : INHERITED(shader, device, paint, matrix)
-    , fShaderContextA(contextA)
-    , fShaderContextB(contextB) {}
-
-SkComposeShader::ComposeShaderContext::~ComposeShaderContext() {
-    fShaderContextA->~Context();
-    fShaderContextB->~Context();
+void SkComposeShader::endContext() {
+    fShaderB->endContext();
+    fShaderA->endContext();
+    this->INHERITED::endContext();
 }
 
 // larger is better (fewer times we have to loop), but we shouldn't
 // take up too much stack-space (each element is 4 bytes)
 #define TMP_COLOR_COUNT     64
 
-void SkComposeShader::ComposeShaderContext::shadeSpan(int x, int y, SkPMColor result[], int count) {
-    SkShader::Context* shaderContextA = fShaderContextA;
-    SkShader::Context* shaderContextB = fShaderContextB;
-    SkXfermode*        mode = static_cast<const SkComposeShader&>(fShader).fMode;
-    unsigned           scale = SkAlpha255To256(this->getPaintAlpha());
+void SkComposeShader::shadeSpan(int x, int y, SkPMColor result[], int count) {
+    SkShader*   shaderA = fShaderA;
+    SkShader*   shaderB = fShaderB;
+    SkXfermode* mode = fMode;
+    unsigned    scale = SkAlpha255To256(this->getPaintAlpha());
 
     SkPMColor   tmp[TMP_COLOR_COUNT];
 
@@ -161,8 +134,8 @@
                 n = TMP_COLOR_COUNT;
             }
 
-            shaderContextA->shadeSpan(x, y, result, n);
-            shaderContextB->shadeSpan(x, y, tmp, n);
+            shaderA->shadeSpan(x, y, result, n);
+            shaderB->shadeSpan(x, y, tmp, n);
 
             if (256 == scale) {
                 for (int i = 0; i < n; i++) {
@@ -186,8 +159,8 @@
                 n = TMP_COLOR_COUNT;
             }
 
-            shaderContextA->shadeSpan(x, y, result, n);
-            shaderContextB->shadeSpan(x, y, tmp, n);
+            shaderA->shadeSpan(x, y, result, n);
+            shaderB->shadeSpan(x, y, tmp, n);
             mode->xfer32(result, tmp, n, NULL);
 
             if (256 == scale) {
diff --git a/src/core/SkCoreBlitters.h b/src/core/SkCoreBlitters.h
index 2d22d38..2851840 100644
--- a/src/core/SkCoreBlitters.h
+++ b/src/core/SkCoreBlitters.h
@@ -27,29 +27,12 @@
 
 class SkShaderBlitter : public SkRasterBlitter {
 public:
-    /**
-      *  The storage for shaderContext is owned by the caller, but the object itself is not.
-      *  The blitter only ensures that the storage always holds a live object, but it may
-      *  exchange that object.
-      */
-    SkShaderBlitter(const SkBitmap& device, const SkPaint& paint,
-                    SkShader::Context* shaderContext);
+    SkShaderBlitter(const SkBitmap& device, const SkPaint& paint);
     virtual ~SkShaderBlitter();
 
-    /**
-      *  Create a new shader context and uses it instead of the old one if successful.
-      *  Will create the context at the same location as the old one (this is safe
-      *  because the shader itself is unchanged).
-      */
-    virtual bool resetShaderContext(const SkBitmap& device, const SkPaint& paint,
-                                    const SkMatrix& matrix) SK_OVERRIDE;
-
-    virtual SkShader::Context* getShaderContext() const SK_OVERRIDE { return fShaderContext; }
-
 protected:
-    uint32_t            fShaderFlags;
-    const SkShader*     fShader;
-    SkShader::Context*  fShaderContext;
+    uint32_t    fShaderFlags;
+    SkShader*   fShader;
 
 private:
     // illegal
@@ -92,8 +75,7 @@
 
 class SkA8_Shader_Blitter : public SkShaderBlitter {
 public:
-    SkA8_Shader_Blitter(const SkBitmap& device, const SkPaint& paint,
-                        SkShader::Context* shaderContext);
+    SkA8_Shader_Blitter(const SkBitmap& device, const SkPaint& paint);
     virtual ~SkA8_Shader_Blitter();
     virtual void blitH(int x, int y, int width);
     virtual void blitAntiH(int x, int y, const SkAlpha antialias[], const int16_t runs[]);
@@ -159,8 +141,7 @@
 
 class SkARGB32_Shader_Blitter : public SkShaderBlitter {
 public:
-    SkARGB32_Shader_Blitter(const SkBitmap& device, const SkPaint& paint,
-                            SkShader::Context* shaderContext);
+    SkARGB32_Shader_Blitter(const SkBitmap& device, const SkPaint& paint);
     virtual ~SkARGB32_Shader_Blitter();
     virtual void blitH(int x, int y, int width) SK_OVERRIDE;
     virtual void blitV(int x, int y, int height, SkAlpha alpha) SK_OVERRIDE;
@@ -198,7 +179,6 @@
  */
 
 SkBlitter* SkBlitter_ChooseD565(const SkBitmap& device, const SkPaint& paint,
-                                SkShader::Context* shaderContext,
                                 SkTBlitterAllocator* allocator);
 
 #endif
diff --git a/src/core/SkDraw.cpp b/src/core/SkDraw.cpp
index 6ddd0d2..7eb0be6 100644
--- a/src/core/SkDraw.cpp
+++ b/src/core/SkDraw.cpp
@@ -2354,26 +2354,9 @@
 public:
     SkTriColorShader() {}
 
-    virtual SkShader::Context* createContext(
-            const SkBitmap&, const SkPaint&, const SkMatrix&, void*) const SK_OVERRIDE;
-    virtual size_t contextSize() const SK_OVERRIDE;
+    bool setup(const SkPoint pts[], const SkColor colors[], int, int, int);
 
-    class TriColorShaderContext : public SkShader::Context {
-    public:
-        TriColorShaderContext(const SkTriColorShader& shader, const SkBitmap& device,
-                              const SkPaint& paint, const SkMatrix& matrix);
-        virtual ~TriColorShaderContext();
-
-        bool setup(const SkPoint pts[], const SkColor colors[], int, int, int);
-
-        virtual void shadeSpan(int x, int y, SkPMColor dstC[], int count) SK_OVERRIDE;
-
-    private:
-        SkMatrix    fDstToUnit;
-        SkPMColor   fColors[3];
-
-        typedef SkShader::Context INHERITED;
-    };
+    virtual void shadeSpan(int x, int y, SkPMColor dstC[], int count) SK_OVERRIDE;
 
     SK_TO_STRING_OVERRIDE()
     SK_DECLARE_PUBLIC_FLATTENABLE_DESERIALIZATION_PROCS(SkTriColorShader)
@@ -2382,20 +2365,14 @@
     SkTriColorShader(SkReadBuffer& buffer) : SkShader(buffer) {}
 
 private:
+    SkMatrix    fDstToUnit;
+    SkPMColor   fColors[3];
+
     typedef SkShader INHERITED;
 };
 
-SkShader::Context* SkTriColorShader::createContext(const SkBitmap& device, const SkPaint& paint,
-                                                   const SkMatrix& matrix, void* storage) const {
-    if (!this->validContext(device, paint, matrix)) {
-        return NULL;
-    }
-
-    return SkNEW_PLACEMENT_ARGS(storage, TriColorShaderContext, (*this, device, paint, matrix));
-}
-
-bool SkTriColorShader::TriColorShaderContext::setup(const SkPoint pts[], const SkColor colors[],
-                                                    int index0, int index1, int index2) {
+bool SkTriColorShader::setup(const SkPoint pts[], const SkColor colors[],
+                             int index0, int index1, int index2) {
 
     fColors[0] = SkPreMultiplyColor(colors[index0]);
     fColors[1] = SkPreMultiplyColor(colors[index1]);
@@ -2430,18 +2407,7 @@
     return SkAlpha255To256(scale);
 }
 
-
-SkTriColorShader::TriColorShaderContext::TriColorShaderContext(
-        const SkTriColorShader& shader, const SkBitmap& device,
-        const SkPaint& paint, const SkMatrix& matrix)
-    : INHERITED(shader, device, paint, matrix) {}
-
-SkTriColorShader::TriColorShaderContext::~TriColorShaderContext() {}
-
-size_t SkTriColorShader::contextSize() const {
-    return sizeof(TriColorShaderContext);
-}
-void SkTriColorShader::TriColorShaderContext::shadeSpan(int x, int y, SkPMColor dstC[], int count) {
+void SkTriColorShader::shadeSpan(int x, int y, SkPMColor dstC[], int count) {
     SkPoint src;
 
     for (int i = 0; i < count; i++) {
@@ -2526,7 +2492,6 @@
     }
 
     // setup the custom shader (if needed)
-    SkAutoTUnref<SkComposeShader> composeShader;
     if (NULL != colors) {
         if (NULL == textures) {
             // just colors (no texture)
@@ -2539,8 +2504,9 @@
                 xmode = SkXfermode::Create(SkXfermode::kModulate_Mode);
                 releaseMode = true;
             }
-            composeShader.reset(SkNEW_ARGS(SkComposeShader, (&triShader, shader, xmode)));
-            p.setShader(composeShader);
+            SkShader* compose = SkNEW_ARGS(SkComposeShader,
+                                           (&triShader, shader, xmode));
+            p.setShader(compose)->unref();
             if (releaseMode) {
                 xmode->unref();
             }
@@ -2548,7 +2514,9 @@
     }
 
     SkAutoBlitterChoose blitter(*fBitmap, *fMatrix, p);
-    // Abort early if we failed to create a shader context.
+    // important that we abort early, as below we may manipulate the shader
+    // and that is only valid if the shader returned true from setContext.
+    // If it returned false, then our blitter will be the NullBlitter.
     if (blitter->isNullBlitter()) {
         return;
     }
@@ -2564,38 +2532,30 @@
             savedLocalM = shader->getLocalMatrix();
         }
 
+        // setContext has already been called and verified to return true
+        // by the constructor of SkAutoBlitterChoose
+        bool prevContextSuccess = true;
         while (vertProc(&state)) {
             if (NULL != textures) {
                 if (texture_to_matrix(state, vertices, textures, &tempM)) {
                     tempM.postConcat(savedLocalM);
                     shader->setLocalMatrix(tempM);
-                    if (!blitter->resetShaderContext(*fBitmap, p, *fMatrix)) {
+                    // Need to recall setContext since we changed the local matrix.
+                    // However, we also need to balance the calls this with a
+                    // call to endContext which requires tracking the result of
+                    // the previous call to setContext.
+                    if (prevContextSuccess) {
+                        shader->endContext();
+                    }
+                    prevContextSuccess = shader->setContext(*fBitmap, p, *fMatrix);
+                    if (!prevContextSuccess) {
                         continue;
                     }
                 }
             }
             if (NULL != colors) {
-                // Find the context for triShader.
-                SkTriColorShader::TriColorShaderContext* triColorShaderContext;
-
-                SkShader::Context* shaderContext = blitter->getShaderContext();
-                SkASSERT(shaderContext);
-                if (p.getShader() == &triShader) {
-                    triColorShaderContext =
-                            static_cast<SkTriColorShader::TriColorShaderContext*>(shaderContext);
-                } else {
-                    // The shader is a compose shader and triShader is its first shader.
-                    SkASSERT(p.getShader() == composeShader);
-                    SkASSERT(composeShader->getShaderA() == &triShader);
-                    SkComposeShader::ComposeShaderContext* composeShaderContext =
-                            static_cast<SkComposeShader::ComposeShaderContext*>(shaderContext);
-                    SkShader::Context* shaderContextA = composeShaderContext->getShaderContextA();
-                    triColorShaderContext =
-                            static_cast<SkTriColorShader::TriColorShaderContext*>(shaderContextA);
-                }
-
-                if (!triColorShaderContext->setup(vertices, colors,
-                                                  state.f0, state.f1, state.f2)) {
+                if (!triShader.setup(vertices, colors,
+                                     state.f0, state.f1, state.f2)) {
                     continue;
                 }
             }
@@ -2610,6 +2570,13 @@
         if (NULL != shader) {
             shader->setLocalMatrix(savedLocalM);
         }
+
+        // If the final call to setContext fails we must make it suceed so that the
+        // call to endContext in the destructor for SkAutoBlitterChoose is balanced.
+        if (!prevContextSuccess) {
+            prevContextSuccess = shader->setContext(*fBitmap, paint, SkMatrix::I());
+            SkASSERT(prevContextSuccess);
+        }
     } else {
         // no colors[] and no texture
         HairProc hairProc = ChooseHairProc(paint.isAntiAlias());
diff --git a/src/core/SkFilterShader.cpp b/src/core/SkFilterShader.cpp
index 5c5e8f3..5896191 100644
--- a/src/core/SkFilterShader.cpp
+++ b/src/core/SkFilterShader.cpp
@@ -38,11 +38,9 @@
     buffer.writeFlattenable(fFilter);
 }
 
-uint32_t SkFilterShader::FilterShaderContext::getFlags() const {
-    const SkFilterShader& filterShader = static_cast<const SkFilterShader&>(fShader);
-
-    uint32_t shaderF = fShaderContext->getFlags();
-    uint32_t filterF = filterShader.fFilter->getFlags();
+uint32_t SkFilterShader::getFlags() {
+    uint32_t shaderF = fShader->getFlags();
+    uint32_t filterF = fFilter->getFlags();
 
     // if the filter doesn't support 16bit, clear the matching bit in the shader
     if (!(filterF & SkColorFilter::kHasFilter16_Flag)) {
@@ -55,62 +53,38 @@
     return shaderF;
 }
 
-SkShader::Context* SkFilterShader::createContext(const SkBitmap& device,
-                                                 const SkPaint& paint,
-                                                 const SkMatrix& matrix,
-                                                 void* storage) const {
-    if (!this->validContext(device, paint, matrix)) {
-        return NULL;
+bool SkFilterShader::setContext(const SkBitmap& device,
+                                const SkPaint& paint,
+                                const SkMatrix& matrix) {
+    // we need to keep the setContext/endContext calls balanced. If we return
+    // false, our endContext() will not be called.
+
+    if (!this->INHERITED::setContext(device, paint, matrix)) {
+        return false;
     }
-
-    char* shaderContextStorage = (char*)storage + sizeof(FilterShaderContext);
-    SkShader::Context* shaderContext = fShader->createContext(device, paint, matrix,
-                                                              shaderContextStorage);
-    SkASSERT(shaderContext);
-
-    return SkNEW_PLACEMENT_ARGS(storage, FilterShaderContext,
-                                (*this, shaderContext, device, paint, matrix));
+    if (!fShader->setContext(device, paint, matrix)) {
+        this->INHERITED::endContext();
+        return false;
+    }
+    return true;
 }
 
-size_t SkFilterShader::contextSize() const {
-    return sizeof(FilterShaderContext) + fShader->contextSize();
+void SkFilterShader::endContext() {
+    fShader->endContext();
+    this->INHERITED::endContext();
 }
 
-bool SkFilterShader::validContext(const SkBitmap& device,
-                                  const SkPaint& paint,
-                                  const SkMatrix& matrix,
-                                  SkMatrix* totalInverse) const {
-    return this->INHERITED::validContext(device, paint, matrix, totalInverse) &&
-           fShader->validContext(device, paint, matrix);
+void SkFilterShader::shadeSpan(int x, int y, SkPMColor result[], int count) {
+    fShader->shadeSpan(x, y, result, count);
+    fFilter->filterSpan(result, count, result);
 }
 
-SkFilterShader::FilterShaderContext::FilterShaderContext(const SkFilterShader& filterShader,
-                                                         SkShader::Context* shaderContext,
-                                                         const SkBitmap& device,
-                                                         const SkPaint& paint,
-                                                         const SkMatrix& matrix)
-    : INHERITED(filterShader, device, paint, matrix)
-    , fShaderContext(shaderContext) {}
+void SkFilterShader::shadeSpan16(int x, int y, uint16_t result[], int count) {
+    SkASSERT(fShader->getFlags() & SkShader::kHasSpan16_Flag);
+    SkASSERT(fFilter->getFlags() & SkColorFilter::kHasFilter16_Flag);
 
-SkFilterShader::FilterShaderContext::~FilterShaderContext() {
-    fShaderContext->~Context();
-}
-
-void SkFilterShader::FilterShaderContext::shadeSpan(int x, int y, SkPMColor result[], int count) {
-    const SkFilterShader& filterShader = static_cast<const SkFilterShader&>(fShader);
-
-    fShaderContext->shadeSpan(x, y, result, count);
-    filterShader.fFilter->filterSpan(result, count, result);
-}
-
-void SkFilterShader::FilterShaderContext::shadeSpan16(int x, int y, uint16_t result[], int count) {
-    const SkFilterShader& filterShader = static_cast<const SkFilterShader&>(fShader);
-
-    SkASSERT(fShaderContext->getFlags() & SkShader::kHasSpan16_Flag);
-    SkASSERT(filterShader.fFilter->getFlags() & SkColorFilter::kHasFilter16_Flag);
-
-    fShaderContext->shadeSpan16(x, y, result, count);
-    filterShader.fFilter->filterSpan16(result, count, result);
+    fShader->shadeSpan16(x, y, result, count);
+    fFilter->filterSpan16(result, count, result);
 }
 
 #ifndef SK_IGNORE_TO_STRING
diff --git a/src/core/SkFilterShader.h b/src/core/SkFilterShader.h
index 4ef4577..11add0c 100644
--- a/src/core/SkFilterShader.h
+++ b/src/core/SkFilterShader.h
@@ -17,29 +17,12 @@
     SkFilterShader(SkShader* shader, SkColorFilter* filter);
     virtual ~SkFilterShader();
 
-    virtual bool validContext(const SkBitmap&, const SkPaint&,
-                              const SkMatrix&, SkMatrix* totalInverse = NULL) const SK_OVERRIDE;
-    virtual SkShader::Context* createContext(const SkBitmap&, const SkPaint&,
-                                             const SkMatrix&, void* storage) const SK_OVERRIDE;
-    virtual size_t contextSize() const SK_OVERRIDE;
-
-    class FilterShaderContext : public SkShader::Context {
-    public:
-        // Takes ownership of shaderContext and calls its destructor.
-        FilterShaderContext(const SkFilterShader& filterShader, SkShader::Context* shaderContext,
-                            const SkBitmap& device, const SkPaint& paint, const SkMatrix& matrix);
-        virtual ~FilterShaderContext();
-
-        virtual uint32_t getFlags() const SK_OVERRIDE;
-
-        virtual void shadeSpan(int x, int y, SkPMColor[], int count) SK_OVERRIDE;
-        virtual void shadeSpan16(int x, int y, uint16_t[], int count) SK_OVERRIDE;
-
-    private:
-        SkShader::Context* fShaderContext;
-
-        typedef SkShader::Context INHERITED;
-    };
+    virtual uint32_t getFlags() SK_OVERRIDE;
+    virtual bool setContext(const SkBitmap&, const SkPaint&,
+                            const SkMatrix&) SK_OVERRIDE;
+    virtual void endContext() SK_OVERRIDE;
+    virtual void shadeSpan(int x, int y, SkPMColor[], int count) SK_OVERRIDE;
+    virtual void shadeSpan16(int x, int y, uint16_t[], int count) SK_OVERRIDE;
 
     SK_TO_STRING_OVERRIDE()
     SK_DECLARE_PUBLIC_FLATTENABLE_DESERIALIZATION_PROCS(SkFilterShader)
diff --git a/src/core/SkPictureShader.cpp b/src/core/SkPictureShader.cpp
index dc5c90b..bf31285 100644
--- a/src/core/SkPictureShader.cpp
+++ b/src/core/SkPictureShader.cpp
@@ -49,7 +49,7 @@
     fPicture->flatten(buffer);
 }
 
-SkShader* SkPictureShader::refBitmapShader(const SkMatrix& matrix) const {
+bool SkPictureShader::buildBitmapShader(const SkMatrix& matrix) const {
     SkASSERT(fPicture && fPicture->width() > 0 && fPicture->height() > 0);
 
     SkMatrix m;
@@ -70,20 +70,17 @@
 
     SkISize tileSize = scaledSize.toRound();
     if (tileSize.isEmpty()) {
-        return NULL;
+        return false;
     }
 
     // The actual scale, compensating for rounding.
     SkSize tileScale = SkSize::Make(SkIntToScalar(tileSize.width()) / fPicture->width(),
                                     SkIntToScalar(tileSize.height()) / fPicture->height());
 
-    SkAutoMutexAcquire ama(fCachedBitmapShaderMutex);
-
-    if (!fCachedBitmapShader || tileScale != fCachedTileScale ||
-        this->getLocalMatrix() != fCachedLocalMatrix) {
+    if (!fCachedShader || tileScale != fCachedTileScale) {
         SkBitmap bm;
         if (!bm.allocN32Pixels(tileSize.width(), tileSize.height())) {
-            return NULL;
+            return false;
         }
         bm.eraseColor(SK_ColorTRANSPARENT);
 
@@ -91,91 +88,66 @@
         canvas.scale(tileScale.width(), tileScale.height());
         canvas.drawPicture(*fPicture);
 
-        fCachedBitmapShader.reset(CreateBitmapShader(bm, fTmx, fTmy));
+        fCachedShader.reset(CreateBitmapShader(bm, fTmx, fTmy));
         fCachedTileScale = tileScale;
-        fCachedLocalMatrix = this->getLocalMatrix();
-
-        SkMatrix shaderMatrix = this->getLocalMatrix();
-        shaderMatrix.preScale(1 / tileScale.width(), 1 / tileScale.height());
-        fCachedBitmapShader->setLocalMatrix(shaderMatrix);
     }
 
-    // Increment the ref counter inside the mutex to ensure the returned pointer is still valid.
-    // Otherwise, the pointer may have been overwritten on a different thread before the object's
-    // ref count was incremented.
-    fCachedBitmapShader.get()->ref();
-    return fCachedBitmapShader;
+    SkMatrix shaderMatrix = this->getLocalMatrix();
+    shaderMatrix.preScale(1 / tileScale.width(), 1 / tileScale.height());
+    fCachedShader->setLocalMatrix(shaderMatrix);
+
+    return true;
 }
 
-SkShader* SkPictureShader::validInternal(const SkBitmap& device, const SkPaint& paint,
-                                         const SkMatrix& matrix, SkMatrix* totalInverse) const {
-    if (!this->INHERITED::validContext(device, paint, matrix, totalInverse)) {
-        return NULL;
+bool SkPictureShader::setContext(const SkBitmap& device,
+                                 const SkPaint& paint,
+                                 const SkMatrix& matrix) {
+    if (!this->buildBitmapShader(matrix)) {
+        return false;
     }
 
-    SkAutoTUnref<SkShader> bitmapShader(this->refBitmapShader(matrix));
-    if (!bitmapShader || !bitmapShader->validContext(device, paint, matrix)) {
-        return NULL;
+    if (!this->INHERITED::setContext(device, paint, matrix)) {
+        return false;
     }
 
-    return bitmapShader.detach();
-}
-
-bool SkPictureShader::validContext(const SkBitmap& device, const SkPaint& paint,
-                                   const SkMatrix& matrix, SkMatrix* totalInverse) const {
-    SkAutoTUnref<SkShader> shader(this->validInternal(device, paint, matrix, totalInverse));
-    return shader != NULL;
-}
-
-SkShader::Context* SkPictureShader::createContext(const SkBitmap& device, const SkPaint& paint,
-                                                  const SkMatrix& matrix, void* storage) const {
-    SkAutoTUnref<SkShader> bitmapShader(this->validInternal(device, paint, matrix, NULL));
-    if (!bitmapShader) {
-        return NULL;
+    SkASSERT(fCachedShader);
+    if (!fCachedShader->setContext(device, paint, matrix)) {
+        this->INHERITED::endContext();
+        return false;
     }
 
-    return SkNEW_PLACEMENT_ARGS(storage, PictureShaderContext,
-                                (*this, device, paint, matrix, bitmapShader.detach()));
+    return true;
 }
 
-size_t SkPictureShader::contextSize() const {
-    return sizeof(PictureShaderContext);
+void SkPictureShader::endContext() {
+    SkASSERT(fCachedShader);
+    fCachedShader->endContext();
+
+    this->INHERITED::endContext();
 }
 
-SkPictureShader::PictureShaderContext::PictureShaderContext(
-        const SkPictureShader& shader, const SkBitmap& device,
-        const SkPaint& paint, const SkMatrix& matrix, SkShader* bitmapShader)
-    : INHERITED(shader, device, paint, matrix)
-    , fBitmapShader(bitmapShader)
-{
-    SkASSERT(fBitmapShader);
-    fBitmapShaderContextStorage = sk_malloc_throw(fBitmapShader->contextSize());
-    fBitmapShaderContext = fBitmapShader->createContext(
-            device, paint, matrix, fBitmapShaderContextStorage);
-    SkASSERT(fBitmapShaderContext);
+uint32_t SkPictureShader::getFlags() {
+    if (NULL != fCachedShader) {
+        return fCachedShader->getFlags();
+    }
+    return 0;
 }
 
-SkPictureShader::PictureShaderContext::~PictureShaderContext() {
-    fBitmapShaderContext->~Context();
-    sk_free(fBitmapShaderContextStorage);
+SkShader::ShadeProc SkPictureShader::asAShadeProc(void** ctx) {
+    if (fCachedShader) {
+        return fCachedShader->asAShadeProc(ctx);
+    }
+    return NULL;
 }
 
-uint32_t SkPictureShader::PictureShaderContext::getFlags() const {
-    return fBitmapShaderContext->getFlags();
+void SkPictureShader::shadeSpan(int x, int y, SkPMColor dstC[], int count) {
+    SkASSERT(fCachedShader);
+    fCachedShader->shadeSpan(x, y, dstC, count);
 }
 
-SkShader::Context::ShadeProc SkPictureShader::PictureShaderContext::asAShadeProc(void** ctx) {
-    return fBitmapShaderContext->asAShadeProc(ctx);
-}
-
-void SkPictureShader::PictureShaderContext::shadeSpan(int x, int y, SkPMColor dstC[], int count) {
-    SkASSERT(fBitmapShaderContext);
-    fBitmapShaderContext->shadeSpan(x, y, dstC, count);
-}
-
-void SkPictureShader::PictureShaderContext::shadeSpan16(int x, int y, uint16_t dstC[], int count) {
-    SkASSERT(fBitmapShaderContext);
-    fBitmapShaderContext->shadeSpan16(x, y, dstC, count);
+void SkPictureShader::shadeSpan16(int x, int y, uint16_t dstC[], int count) {
+    SkASSERT(fCachedShader);
+    fCachedShader->shadeSpan16(x, y, dstC, count);
 }
 
 #ifndef SK_IGNORE_TO_STRING
@@ -196,10 +168,10 @@
 
 #if SK_SUPPORT_GPU
 GrEffectRef* SkPictureShader::asNewEffect(GrContext* context, const SkPaint& paint) const {
-    SkAutoTUnref<SkShader> bitmapShader(this->refBitmapShader(context->getMatrix()));
-    if (!bitmapShader) {
+    if (!this->buildBitmapShader(context->getMatrix())) {
         return NULL;
     }
-    return bitmapShader->asNewEffect(context, paint);
+    SkASSERT(fCachedShader);
+    return fCachedShader->asNewEffect(context, paint);
 }
 #endif
diff --git a/src/core/SkPictureShader.h b/src/core/SkPictureShader.h
index d1be059..ea74b56 100644
--- a/src/core/SkPictureShader.h
+++ b/src/core/SkPictureShader.h
@@ -24,33 +24,13 @@
     static SkPictureShader* Create(SkPicture*, TileMode, TileMode);
     virtual ~SkPictureShader();
 
-    virtual bool validContext(const SkBitmap&, const SkPaint&,
-                              const SkMatrix&, SkMatrix* totalInverse = NULL) const SK_OVERRIDE;
-    virtual SkShader::Context* createContext(const SkBitmap& device, const SkPaint& paint,
-                                             const SkMatrix& matrix, void* storage) const
-            SK_OVERRIDE;
-    virtual size_t contextSize() const SK_OVERRIDE;
+    virtual bool setContext(const SkBitmap&, const SkPaint&, const SkMatrix&) SK_OVERRIDE;
+    virtual void endContext() SK_OVERRIDE;
+    virtual uint32_t getFlags() SK_OVERRIDE;
 
-    class PictureShaderContext : public SkShader::Context {
-    public:
-        PictureShaderContext(const SkPictureShader& shader, const SkBitmap& device,
-                             const SkPaint& paint, const SkMatrix& matrix,
-                             SkShader* bitmapShader);
-        virtual ~PictureShaderContext();
-
-        virtual uint32_t getFlags() const SK_OVERRIDE;
-
-        virtual ShadeProc asAShadeProc(void** ctx) SK_OVERRIDE;
-        virtual void shadeSpan(int x, int y, SkPMColor dstC[], int count) SK_OVERRIDE;
-        virtual void shadeSpan16(int x, int y, uint16_t dstC[], int count) SK_OVERRIDE;
-
-    private:
-        SkAutoTUnref<SkShader>  fBitmapShader;
-        SkShader::Context*      fBitmapShaderContext;
-        void*                   fBitmapShaderContextStorage;
-
-        typedef SkShader::Context INHERITED;
-    };
+    virtual ShadeProc asAShadeProc(void** ctx) SK_OVERRIDE;
+    virtual void shadeSpan(int x, int y, SkPMColor dstC[], int count) SK_OVERRIDE;
+    virtual void shadeSpan16(int x, int y, uint16_t dstC[], int count) SK_OVERRIDE;
 
     SK_TO_STRING_OVERRIDE()
     SK_DECLARE_PUBLIC_FLATTENABLE_DESERIALIZATION_PROCS(SkPictureShader)
@@ -66,18 +46,13 @@
 private:
     SkPictureShader(SkPicture*, TileMode, TileMode);
 
-    SkShader* validInternal(const SkBitmap& device, const SkPaint& paint,
-                            const SkMatrix& matrix, SkMatrix* totalInverse) const;
-
-    SkShader* refBitmapShader(const SkMatrix&) const;
+    bool buildBitmapShader(const SkMatrix&) const;
 
     SkPicture*  fPicture;
     TileMode    fTmx, fTmy;
 
-    mutable SkMutex                 fCachedBitmapShaderMutex;
-    mutable SkAutoTUnref<SkShader>  fCachedBitmapShader;
+    mutable SkAutoTUnref<SkShader>  fCachedShader;
     mutable SkSize                  fCachedTileScale;
-    mutable SkMatrix                fCachedLocalMatrix;
 
     typedef SkShader INHERITED;
 };
diff --git a/src/core/SkShader.cpp b/src/core/SkShader.cpp
index 40e52a0..e337b7d 100644
--- a/src/core/SkShader.cpp
+++ b/src/core/SkShader.cpp
@@ -17,6 +17,7 @@
 
 SkShader::SkShader() {
     fLocalMatrix.reset();
+    SkDEBUGCODE(fInSetContext = false;)
 }
 
 SkShader::SkShader(SkReadBuffer& buffer)
@@ -26,9 +27,12 @@
     } else {
         fLocalMatrix.reset();
     }
+
+    SkDEBUGCODE(fInSetContext = false;)
 }
 
 SkShader::~SkShader() {
+    SkASSERT(!fInSetContext);
 }
 
 void SkShader::flatten(SkWriteBuffer& buffer) const {
@@ -40,48 +44,39 @@
     }
 }
 
-bool SkShader::computeTotalInverse(const SkMatrix& matrix, SkMatrix* totalInverse) const {
+bool SkShader::setContext(const SkBitmap& device,
+                          const SkPaint& paint,
+                          const SkMatrix& matrix) {
+    SkASSERT(!this->setContextHasBeenCalled());
+
     const SkMatrix* m = &matrix;
     SkMatrix        total;
 
+    fPaintAlpha = paint.getAlpha();
     if (this->hasLocalMatrix()) {
         total.setConcat(matrix, this->getLocalMatrix());
         m = &total;
     }
-
-    return m->invert(totalInverse);
+    if (m->invert(&fTotalInverse)) {
+        fTotalInverseClass = (uint8_t)ComputeMatrixClass(fTotalInverse);
+        SkDEBUGCODE(fInSetContext = true;)
+        return true;
+    }
+    return false;
 }
 
-bool SkShader::validContext(const SkBitmap& device,
-                            const SkPaint& paint,
-                            const SkMatrix& matrix,
-                            SkMatrix* totalInverse) const {
-    return this->computeTotalInverse(matrix, totalInverse);
+void SkShader::endContext() {
+    SkASSERT(fInSetContext);
+    SkDEBUGCODE(fInSetContext = false;)
 }
 
-SkShader::Context::Context(const SkShader& shader, const SkBitmap& device,
-                           const SkPaint& paint, const SkMatrix& matrix)
-    : fShader(shader)
-{
-    SkASSERT(fShader.validContext(device, paint, matrix));
-
-    // Because the context parameters must be valid at this point, we know that the matrix is
-    // invertible.
-    SkAssertResult(fShader.computeTotalInverse(matrix, &fTotalInverse));
-    fTotalInverseClass = (uint8_t)ComputeMatrixClass(fTotalInverse);
-
-    fPaintAlpha = paint.getAlpha();
-}
-
-SkShader::Context::~Context() {}
-
-SkShader::Context::ShadeProc SkShader::Context::asAShadeProc(void** ctx) {
+SkShader::ShadeProc SkShader::asAShadeProc(void** ctx) {
     return NULL;
 }
 
 #include "SkColorPriv.h"
 
-void SkShader::Context::shadeSpan16(int x, int y, uint16_t span16[], int count) {
+void SkShader::shadeSpan16(int x, int y, uint16_t span16[], int count) {
     SkASSERT(span16);
     SkASSERT(count > 0);
     SkASSERT(this->canCallShadeSpan16());
@@ -99,7 +94,7 @@
     #define SkU32BitShiftToByteOffset(shift)    ((shift) >> 3)
 #endif
 
-void SkShader::Context::shadeSpanAlpha(int x, int y, uint8_t alpha[], int count) {
+void SkShader::shadeSpanAlpha(int x, int y, uint8_t alpha[], int count) {
     SkASSERT(count > 0);
 
     SkPMColor   colors[kTempColorCount];
@@ -153,7 +148,7 @@
 #endif
 }
 
-SkShader::Context::MatrixClass SkShader::Context::ComputeMatrixClass(const SkMatrix& mat) {
+SkShader::MatrixClass SkShader::ComputeMatrixClass(const SkMatrix& mat) {
     MatrixClass mc = kLinear_MatrixClass;
 
     if (mat.hasPerspective()) {
@@ -168,7 +163,8 @@
 
 //////////////////////////////////////////////////////////////////////////////
 
-SkShader::BitmapType SkShader::asABitmap(SkBitmap*, SkMatrix*, TileMode*) const {
+SkShader::BitmapType SkShader::asABitmap(SkBitmap*, SkMatrix*,
+                                         TileMode*) const {
     return kNone_BitmapType;
 }
 
@@ -203,16 +199,19 @@
 #include "SkColorShader.h"
 #include "SkUtils.h"
 
-SkColorShader::SkColorShader()
-    : fColor()
-    , fInheritColor(true) {
+SkColorShader::SkColorShader() {
+    fFlags = 0;
+    fInheritColor = true;
 }
 
-SkColorShader::SkColorShader(SkColor c)
-    : fColor(c)
-    , fInheritColor(false) {
+SkColorShader::SkColorShader(SkColor c) {
+    fFlags = 0;
+    fColor = c;
+    fInheritColor = false;
 }
 
+SkColorShader::~SkColorShader() {}
+
 bool SkColorShader::isOpaque() const {
     if (fInheritColor) {
         return true; // using paint's alpha
@@ -221,6 +220,8 @@
 }
 
 SkColorShader::SkColorShader(SkReadBuffer& b) : INHERITED(b) {
+    fFlags = 0; // computed in setContext
+
     fInheritColor = b.readBool();
     if (fInheritColor) {
         return;
@@ -237,43 +238,32 @@
     buffer.writeColor(fColor);
 }
 
-uint32_t SkColorShader::ColorShaderContext::getFlags() const {
+uint32_t SkColorShader::getFlags() {
     return fFlags;
 }
 
-uint8_t SkColorShader::ColorShaderContext::getSpan16Alpha() const {
+uint8_t SkColorShader::getSpan16Alpha() const {
     return SkGetPackedA32(fPMColor);
 }
 
-SkShader::Context* SkColorShader::createContext(const SkBitmap& device, const SkPaint& paint,
-                                                const SkMatrix& matrix, void* storage) const {
-    if (!this->validContext(device, paint, matrix)) {
-        return NULL;
+bool SkColorShader::setContext(const SkBitmap& device, const SkPaint& paint,
+                               const SkMatrix& matrix) {
+    if (!this->INHERITED::setContext(device, paint, matrix)) {
+        return false;
     }
 
-    return SkNEW_PLACEMENT_ARGS(storage, ColorShaderContext, (*this, device, paint, matrix));
-}
-
-SkColorShader::ColorShaderContext::ColorShaderContext(const SkColorShader& shader,
-                                                      const SkBitmap& device,
-                                                      const SkPaint& paint,
-                                                      const SkMatrix& matrix)
-    : INHERITED(shader, device, paint, matrix)
-{
     unsigned a;
 
-    SkColor color;
-    if (shader.fInheritColor) {
-        color = paint.getColor();
-        a = SkColorGetA(color);
+    if (fInheritColor) {
+        fColor = paint.getColor();
+        a = SkColorGetA(fColor);
     } else {
-        color = shader.fColor;
-        a = SkAlphaMul(SkColorGetA(color), SkAlpha255To256(paint.getAlpha()));
+        a = SkAlphaMul(SkColorGetA(fColor), SkAlpha255To256(paint.getAlpha()));
     }
 
-    unsigned r = SkColorGetR(color);
-    unsigned g = SkColorGetG(color);
-    unsigned b = SkColorGetB(color);
+    unsigned r = SkColorGetR(fColor);
+    unsigned g = SkColorGetG(fColor);
+    unsigned b = SkColorGetB(fColor);
 
     // we want this before we apply any alpha
     fColor16 = SkPack888ToRGB16(r, g, b);
@@ -292,17 +282,19 @@
             fFlags |= kHasSpan16_Flag;
         }
     }
+
+    return true;
 }
 
-void SkColorShader::ColorShaderContext::shadeSpan(int x, int y, SkPMColor span[], int count) {
+void SkColorShader::shadeSpan(int x, int y, SkPMColor span[], int count) {
     sk_memset32(span, fPMColor, count);
 }
 
-void SkColorShader::ColorShaderContext::shadeSpan16(int x, int y, uint16_t span[], int count) {
+void SkColorShader::shadeSpan16(int x, int y, uint16_t span[], int count) {
     sk_memset16(span, fColor16, count);
 }
 
-void SkColorShader::ColorShaderContext::shadeSpanAlpha(int x, int y, uint8_t alpha[], int count) {
+void SkColorShader::shadeSpanAlpha(int x, int y, uint8_t alpha[], int count) {
     memset(alpha, SkGetPackedA32(fPMColor), count);
 }
 
@@ -342,9 +334,27 @@
 
 ///////////////////////////////////////////////////////////////////////////////
 
-#ifndef SK_IGNORE_TO_STRING
 #include "SkEmptyShader.h"
 
+uint32_t SkEmptyShader::getFlags() { return 0; }
+uint8_t SkEmptyShader::getSpan16Alpha() const { return 0; }
+
+bool SkEmptyShader::setContext(const SkBitmap&, const SkPaint&,
+                               const SkMatrix&) { return false; }
+
+void SkEmptyShader::shadeSpan(int x, int y, SkPMColor span[], int count) {
+    SkDEBUGFAIL("should never get called, since setContext() returned false");
+}
+
+void SkEmptyShader::shadeSpan16(int x, int y, uint16_t span[], int count) {
+    SkDEBUGFAIL("should never get called, since setContext() returned false");
+}
+
+void SkEmptyShader::shadeSpanAlpha(int x, int y, uint8_t alpha[], int count) {
+    SkDEBUGFAIL("should never get called, since setContext() returned false");
+}
+
+#ifndef SK_IGNORE_TO_STRING
 void SkEmptyShader::toString(SkString* str) const {
     str->append("SkEmptyShader: (");
 
diff --git a/src/core/SkSmallAllocator.h b/src/core/SkSmallAllocator.h
index 8d4b53a..655008b 100644
--- a/src/core/SkSmallAllocator.h
+++ b/src/core/SkSmallAllocator.h
@@ -117,12 +117,10 @@
             // but we're not sure we can catch all callers, so handle it but
             // assert false in debug mode.
             SkASSERT(false);
-            rec->fStorageSize = 0;
             rec->fHeapStorage = sk_malloc_throw(storageRequired);
             rec->fObj = static_cast<void*>(rec->fHeapStorage);
         } else {
             // There is space in fStorage.
-            rec->fStorageSize = storageRequired;
             rec->fHeapStorage = NULL;
             SkASSERT(SkIsAlign4(fStorageUsed));
             rec->fObj = static_cast<void*>(fStorage + (fStorageUsed / 4));
@@ -133,26 +131,11 @@
         return rec->fObj;
     }
 
-    /*
-     *  Free the memory reserved last without calling the destructor.
-     *  Can be used in a nested way, i.e. after reserving A and B, calling
-     *  freeLast once will free B and calling it again will free A.
-     */
-    void freeLast() {
-        SkASSERT(fNumObjects > 0);
-        Rec* rec = &fRecs[fNumObjects - 1];
-        sk_free(rec->fHeapStorage);
-        fStorageUsed -= rec->fStorageSize;
-
-        fNumObjects--;
-    }
-
 private:
     struct Rec {
-        size_t fStorageSize;  // 0 if allocated on heap
-        void*  fObj;
-        void*  fHeapStorage;
-        void   (*fKillProc)(void*);
+        void* fObj;
+        void* fHeapStorage;
+        void  (*fKillProc)(void*);
     };
 
     // Number of bytes used so far.
diff --git a/src/effects/SkPerlinNoiseShader.cpp b/src/effects/SkPerlinNoiseShader.cpp
index 5adb582..ed63faf 100644
--- a/src/effects/SkPerlinNoiseShader.cpp
+++ b/src/effects/SkPerlinNoiseShader.cpp
@@ -278,6 +278,7 @@
   , fStitchTiles(!fTileSize.isEmpty())
 {
     SkASSERT(numOctaves >= 0 && numOctaves < 256);
+    fMatrix.reset();
     fPaintingData = SkNEW_ARGS(PaintingData, (fTileSize, fSeed, fBaseFrequencyX, fBaseFrequencyY));
 }
 
@@ -292,6 +293,7 @@
     fStitchTiles    = buffer.readBool();
     fTileSize.fWidth  = buffer.readInt();
     fTileSize.fHeight = buffer.readInt();
+    fMatrix.reset();
     fPaintingData = SkNEW_ARGS(PaintingData, (fTileSize, fSeed, fBaseFrequencyX, fBaseFrequencyY));
     buffer.validate(perlin_noise_type_is_valid(fType) &&
                     (fNumOctaves >= 0) && (fNumOctaves <= 255) &&
@@ -315,9 +317,9 @@
     buffer.writeInt(fTileSize.fHeight);
 }
 
-SkScalar SkPerlinNoiseShader::PerlinNoiseShaderContext::noise2D(
-        int channel, const PaintingData& paintingData,
-        const StitchData& stitchData, const SkPoint& noiseVector) const {
+SkScalar SkPerlinNoiseShader::noise2D(int channel, const PaintingData& paintingData,
+                                      const StitchData& stitchData,
+                                      const SkPoint& noiseVector) const {
     struct Noise {
         int noisePositionIntegerValue;
         SkScalar noisePositionFractionValue;
@@ -331,9 +333,8 @@
     Noise noiseX(noiseVector.x());
     Noise noiseY(noiseVector.y());
     SkScalar u, v;
-    const SkPerlinNoiseShader& perlinNoiseShader = static_cast<const SkPerlinNoiseShader&>(fShader);
     // If stitching, adjust lattice points accordingly.
-    if (perlinNoiseShader.fStitchTiles) {
+    if (fStitchTiles) {
         noiseX.noisePositionIntegerValue =
             checkNoise(noiseX.noisePositionIntegerValue, stitchData.fWrapX, stitchData.fWidth);
         noiseY.noisePositionIntegerValue =
@@ -364,11 +365,11 @@
     return SkScalarInterp(a, b, sy);
 }
 
-SkScalar SkPerlinNoiseShader::PerlinNoiseShaderContext::calculateTurbulenceValueForPoint(
-        int channel, const PaintingData& paintingData,
-        StitchData& stitchData, const SkPoint& point) const {
-    const SkPerlinNoiseShader& perlinNoiseShader = static_cast<const SkPerlinNoiseShader&>(fShader);
-    if (perlinNoiseShader.fStitchTiles) {
+SkScalar SkPerlinNoiseShader::calculateTurbulenceValueForPoint(int channel,
+                                                               const PaintingData& paintingData,
+                                                               StitchData& stitchData,
+                                                               const SkPoint& point) const {
+    if (fStitchTiles) {
         // Set up TurbulenceInitial stitch values.
         stitchData = paintingData.fStitchDataInit;
     }
@@ -376,14 +377,14 @@
     SkPoint noiseVector(SkPoint::Make(SkScalarMul(point.x(), paintingData.fBaseFrequency.fX),
                                       SkScalarMul(point.y(), paintingData.fBaseFrequency.fY)));
     SkScalar ratio = SK_Scalar1;
-    for (int octave = 0; octave < perlinNoiseShader.fNumOctaves; ++octave) {
+    for (int octave = 0; octave < fNumOctaves; ++octave) {
         SkScalar noise = noise2D(channel, paintingData, stitchData, noiseVector);
         turbulenceFunctionResult += SkScalarDiv(
-            (perlinNoiseShader.fType == kFractalNoise_Type) ? noise : SkScalarAbs(noise), ratio);
+            (fType == kFractalNoise_Type) ? noise : SkScalarAbs(noise), ratio);
         noiseVector.fX *= 2;
         noiseVector.fY *= 2;
         ratio *= 2;
-        if (perlinNoiseShader.fStitchTiles) {
+        if (fStitchTiles) {
             // Update stitch values
             stitchData.fWidth  *= 2;
             stitchData.fWrapX   = stitchData.fWidth + kPerlinNoise;
@@ -394,7 +395,7 @@
 
     // The value of turbulenceFunctionResult comes from ((turbulenceFunctionResult) + 1) / 2
     // by fractalNoise and (turbulenceFunctionResult) by turbulence.
-    if (perlinNoiseShader.fType == kFractalNoise_Type) {
+    if (fType == kFractalNoise_Type) {
         turbulenceFunctionResult =
             SkScalarMul(turbulenceFunctionResult, SK_ScalarHalf) + SK_ScalarHalf;
     }
@@ -408,9 +409,7 @@
     return SkScalarPin(turbulenceFunctionResult, 0, SK_Scalar1);
 }
 
-SkPMColor SkPerlinNoiseShader::PerlinNoiseShaderContext::shade(
-        const SkPoint& point, StitchData& stitchData) const {
-    const SkPerlinNoiseShader& perlinNoiseShader = static_cast<const SkPerlinNoiseShader&>(fShader);
+SkPMColor SkPerlinNoiseShader::shade(const SkPoint& point, StitchData& stitchData) const {
     SkPoint newPoint;
     fMatrix.mapPoints(&newPoint, &point, 1);
     newPoint.fX = SkScalarRoundToScalar(newPoint.fX);
@@ -419,32 +418,15 @@
     U8CPU rgba[4];
     for (int channel = 3; channel >= 0; --channel) {
         rgba[channel] = SkScalarFloorToInt(255 *
-            calculateTurbulenceValueForPoint(channel, *perlinNoiseShader.fPaintingData,
-                                             stitchData, newPoint));
+            calculateTurbulenceValueForPoint(channel, *fPaintingData, stitchData, newPoint));
     }
     return SkPreMultiplyARGB(rgba[3], rgba[0], rgba[1], rgba[2]);
 }
 
-SkShader::Context* SkPerlinNoiseShader::createContext(const SkBitmap& device, const SkPaint& paint,
-                                                      const SkMatrix& matrix, void* storage) const {
-    if (!this->validContext(device, paint, matrix)) {
-        return NULL;
-    }
-
-    return SkNEW_PLACEMENT_ARGS(storage, PerlinNoiseShaderContext, (*this, device, paint, matrix));
-}
-
-size_t SkPerlinNoiseShader::contextSize() const {
-    return sizeof(PerlinNoiseShaderContext);
-}
-
-SkPerlinNoiseShader::PerlinNoiseShaderContext::PerlinNoiseShaderContext(
-        const SkPerlinNoiseShader& shader, const SkBitmap& device,
-        const SkPaint& paint, const SkMatrix& matrix)
-    : INHERITED(shader, device, paint, matrix)
-{
+bool SkPerlinNoiseShader::setContext(const SkBitmap& device, const SkPaint& paint,
+                                     const SkMatrix& matrix) {
     SkMatrix newMatrix = matrix;
-    newMatrix.postConcat(shader.getLocalMatrix());
+    newMatrix.postConcat(getLocalMatrix());
     SkMatrix invMatrix;
     if (!newMatrix.invert(&invMatrix)) {
         invMatrix.reset();
@@ -455,10 +437,10 @@
     newMatrix.postConcat(invMatrix);
     newMatrix.postConcat(invMatrix);
     fMatrix = newMatrix;
+    return INHERITED::setContext(device, paint, matrix);
 }
 
-void SkPerlinNoiseShader::PerlinNoiseShaderContext::shadeSpan(
-        int x, int y, SkPMColor result[], int count) {
+void SkPerlinNoiseShader::shadeSpan(int x, int y, SkPMColor result[], int count) {
     SkPoint point = SkPoint::Make(SkIntToScalar(x), SkIntToScalar(y));
     StitchData stitchData;
     for (int i = 0; i < count; ++i) {
@@ -467,8 +449,7 @@
     }
 }
 
-void SkPerlinNoiseShader::PerlinNoiseShaderContext::shadeSpan16(
-        int x, int y, uint16_t result[], int count) {
+void SkPerlinNoiseShader::shadeSpan16(int x, int y, uint16_t result[], int count) {
     SkPoint point = SkPoint::Make(SkIntToScalar(x), SkIntToScalar(y));
     StitchData stitchData;
     DITHER_565_SCAN(y);
diff --git a/src/effects/SkTransparentShader.cpp b/src/effects/SkTransparentShader.cpp
index 0997e62..bd8b99a 100644
--- a/src/effects/SkTransparentShader.cpp
+++ b/src/effects/SkTransparentShader.cpp
@@ -11,40 +11,26 @@
 #include "SkColorPriv.h"
 #include "SkString.h"
 
-SkShader::Context* SkTransparentShader::createContext(const SkBitmap& device,
-                                                      const SkPaint& paint,
-                                                      const SkMatrix& matrix,
-                                                      void* storage) const {
-    if (!this->validContext(device, paint, matrix)) {
-        return NULL;
-    }
+bool SkTransparentShader::setContext(const SkBitmap& device,
+                                     const SkPaint& paint,
+                                     const SkMatrix& matrix) {
+    fDevice = &device;
+    fAlpha = paint.getAlpha();
 
-    return SkNEW_PLACEMENT_ARGS(storage, TransparentShaderContext, (*this, device, paint, matrix));
+    return this->INHERITED::setContext(device, paint, matrix);
 }
 
-size_t SkTransparentShader::contextSize() const {
-    return sizeof(TransparentShaderContext);
-}
-
-SkTransparentShader::TransparentShaderContext::TransparentShaderContext(
-        const SkTransparentShader& shader, const SkBitmap& device,
-        const SkPaint& paint, const SkMatrix& matrix)
-    : INHERITED(shader, device, paint, matrix)
-    , fDevice(&device) {}
-
-SkTransparentShader::TransparentShaderContext::~TransparentShaderContext() {}
-
-uint32_t SkTransparentShader::TransparentShaderContext::getFlags() const {
+uint32_t SkTransparentShader::getFlags() {
     uint32_t flags = this->INHERITED::getFlags();
 
     switch (fDevice->colorType()) {
         case kRGB_565_SkColorType:
             flags |= kHasSpan16_Flag;
-            if (this->getPaintAlpha() == 255)
+            if (fAlpha == 255)
                 flags |= kOpaqueAlpha_Flag;
             break;
         case kN32_SkColorType:
-            if (this->getPaintAlpha() == 255 && fDevice->isOpaque())
+            if (fAlpha == 255 && fDevice->isOpaque())
                 flags |= kOpaqueAlpha_Flag;
             break;
         default:
@@ -53,9 +39,8 @@
     return flags;
 }
 
-void SkTransparentShader::TransparentShaderContext::shadeSpan(int x, int y, SkPMColor span[],
-                                                              int count) {
-    unsigned scale = SkAlpha255To256(this->getPaintAlpha());
+void SkTransparentShader::shadeSpan(int x, int y, SkPMColor span[], int count) {
+    unsigned scale = SkAlpha255To256(fAlpha);
 
     switch (fDevice->colorType()) {
         case kN32_SkColorType:
@@ -78,7 +63,7 @@
                     span[i] = SkPixel16ToPixel32(src[i]);
                 }
             } else {
-                unsigned alpha = this->getPaintAlpha();
+                unsigned alpha = fAlpha;
                 for (int i = count - 1; i >= 0; --i) {
                     uint16_t c = src[i];
                     unsigned r = SkPacked16ToR32(c);
@@ -112,8 +97,7 @@
     }
 }
 
-void SkTransparentShader::TransparentShaderContext::shadeSpan16(int x, int y, uint16_t span[],
-                                                                int count) {
+void SkTransparentShader::shadeSpan16(int x, int y, uint16_t span[], int count) {
     SkASSERT(fDevice->colorType() == kRGB_565_SkColorType);
 
     uint16_t* src = fDevice->getAddr16(x, y);
diff --git a/src/effects/gradients/SkGradientShader.cpp b/src/effects/gradients/SkGradientShader.cpp
index 46e0c95..2e92076 100644
--- a/src/effects/gradients/SkGradientShader.cpp
+++ b/src/effects/gradients/SkGradientShader.cpp
@@ -15,6 +15,8 @@
 SkGradientShaderBase::SkGradientShaderBase(const Descriptor& desc) {
     SkASSERT(desc.fCount > 1);
 
+    fCacheAlpha = 256;  // init to a value that paint.getAlpha() can't return
+
     fMapper = desc.fMapper;
     SkSafeRef(fMapper);
     fGradFlags = SkToU8(desc.fGradFlags);
@@ -24,6 +26,10 @@
     fTileMode = desc.fTileMode;
     fTileProc = gTileProcs[desc.fTileMode];
 
+    fCache16 = fCache16Storage = NULL;
+    fCache32 = NULL;
+    fCache32PixelRef = NULL;
+
     /*  Note: we let the caller skip the first and/or last position.
         i.e. pos[0] = 0.3, pos[1] = 0.7
         In these cases, we insert dummy entries to ensure that the final data
@@ -140,8 +146,14 @@
 }
 
 SkGradientShaderBase::SkGradientShaderBase(SkReadBuffer& buffer) : INHERITED(buffer) {
+    fCacheAlpha = 256;
+
     fMapper = buffer.readUnitMapper();
 
+    fCache16 = fCache16Storage = NULL;
+    fCache32 = NULL;
+    fCache32PixelRef = NULL;
+
     int colorCount = fColorCount = buffer.getArrayCount();
     if (colorCount > kColorStorageCount) {
         size_t allocSize = (sizeof(SkColor) + sizeof(SkPMColor) + sizeof(Rec)) * colorCount;
@@ -176,6 +188,10 @@
 }
 
 SkGradientShaderBase::~SkGradientShaderBase() {
+    if (fCache16Storage) {
+        sk_free(fCache16Storage);
+    }
+    SkSafeUnref(fCache32PixelRef);
     if (fOrigColors != fStorage) {
         sk_free(fOrigColors);
     }
@@ -183,6 +199,7 @@
 }
 
 void SkGradientShaderBase::initCommon() {
+    fFlags = 0;
     unsigned colorAlpha = 0xFF;
     for (int i = 0; i < fColorCount; i++) {
         colorAlpha &= SkColorGetA(fOrigColors[i]);
@@ -250,50 +267,49 @@
     return fColorsAreOpaque;
 }
 
-SkGradientShaderBase::GradientShaderBaseContext::GradientShaderBaseContext(
-        const SkGradientShaderBase& shader, const SkBitmap& device,
-        const SkPaint& paint, const SkMatrix& matrix)
-    : INHERITED(shader, device, paint, matrix)
-    , fCache(shader.refCache(getPaintAlpha()))
-{
+bool SkGradientShaderBase::setContext(const SkBitmap& device,
+                                 const SkPaint& paint,
+                                 const SkMatrix& matrix) {
+    if (!this->INHERITED::setContext(device, paint, matrix)) {
+        return false;
+    }
+
     const SkMatrix& inverse = this->getTotalInverse();
 
-    fDstToIndex.setConcat(shader.fPtsToUnit, inverse);
-
+    fDstToIndex.setConcat(fPtsToUnit, inverse);
     fDstToIndexProc = fDstToIndex.getMapXYProc();
-    fDstToIndexClass = (uint8_t)SkShader::Context::ComputeMatrixClass(fDstToIndex);
+    fDstToIndexClass = (uint8_t)SkShader::ComputeMatrixClass(fDstToIndex);
 
     // now convert our colors in to PMColors
     unsigned paintAlpha = this->getPaintAlpha();
 
     fFlags = this->INHERITED::getFlags();
-    if (shader.fColorsAreOpaque && paintAlpha == 0xFF) {
+    if (fColorsAreOpaque && paintAlpha == 0xFF) {
         fFlags |= kOpaqueAlpha_Flag;
     }
     // we can do span16 as long as our individual colors are opaque,
     // regardless of the paint's alpha
-    if (shader.fColorsAreOpaque) {
+    if (fColorsAreOpaque) {
         fFlags |= kHasSpan16_Flag;
     }
+
+    this->setCacheAlpha(paintAlpha);
+    return true;
 }
 
-SkGradientShaderBase::GradientShaderCache::GradientShaderCache(
-        U8CPU alpha, const SkGradientShaderBase& shader)
-    : fCacheAlpha(alpha)
-    , fShader(shader)
-    , fCache16Inited(false)
-    , fCache32Inited(false)
-{
-    // Only initialize the cache in getCache16/32.
-    fCache16 = NULL;
-    fCache32 = NULL;
-    fCache16Storage = NULL;
-    fCache32PixelRef = NULL;
-}
-
-SkGradientShaderBase::GradientShaderCache::~GradientShaderCache() {
-    sk_free(fCache16Storage);
-    SkSafeUnref(fCache32PixelRef);
+void SkGradientShaderBase::setCacheAlpha(U8CPU alpha) const {
+    // if the new alpha differs from the previous time we were called, inval our cache
+    // this will trigger the cache to be rebuilt.
+    // we don't care about the first time, since the cache ptrs will already be NULL
+    if (fCacheAlpha != alpha) {
+        fCache16 = NULL;            // inval the cache
+        fCache32 = NULL;            // inval the cache
+        fCacheAlpha = alpha;        // record the new alpha
+        // inform our subclasses
+        if (fCache32PixelRef) {
+            fCache32PixelRef->notifyPixelsChanged();
+        }
+    }
 }
 
 #define Fixed_To_Dot8(x)        (((x) + 0x80) >> 8)
@@ -302,8 +318,8 @@
     build a 16bit table as long as the original colors are opaque, even if the
     paint specifies a non-opaque alpha.
 */
-void SkGradientShaderBase::GradientShaderCache::Build16bitCache(
-        uint16_t cache[], SkColor c0, SkColor c1, int count) {
+void SkGradientShaderBase::Build16bitCache(uint16_t cache[], SkColor c0, SkColor c1,
+                                      int count) {
     SkASSERT(count > 1);
     SkASSERT(SkColorGetA(c0) == 0xFF);
     SkASSERT(SkColorGetA(c1) == 0xFF);
@@ -351,9 +367,8 @@
  */
 typedef uint32_t SkUFixed;
 
-void SkGradientShaderBase::GradientShaderCache::Build32bitCache(
-        SkPMColor cache[], SkColor c0, SkColor c1,
-        int count, U8CPU paintAlpha, uint32_t gradFlags) {
+void SkGradientShaderBase::Build32bitCache(SkPMColor cache[], SkColor c0, SkColor c1,
+                                      int count, U8CPU paintAlpha, uint32_t gradFlags) {
     SkASSERT(count > 1);
 
     // need to apply paintAlpha to our two endpoints
@@ -496,123 +511,99 @@
     return 0;
 }
 
-const uint16_t* SkGradientShaderBase::GradientShaderCache::getCache16() {
-    SkOnce(&fCache16Inited, &fCache16Mutex, SkGradientShaderBase::GradientShaderCache::initCache16,
-           this);
-    SkASSERT(fCache16);
+const uint16_t* SkGradientShaderBase::getCache16() const {
+    if (fCache16 == NULL) {
+        // double the count for dither entries
+        const int entryCount = kCache16Count * 2;
+        const size_t allocSize = sizeof(uint16_t) * entryCount;
+
+        if (fCache16Storage == NULL) { // set the storage and our working ptr
+            fCache16Storage = (uint16_t*)sk_malloc_throw(allocSize);
+        }
+        fCache16 = fCache16Storage;
+        if (fColorCount == 2) {
+            Build16bitCache(fCache16, fOrigColors[0], fOrigColors[1],
+                            kCache16Count);
+        } else {
+            Rec* rec = fRecs;
+            int prevIndex = 0;
+            for (int i = 1; i < fColorCount; i++) {
+                int nextIndex = SkFixedToFFFF(rec[i].fPos) >> kCache16Shift;
+                SkASSERT(nextIndex < kCache16Count);
+
+                if (nextIndex > prevIndex)
+                    Build16bitCache(fCache16 + prevIndex, fOrigColors[i-1], fOrigColors[i], nextIndex - prevIndex + 1);
+                prevIndex = nextIndex;
+            }
+        }
+
+        if (fMapper) {
+            fCache16Storage = (uint16_t*)sk_malloc_throw(allocSize);
+            uint16_t* linear = fCache16;         // just computed linear data
+            uint16_t* mapped = fCache16Storage;  // storage for mapped data
+            SkUnitMapper* map = fMapper;
+            for (int i = 0; i < kCache16Count; i++) {
+                int index = map->mapUnit16(bitsTo16(i, kCache16Bits)) >> kCache16Shift;
+                mapped[i] = linear[index];
+                mapped[i + kCache16Count] = linear[index + kCache16Count];
+            }
+            sk_free(fCache16);
+            fCache16 = fCache16Storage;
+        }
+    }
     return fCache16;
 }
 
-void SkGradientShaderBase::GradientShaderCache::initCache16(GradientShaderCache* cache) {
-    // double the count for dither entries
-    const int entryCount = kCache16Count * 2;
-    const size_t allocSize = sizeof(uint16_t) * entryCount;
+const SkPMColor* SkGradientShaderBase::getCache32() const {
+    if (fCache32 == NULL) {
+        SkImageInfo info;
+        info.fWidth = kCache32Count;
+        info.fHeight = 4;   // for our 4 dither rows
+        info.fAlphaType = kPremul_SkAlphaType;
+        info.fColorType = kN32_SkColorType;
 
-    SkASSERT(NULL == cache->fCache16Storage);
-    cache->fCache16Storage = (uint16_t*)sk_malloc_throw(allocSize);
-    cache->fCache16 = cache->fCache16Storage;
-    if (cache->fShader.fColorCount == 2) {
-        Build16bitCache(cache->fCache16, cache->fShader.fOrigColors[0],
-                        cache->fShader.fOrigColors[1], kCache16Count);
-    } else {
-        Rec* rec = cache->fShader.fRecs;
-        int prevIndex = 0;
-        for (int i = 1; i < cache->fShader.fColorCount; i++) {
-            int nextIndex = SkFixedToFFFF(rec[i].fPos) >> kCache16Shift;
-            SkASSERT(nextIndex < kCache16Count);
+        if (NULL == fCache32PixelRef) {
+            fCache32PixelRef = SkMallocPixelRef::NewAllocate(info, 0, NULL);
+        }
+        fCache32 = (SkPMColor*)fCache32PixelRef->getAddr();
+        if (fColorCount == 2) {
+            Build32bitCache(fCache32, fOrigColors[0], fOrigColors[1],
+                            kCache32Count, fCacheAlpha, fGradFlags);
+        } else {
+            Rec* rec = fRecs;
+            int prevIndex = 0;
+            for (int i = 1; i < fColorCount; i++) {
+                int nextIndex = SkFixedToFFFF(rec[i].fPos) >> kCache32Shift;
+                SkASSERT(nextIndex < kCache32Count);
 
-            if (nextIndex > prevIndex)
-                Build16bitCache(cache->fCache16 + prevIndex, cache->fShader.fOrigColors[i-1],
-                                cache->fShader.fOrigColors[i], nextIndex - prevIndex + 1);
-            prevIndex = nextIndex;
+                if (nextIndex > prevIndex)
+                    Build32bitCache(fCache32 + prevIndex, fOrigColors[i-1],
+                                    fOrigColors[i], nextIndex - prevIndex + 1,
+                                    fCacheAlpha, fGradFlags);
+                prevIndex = nextIndex;
+            }
+        }
+
+        if (fMapper) {
+            SkMallocPixelRef* newPR = SkMallocPixelRef::NewAllocate(info, 0, NULL);
+            SkPMColor* linear = fCache32;           // just computed linear data
+            SkPMColor* mapped = (SkPMColor*)newPR->getAddr();    // storage for mapped data
+            SkUnitMapper* map = fMapper;
+            for (int i = 0; i < kCache32Count; i++) {
+                int index = map->mapUnit16((i << 8) | i) >> 8;
+                mapped[i + kCache32Count*0] = linear[index + kCache32Count*0];
+                mapped[i + kCache32Count*1] = linear[index + kCache32Count*1];
+                mapped[i + kCache32Count*2] = linear[index + kCache32Count*2];
+                mapped[i + kCache32Count*3] = linear[index + kCache32Count*3];
+            }
+            fCache32PixelRef->unref();
+            fCache32PixelRef = newPR;
+            fCache32 = (SkPMColor*)newPR->getAddr();
         }
     }
-
-    if (cache->fShader.fMapper) {
-        cache->fCache16Storage = (uint16_t*)sk_malloc_throw(allocSize);
-        uint16_t* linear = cache->fCache16;         // just computed linear data
-        uint16_t* mapped = cache->fCache16Storage;  // storage for mapped data
-        SkUnitMapper* map = cache->fShader.fMapper;
-        for (int i = 0; i < kCache16Count; i++) {
-            int index = map->mapUnit16(bitsTo16(i, kCache16Bits)) >> kCache16Shift;
-            mapped[i] = linear[index];
-            mapped[i + kCache16Count] = linear[index + kCache16Count];
-        }
-        sk_free(cache->fCache16);
-        cache->fCache16 = cache->fCache16Storage;
-    }
-}
-
-const SkPMColor* SkGradientShaderBase::GradientShaderCache::getCache32() {
-    SkOnce(&fCache32Inited, &fCache32Mutex, SkGradientShaderBase::GradientShaderCache::initCache32,
-           this);
-    SkASSERT(fCache32);
     return fCache32;
 }
 
-void SkGradientShaderBase::GradientShaderCache::initCache32(GradientShaderCache* cache) {
-    SkImageInfo info;
-    info.fWidth = kCache32Count;
-    info.fHeight = 4;   // for our 4 dither rows
-    info.fAlphaType = kPremul_SkAlphaType;
-    info.fColorType = kN32_SkColorType;
-
-    SkASSERT(NULL == cache->fCache32PixelRef);
-    cache->fCache32PixelRef = SkMallocPixelRef::NewAllocate(info, 0, NULL);
-    cache->fCache32 = (SkPMColor*)cache->fCache32PixelRef->getAddr();
-    if (cache->fShader.fColorCount == 2) {
-        Build32bitCache(cache->fCache32, cache->fShader.fOrigColors[0],
-                        cache->fShader.fOrigColors[1], kCache32Count, cache->fCacheAlpha,
-                        cache->fShader.fGradFlags);
-    } else {
-        Rec* rec = cache->fShader.fRecs;
-        int prevIndex = 0;
-        for (int i = 1; i < cache->fShader.fColorCount; i++) {
-            int nextIndex = SkFixedToFFFF(rec[i].fPos) >> kCache32Shift;
-            SkASSERT(nextIndex < kCache32Count);
-
-            if (nextIndex > prevIndex)
-                Build32bitCache(cache->fCache32 + prevIndex, cache->fShader.fOrigColors[i-1],
-                                cache->fShader.fOrigColors[i], nextIndex - prevIndex + 1,
-                                cache->fCacheAlpha, cache->fShader.fGradFlags);
-            prevIndex = nextIndex;
-        }
-    }
-
-    if (cache->fShader.fMapper) {
-        SkMallocPixelRef* newPR = SkMallocPixelRef::NewAllocate(info, 0, NULL);
-        SkPMColor* linear = cache->fCache32;           // just computed linear data
-        SkPMColor* mapped = (SkPMColor*)newPR->getAddr();    // storage for mapped data
-        SkUnitMapper* map = cache->fShader.fMapper;
-        for (int i = 0; i < kCache32Count; i++) {
-            int index = map->mapUnit16((i << 8) | i) >> 8;
-            mapped[i + kCache32Count*0] = linear[index + kCache32Count*0];
-            mapped[i + kCache32Count*1] = linear[index + kCache32Count*1];
-            mapped[i + kCache32Count*2] = linear[index + kCache32Count*2];
-            mapped[i + kCache32Count*3] = linear[index + kCache32Count*3];
-        }
-        cache->fCache32PixelRef->unref();
-        cache->fCache32PixelRef = newPR;
-        cache->fCache32 = (SkPMColor*)newPR->getAddr();
-    }
-}
-
-/*
- *  The gradient holds a cache for the most recent value of alpha. Successive
- *  callers with the same alpha value will share the same cache.
- */
-SkGradientShaderBase::GradientShaderCache* SkGradientShaderBase::refCache(U8CPU alpha) const {
-    SkAutoMutexAcquire ama(fCacheMutex);
-    if (!fCache || fCache->getAlpha() != alpha) {
-        fCache.reset(SkNEW_ARGS(GradientShaderCache, (alpha, *this)));
-    }
-    // Increment the ref counter inside the mutex to ensure the returned pointer is still valid.
-    // Otherwise, the pointer may have been overwritten on a different thread before the object's
-    // ref count was incremented.
-    fCache.get()->ref();
-    return fCache;
-}
-
 /*
  *  Because our caller might rebuild the same (logically the same) gradient
  *  over and over, we'd like to return exactly the same "bitmap" if possible,
@@ -624,14 +615,14 @@
 void SkGradientShaderBase::getGradientTableBitmap(SkBitmap* bitmap) const {
     // our caller assumes no external alpha, so we ensure that our cache is
     // built with 0xFF
-    SkAutoTUnref<GradientShaderCache> cache(this->refCache(0xFF));
+    this->setCacheAlpha(0xFF);
 
     // don't have a way to put the mapper into our cache-key yet
     if (fMapper) {
-        // force our cache32pixelref to be built
-        (void)cache->getCache32();
+        // force our cahce32pixelref to be built
+        (void)this->getCache32();
         bitmap->setConfig(SkImageInfo::MakeN32Premul(kCache32Count, 1));
-        bitmap->setPixelRef(cache->getCache32PixelRef());
+        bitmap->setPixelRef(fCache32PixelRef);
         return;
     }
 
@@ -670,9 +661,9 @@
 
     if (!gCache->find(storage.get(), size, bitmap)) {
         // force our cahce32pixelref to be built
-        (void)cache->getCache32();
+        (void)this->getCache32();
         bitmap->setConfig(SkImageInfo::MakeN32Premul(kCache32Count, 1));
-        bitmap->setPixelRef(cache->getCache32PixelRef());
+        bitmap->setPixelRef(fCache32PixelRef);
 
         gCache->add(storage.get(), size, *bitmap);
     }
diff --git a/src/effects/gradients/SkGradientShaderPriv.h b/src/effects/gradients/SkGradientShaderPriv.h
index 5dec665..02bb50b 100644
--- a/src/effects/gradients/SkGradientShaderPriv.h
+++ b/src/effects/gradients/SkGradientShaderPriv.h
@@ -19,7 +19,6 @@
 #include "SkTemplates.h"
 #include "SkBitmapCache.h"
 #include "SkShader.h"
-#include "SkOnce.h"
 
 static inline void sk_memset32_dither(uint32_t dst[], uint32_t v0, uint32_t v1,
                                int count) {
@@ -102,64 +101,8 @@
     SkGradientShaderBase(const Descriptor& desc);
     virtual ~SkGradientShaderBase();
 
-    // The cache is initialized on-demand when getCache16/32 is called.
-    class GradientShaderCache : public SkRefCnt {
-    public:
-        GradientShaderCache(U8CPU alpha, const SkGradientShaderBase& shader);
-        ~GradientShaderCache();
-
-        const uint16_t*     getCache16();
-        const SkPMColor*    getCache32();
-
-        SkMallocPixelRef* getCache32PixelRef() const { return fCache32PixelRef; }
-
-        unsigned getAlpha() const { return fCacheAlpha; }
-
-    private:
-        // Working pointers. If either is NULL, we need to recompute the corresponding cache values.
-        uint16_t*   fCache16;
-        SkPMColor*  fCache32;
-
-        uint16_t*         fCache16Storage;    // Storage for fCache16, allocated on demand.
-        SkMallocPixelRef* fCache32PixelRef;
-        const unsigned    fCacheAlpha;        // The alpha value we used when we computed the cache.
-                                              // Larger than 8bits so we can store uninitialized
-                                              // value.
-
-        const SkGradientShaderBase& fShader;
-
-        // Make sure we only initialize the caches once.
-        bool    fCache16Inited, fCache32Inited;
-        SkMutex fCache16Mutex, fCache32Mutex;
-
-        static void initCache16(GradientShaderCache* cache);
-        static void initCache32(GradientShaderCache* cache);
-
-        static void Build16bitCache(uint16_t[], SkColor c0, SkColor c1, int count);
-        static void Build32bitCache(SkPMColor[], SkColor c0, SkColor c1, int count,
-                                    U8CPU alpha, uint32_t gradFlags);
-    };
-
-    class GradientShaderBaseContext : public SkShader::Context {
-    public:
-        GradientShaderBaseContext(const SkGradientShaderBase& shader, const SkBitmap& device,
-                                  const SkPaint& paint, const SkMatrix& matrix);
-        ~GradientShaderBaseContext() {}
-
-        virtual uint32_t getFlags() const SK_OVERRIDE { return fFlags; }
-
-    protected:
-        SkMatrix    fDstToIndex;
-        SkMatrix::MapXYProc fDstToIndexProc;
-        uint8_t     fDstToIndexClass;
-        uint8_t     fFlags;
-
-        SkAutoTUnref<GradientShaderCache> fCache;
-
-    private:
-        typedef SkShader::Context INHERITED;
-    };
-
+    virtual bool setContext(const SkBitmap&, const SkPaint&, const SkMatrix&) SK_OVERRIDE;
+    virtual uint32_t getFlags() SK_OVERRIDE { return fFlags; }
     virtual bool isOpaque() const SK_OVERRIDE;
 
     void getGradientTableBitmap(SkBitmap*) const;
@@ -205,9 +148,13 @@
 
     SkUnitMapper* fMapper;
     SkMatrix    fPtsToUnit;     // set by subclass
+    SkMatrix    fDstToIndex;
+    SkMatrix::MapXYProc fDstToIndexProc;
     TileMode    fTileMode;
     TileProc    fTileProc;
     int         fColorCount;
+    uint8_t     fDstToIndexClass;
+    uint8_t     fFlags;
     uint8_t     fGradFlags;
 
     struct Rec {
@@ -216,6 +163,9 @@
     };
     Rec*        fRecs;
 
+    const uint16_t*     getCache16() const;
+    const SkPMColor*    getCache32() const;
+
     void commonAsAGradient(GradientInfo*, bool flipGrad = false) const;
 
     /*
@@ -241,13 +191,20 @@
         kStorageSize = kColorStorageCount * (sizeof(SkColor) + sizeof(Rec))
     };
     SkColor     fStorage[(kStorageSize + 3) >> 2];
-    SkColor*    fOrigColors; // original colors, before modulation by paint in context.
+    SkColor*    fOrigColors; // original colors, before modulation by paint in setContext
     bool        fColorsAreOpaque;
 
-    GradientShaderCache* refCache(U8CPU alpha) const;
-    mutable SkMutex                           fCacheMutex;
-    mutable SkAutoTUnref<GradientShaderCache> fCache;
+    mutable uint16_t*   fCache16;   // working ptr. If this is NULL, we need to recompute the cache values
+    mutable SkPMColor*  fCache32;   // working ptr. If this is NULL, we need to recompute the cache values
 
+    mutable uint16_t*   fCache16Storage;    // storage for fCache16, allocated on demand
+    mutable SkMallocPixelRef* fCache32PixelRef;
+    mutable unsigned    fCacheAlpha;        // the alpha value we used when we computed the cache. larger than 8bits so we can store uninitialized value
+
+    static void Build16bitCache(uint16_t[], SkColor c0, SkColor c1, int count);
+    static void Build32bitCache(SkPMColor[], SkColor c0, SkColor c1, int count,
+                                U8CPU alpha, uint32_t gradFlags);
+    void setCacheAlpha(U8CPU alpha) const;
     void initCommon();
 
     typedef SkShader INHERITED;
diff --git a/src/effects/gradients/SkLinearGradient.cpp b/src/effects/gradients/SkLinearGradient.cpp
index e660d7c..b24a634 100644
--- a/src/effects/gradients/SkLinearGradient.cpp
+++ b/src/effects/gradients/SkLinearGradient.cpp
@@ -71,24 +71,12 @@
     buffer.writePoint(fEnd);
 }
 
-size_t SkLinearGradient::contextSize() const {
-    return sizeof(LinearGradientContext);
-}
-
-SkShader::Context* SkLinearGradient::createContext(const SkBitmap& device, const SkPaint& paint,
-                                                   const SkMatrix& matrix, void* storage) const {
-    if (!this->validContext(device, paint, matrix)) {
-        return NULL;
+bool SkLinearGradient::setContext(const SkBitmap& device, const SkPaint& paint,
+                                 const SkMatrix& matrix) {
+    if (!this->INHERITED::setContext(device, paint, matrix)) {
+        return false;
     }
 
-    return SkNEW_PLACEMENT_ARGS(storage, LinearGradientContext, (*this, device, paint, matrix));
-}
-
-SkLinearGradient::LinearGradientContext::LinearGradientContext(
-        const SkLinearGradient& shader, const SkBitmap& device,
-        const SkPaint& paint, const SkMatrix& matrix)
-    : INHERITED(shader, device, paint, matrix)
-{
     unsigned mask = SkMatrix::kTranslate_Mask | SkMatrix::kScale_Mask;
     if ((fDstToIndex.getType() & ~mask) == 0) {
         // when we dither, we are (usually) not const-in-Y
@@ -99,6 +87,7 @@
             fFlags |= SkShader::kConstInY16_Flag;
         }
     }
+    return true;
 }
 
 #define NO_CHECK_ITER               \
@@ -207,16 +196,14 @@
 
 }
 
-void SkLinearGradient::LinearGradientContext::shadeSpan(int x, int y, SkPMColor* SK_RESTRICT dstC,
-                                                        int count) {
+void SkLinearGradient::shadeSpan(int x, int y, SkPMColor* SK_RESTRICT dstC,
+                                int count) {
     SkASSERT(count > 0);
 
-    const SkLinearGradient& linearGradient = static_cast<const SkLinearGradient&>(fShader);
-
     SkPoint             srcPt;
     SkMatrix::MapXYProc dstProc = fDstToIndexProc;
-    TileProc            proc = linearGradient.fTileProc;
-    const SkPMColor* SK_RESTRICT cache = fCache->getCache32();
+    TileProc            proc = fTileProc;
+    const SkPMColor* SK_RESTRICT cache = this->getCache32();
     int                 toggle = init_dither_toggle(x, y);
 
     if (fDstToIndexClass != kPerspective_MatrixClass) {
@@ -236,12 +223,12 @@
         LinearShadeProc shadeProc = shadeSpan_linear_repeat;
         if (0 == dx) {
             shadeProc = shadeSpan_linear_vertical_lerp;
-        } else if (SkShader::kClamp_TileMode == linearGradient.fTileMode) {
+        } else if (SkShader::kClamp_TileMode == fTileMode) {
             shadeProc = shadeSpan_linear_clamp;
-        } else if (SkShader::kMirror_TileMode == linearGradient.fTileMode) {
+        } else if (SkShader::kMirror_TileMode == fTileMode) {
             shadeProc = shadeSpan_linear_mirror;
         } else {
-            SkASSERT(SkShader::kRepeat_TileMode == linearGradient.fTileMode);
+            SkASSERT(SkShader::kRepeat_TileMode == fTileMode);
         }
         (*shadeProc)(proc, dx, fx, dstC, cache, toggle, count);
     } else {
@@ -394,16 +381,14 @@
     return SkAbs32(x) < (SK_Fixed1 >> 12);
 }
 
-void SkLinearGradient::LinearGradientContext::shadeSpan16(int x, int y,
-                                                          uint16_t* SK_RESTRICT dstC, int count) {
+void SkLinearGradient::shadeSpan16(int x, int y,
+                                  uint16_t* SK_RESTRICT dstC, int count) {
     SkASSERT(count > 0);
 
-    const SkLinearGradient& linearGradient = static_cast<const SkLinearGradient&>(fShader);
-
     SkPoint             srcPt;
     SkMatrix::MapXYProc dstProc = fDstToIndexProc;
-    TileProc            proc = linearGradient.fTileProc;
-    const uint16_t* SK_RESTRICT cache = fCache->getCache16();
+    TileProc            proc = fTileProc;
+    const uint16_t* SK_RESTRICT cache = this->getCache16();
     int                 toggle = init_dither_toggle16(x, y);
 
     if (fDstToIndexClass != kPerspective_MatrixClass) {
@@ -423,12 +408,12 @@
         LinearShade16Proc shadeProc = shadeSpan16_linear_repeat;
         if (fixed_nearly_zero(dx)) {
             shadeProc = shadeSpan16_linear_vertical;
-        } else if (SkShader::kClamp_TileMode == linearGradient.fTileMode) {
+        } else if (SkShader::kClamp_TileMode == fTileMode) {
             shadeProc = shadeSpan16_linear_clamp;
-        } else if (SkShader::kMirror_TileMode == linearGradient.fTileMode) {
+        } else if (SkShader::kMirror_TileMode == fTileMode) {
             shadeProc = shadeSpan16_linear_mirror;
         } else {
-            SkASSERT(SkShader::kRepeat_TileMode == linearGradient.fTileMode);
+            SkASSERT(SkShader::kRepeat_TileMode == fTileMode);
         }
         (*shadeProc)(proc, dx, fx, dstC, cache, toggle, count);
     } else {
diff --git a/src/effects/gradients/SkLinearGradient.h b/src/effects/gradients/SkLinearGradient.h
index 8d80667..013c449 100644
--- a/src/effects/gradients/SkLinearGradient.h
+++ b/src/effects/gradients/SkLinearGradient.h
@@ -15,23 +15,9 @@
 public:
     SkLinearGradient(const SkPoint pts[2], const Descriptor&);
 
-    virtual SkShader::Context* createContext(const SkBitmap&, const SkPaint&, const SkMatrix&,
-                                             void* storage) const SK_OVERRIDE;
-    virtual size_t contextSize() const SK_OVERRIDE;
-
-    class LinearGradientContext : public SkGradientShaderBase::GradientShaderBaseContext {
-    public:
-        LinearGradientContext(const SkLinearGradient& shader, const SkBitmap& device,
-                              const SkPaint& paint, const SkMatrix& matrix);
-        ~LinearGradientContext() {}
-
-        virtual void shadeSpan(int x, int y, SkPMColor dstC[], int count) SK_OVERRIDE;
-        virtual void shadeSpan16(int x, int y, uint16_t dstC[], int count) SK_OVERRIDE;
-
-    private:
-        typedef SkGradientShaderBase::GradientShaderBaseContext INHERITED;
-    };
-
+    virtual bool setContext(const SkBitmap&, const SkPaint&, const SkMatrix&) SK_OVERRIDE;
+    virtual void shadeSpan(int x, int y, SkPMColor dstC[], int count) SK_OVERRIDE;
+    virtual void shadeSpan16(int x, int y, uint16_t dstC[], int count) SK_OVERRIDE;
     virtual BitmapType asABitmap(SkBitmap*, SkMatrix*, TileMode*) const SK_OVERRIDE;
     virtual GradientType asAGradient(GradientInfo* info) const SK_OVERRIDE;
     virtual GrEffectRef* asNewEffect(GrContext* context, const SkPaint&) const SK_OVERRIDE;
diff --git a/src/effects/gradients/SkRadialGradient.cpp b/src/effects/gradients/SkRadialGradient.cpp
index bc2ea3b..1b9e725 100644
--- a/src/effects/gradients/SkRadialGradient.cpp
+++ b/src/effects/gradients/SkRadialGradient.cpp
@@ -157,36 +157,16 @@
     rad_to_unit_matrix(center, radius, &fPtsToUnit);
 }
 
-size_t SkRadialGradient::contextSize() const {
-    return sizeof(RadialGradientContext);
-}
-
-SkShader::Context* SkRadialGradient::createContext(const SkBitmap& device, const SkPaint& paint,
-                                                   const SkMatrix& matrix, void* storage) const {
-    if (!this->validContext(device, paint, matrix)) {
-        return NULL;
-    }
-
-    return SkNEW_PLACEMENT_ARGS(storage, RadialGradientContext, (*this, device, paint, matrix));
-}
-
-SkRadialGradient::RadialGradientContext::RadialGradientContext(
-        const SkRadialGradient& shader, const SkBitmap& device,
-        const SkPaint& paint, const SkMatrix& matrix)
-    : INHERITED(shader, device, paint, matrix) {}
-
-void SkRadialGradient::RadialGradientContext::shadeSpan16(int x, int y, uint16_t* dstCParam,
-                                                          int count) {
+void SkRadialGradient::shadeSpan16(int x, int y, uint16_t* dstCParam,
+                         int count) {
     SkASSERT(count > 0);
 
-    const SkRadialGradient& radialGradient = static_cast<const SkRadialGradient&>(fShader);
-
     uint16_t* SK_RESTRICT dstC = dstCParam;
 
     SkPoint             srcPt;
     SkMatrix::MapXYProc dstProc = fDstToIndexProc;
-    TileProc            proc = radialGradient.fTileProc;
-    const uint16_t* SK_RESTRICT cache = fCache->getCache16();
+    TileProc            proc = fTileProc;
+    const uint16_t* SK_RESTRICT cache = this->getCache16();
     int                 toggle = init_dither_toggle16(x, y);
 
     if (fDstToIndexClass != kPerspective_MatrixClass) {
@@ -207,12 +187,12 @@
         }
 
         RadialShade16Proc shadeProc = shadeSpan16_radial_repeat;
-        if (SkShader::kClamp_TileMode == radialGradient.fTileMode) {
+        if (SkShader::kClamp_TileMode == fTileMode) {
             shadeProc = shadeSpan16_radial_clamp;
-        } else if (SkShader::kMirror_TileMode == radialGradient.fTileMode) {
+        } else if (SkShader::kMirror_TileMode == fTileMode) {
             shadeProc = shadeSpan16_radial_mirror;
         } else {
-            SkASSERT(SkShader::kRepeat_TileMode == radialGradient.fTileMode);
+            SkASSERT(SkShader::kRepeat_TileMode == fTileMode);
         }
         (*shadeProc)(srcPt.fX, sdx, srcPt.fY, sdy, dstC,
                      cache, toggle, count);
@@ -409,16 +389,14 @@
 
 }  // namespace
 
-void SkRadialGradient::RadialGradientContext::shadeSpan(int x, int y,
-                                                        SkPMColor* SK_RESTRICT dstC, int count) {
+void SkRadialGradient::shadeSpan(int x, int y,
+                                SkPMColor* SK_RESTRICT dstC, int count) {
     SkASSERT(count > 0);
 
-    const SkRadialGradient& radialGradient = static_cast<const SkRadialGradient&>(fShader);
-
     SkPoint             srcPt;
     SkMatrix::MapXYProc dstProc = fDstToIndexProc;
-    TileProc            proc = radialGradient.fTileProc;
-    const SkPMColor* SK_RESTRICT cache = fCache->getCache32();
+    TileProc            proc = fTileProc;
+    const SkPMColor* SK_RESTRICT cache = this->getCache32();
     int toggle = init_dither_toggle(x, y);
 
     if (fDstToIndexClass != kPerspective_MatrixClass) {
@@ -438,12 +416,12 @@
         }
 
         RadialShadeProc shadeProc = shadeSpan_radial_repeat;
-        if (SkShader::kClamp_TileMode == radialGradient.fTileMode) {
+        if (SkShader::kClamp_TileMode == fTileMode) {
             shadeProc = shadeSpan_radial_clamp;
-        } else if (SkShader::kMirror_TileMode == radialGradient.fTileMode) {
+        } else if (SkShader::kMirror_TileMode == fTileMode) {
             shadeProc = shadeSpan_radial_mirror;
         } else {
-            SkASSERT(SkShader::kRepeat_TileMode == radialGradient.fTileMode);
+            SkASSERT(SkShader::kRepeat_TileMode == fTileMode);
         }
         (*shadeProc)(srcPt.fX, sdx, srcPt.fY, sdy, dstC, cache, count, toggle);
     } else {    // perspective case
diff --git a/src/effects/gradients/SkRadialGradient.h b/src/effects/gradients/SkRadialGradient.h
index a3d04b1..4a72514 100644
--- a/src/effects/gradients/SkRadialGradient.h
+++ b/src/effects/gradients/SkRadialGradient.h
@@ -14,24 +14,10 @@
 class SkRadialGradient : public SkGradientShaderBase {
 public:
     SkRadialGradient(const SkPoint& center, SkScalar radius, const Descriptor&);
-
-    virtual SkShader::Context* createContext(const SkBitmap&, const SkPaint&, const SkMatrix&,
-                                             void* storage) const SK_OVERRIDE;
-    virtual size_t contextSize() const SK_OVERRIDE;
-
-    class RadialGradientContext : public SkGradientShaderBase::GradientShaderBaseContext {
-    public:
-        RadialGradientContext(const SkRadialGradient& shader, const SkBitmap& device,
-                              const SkPaint& paint, const SkMatrix& matrix);
-        ~RadialGradientContext() {}
-
-        virtual void shadeSpan(int x, int y, SkPMColor dstC[], int count) SK_OVERRIDE;
-        virtual void shadeSpan16(int x, int y, uint16_t dstC[], int count) SK_OVERRIDE;
-
-    private:
-        typedef SkGradientShaderBase::GradientShaderBaseContext INHERITED;
-    };
-
+    virtual void shadeSpan(int x, int y, SkPMColor* dstC, int count)
+        SK_OVERRIDE;
+    virtual void shadeSpan16(int x, int y, uint16_t* dstCParam,
+                             int count) SK_OVERRIDE;
     virtual BitmapType asABitmap(SkBitmap* bitmap,
                                  SkMatrix* matrix,
                                  TileMode* xy) const SK_OVERRIDE;
diff --git a/src/effects/gradients/SkSweepGradient.cpp b/src/effects/gradients/SkSweepGradient.cpp
index 6dff1e7..7024945 100644
--- a/src/effects/gradients/SkSweepGradient.cpp
+++ b/src/effects/gradients/SkSweepGradient.cpp
@@ -52,24 +52,6 @@
     buffer.writePoint(fCenter);
 }
 
-size_t SkSweepGradient::contextSize() const {
-    return sizeof(SweepGradientContext);
-}
-
-SkShader::Context* SkSweepGradient::createContext(const SkBitmap& device, const SkPaint& paint,
-                                                  const SkMatrix& matrix, void* storage) const {
-    if (!this->validContext(device, paint, matrix)) {
-        return NULL;
-    }
-
-    return SkNEW_PLACEMENT_ARGS(storage, SweepGradientContext, (*this, device, paint, matrix));
-}
-
-SkSweepGradient::SweepGradientContext::SweepGradientContext(
-        const SkSweepGradient& shader, const SkBitmap& device,
-        const SkPaint& paint, const SkMatrix& matrix)
-    : INHERITED(shader, device, paint, matrix) {}
-
 //  returns angle in a circle [0..2PI) -> [0..255]
 static unsigned SkATan2_255(float y, float x) {
     //    static const float g255Over2PI = 255 / (2 * SK_ScalarPI);
@@ -87,11 +69,11 @@
     return ir;
 }
 
-void SkSweepGradient::SweepGradientContext::shadeSpan(int x, int y, SkPMColor* SK_RESTRICT dstC,
-                                                      int count) {
+void SkSweepGradient::shadeSpan(int x, int y, SkPMColor* SK_RESTRICT dstC,
+                               int count) {
     SkMatrix::MapXYProc proc = fDstToIndexProc;
     const SkMatrix&     matrix = fDstToIndex;
-    const SkPMColor* SK_RESTRICT cache = fCache->getCache32();
+    const SkPMColor* SK_RESTRICT cache = this->getCache32();
     int                 toggle = init_dither_toggle(x, y);
     SkPoint             srcPt;
 
@@ -129,11 +111,11 @@
     }
 }
 
-void SkSweepGradient::SweepGradientContext::shadeSpan16(int x, int y, uint16_t* SK_RESTRICT dstC,
-                                                        int count) {
+void SkSweepGradient::shadeSpan16(int x, int y, uint16_t* SK_RESTRICT dstC,
+                                 int count) {
     SkMatrix::MapXYProc proc = fDstToIndexProc;
     const SkMatrix&     matrix = fDstToIndex;
-    const uint16_t* SK_RESTRICT cache = fCache->getCache16();
+    const uint16_t* SK_RESTRICT cache = this->getCache16();
     int                 toggle = init_dither_toggle16(x, y);
     SkPoint             srcPt;
 
diff --git a/src/effects/gradients/SkSweepGradient.h b/src/effects/gradients/SkSweepGradient.h
index 9998ed1..ca19da2 100644
--- a/src/effects/gradients/SkSweepGradient.h
+++ b/src/effects/gradients/SkSweepGradient.h
@@ -14,23 +14,8 @@
 class SkSweepGradient : public SkGradientShaderBase {
 public:
     SkSweepGradient(SkScalar cx, SkScalar cy, const Descriptor&);
-
-    virtual SkShader::Context* createContext(const SkBitmap&, const SkPaint&, const SkMatrix&,
-                                             void* storage) const SK_OVERRIDE;
-    virtual size_t contextSize() const SK_OVERRIDE;
-
-    class SweepGradientContext : public SkGradientShaderBase::GradientShaderBaseContext {
-    public:
-        SweepGradientContext(const SkSweepGradient& shader, const SkBitmap& device,
-                             const SkPaint& paint, const SkMatrix& matrix);
-        ~SweepGradientContext() {}
-
-        virtual void shadeSpan(int x, int y, SkPMColor dstC[], int count) SK_OVERRIDE;
-        virtual void shadeSpan16(int x, int y, uint16_t dstC[], int count) SK_OVERRIDE;
-
-    private:
-        typedef SkGradientShaderBase::GradientShaderBaseContext INHERITED;
-    };
+    virtual void shadeSpan(int x, int y, SkPMColor dstC[], int count) SK_OVERRIDE;
+    virtual void shadeSpan16(int x, int y, uint16_t dstC[], int count) SK_OVERRIDE;
 
     virtual BitmapType asABitmap(SkBitmap* bitmap,
                                  SkMatrix* matrix,
@@ -48,9 +33,8 @@
     virtual void flatten(SkWriteBuffer& buffer) const SK_OVERRIDE;
 
 private:
-    const SkPoint fCenter;
-
     typedef SkGradientShaderBase INHERITED;
+    const SkPoint fCenter;
 };
 
 #endif
diff --git a/src/effects/gradients/SkTwoPointConicalGradient.cpp b/src/effects/gradients/SkTwoPointConicalGradient.cpp
index b7aba82..1e6a0d8 100644
--- a/src/effects/gradients/SkTwoPointConicalGradient.cpp
+++ b/src/effects/gradients/SkTwoPointConicalGradient.cpp
@@ -9,18 +9,6 @@
 
 #include "SkTwoPointConicalGradient_gpu.h"
 
-struct TwoPtRadialContext {
-    const TwoPtRadial&  fRec;
-    float               fRelX, fRelY;
-    const float         fIncX, fIncY;
-    float               fB;
-    const float         fDB;
-
-    TwoPtRadialContext(const TwoPtRadial& rec, SkScalar fx, SkScalar fy,
-                       SkScalar dfx, SkScalar dfy);
-    SkFixed nextT();
-};
-
 static int valid_divide(float numer, float denom, float* ratio) {
     SkASSERT(ratio);
     if (0 == denom) {
@@ -95,48 +83,47 @@
     fFlipped = flipped;
 }
 
-TwoPtRadialContext::TwoPtRadialContext(const TwoPtRadial& rec, SkScalar fx, SkScalar fy,
-                                       SkScalar dfx, SkScalar dfy)
-    : fRec(rec)
-    , fRelX(SkScalarToFloat(fx) - rec.fCenterX)
-    , fRelY(SkScalarToFloat(fy) - rec.fCenterY)
-    , fIncX(SkScalarToFloat(dfx))
-    , fIncY(SkScalarToFloat(dfy))
-    , fB(-2 * (rec.fDCenterX * fRelX + rec.fDCenterY * fRelY + rec.fRDR))
-    , fDB(-2 * (rec.fDCenterX * fIncX + rec.fDCenterY * fIncY)) {}
+void TwoPtRadial::setup(SkScalar fx, SkScalar fy, SkScalar dfx, SkScalar dfy) {
+    fRelX = SkScalarToFloat(fx) - fCenterX;
+    fRelY = SkScalarToFloat(fy) - fCenterY;
+    fIncX = SkScalarToFloat(dfx);
+    fIncY = SkScalarToFloat(dfy);
+    fB = -2 * (fDCenterX * fRelX + fDCenterY * fRelY + fRDR);
+    fDB = -2 * (fDCenterX * fIncX + fDCenterY * fIncY);
+}
 
-SkFixed TwoPtRadialContext::nextT() {
+SkFixed TwoPtRadial::nextT() {
     float roots[2];
 
-    float C = sqr(fRelX) + sqr(fRelY) - fRec.fRadius2;
-    int countRoots = find_quad_roots(fRec.fA, fB, C, roots, fRec.fFlipped);
+    float C = sqr(fRelX) + sqr(fRelY) - fRadius2;
+    int countRoots = find_quad_roots(fA, fB, C, roots, fFlipped);
 
     fRelX += fIncX;
     fRelY += fIncY;
     fB += fDB;
 
     if (0 == countRoots) {
-        return TwoPtRadial::kDontDrawT;
+        return kDontDrawT;
     }
 
     // Prefer the bigger t value if both give a radius(t) > 0
     // find_quad_roots returns the values sorted, so we start with the last
     float t = roots[countRoots - 1];
-    float r = lerp(fRec.fRadius, fRec.fDRadius, t);
+    float r = lerp(fRadius, fDRadius, t);
     if (r <= 0) {
         t = roots[0];   // might be the same as roots[countRoots-1]
-        r = lerp(fRec.fRadius, fRec.fDRadius, t);
+        r = lerp(fRadius, fDRadius, t);
         if (r <= 0) {
-            return TwoPtRadial::kDontDrawT;
+            return kDontDrawT;
         }
     }
     return SkFloatToFixed(t);
 }
 
-typedef void (*TwoPointConicalProc)(TwoPtRadialContext* rec, SkPMColor* dstC,
+typedef void (*TwoPointConicalProc)(TwoPtRadial* rec, SkPMColor* dstC,
                                     const SkPMColor* cache, int toggle, int count);
 
-static void twopoint_clamp(TwoPtRadialContext* rec, SkPMColor* SK_RESTRICT dstC,
+static void twopoint_clamp(TwoPtRadial* rec, SkPMColor* SK_RESTRICT dstC,
                            const SkPMColor* SK_RESTRICT cache, int toggle,
                            int count) {
     for (; count > 0; --count) {
@@ -153,7 +140,7 @@
     }
 }
 
-static void twopoint_repeat(TwoPtRadialContext* rec, SkPMColor* SK_RESTRICT dstC,
+static void twopoint_repeat(TwoPtRadial* rec, SkPMColor* SK_RESTRICT dstC,
                             const SkPMColor* SK_RESTRICT cache, int toggle,
                             int count) {
     for (; count > 0; --count) {
@@ -170,7 +157,7 @@
     }
 }
 
-static void twopoint_mirror(TwoPtRadialContext* rec, SkPMColor* SK_RESTRICT dstC,
+static void twopoint_mirror(TwoPtRadial* rec, SkPMColor* SK_RESTRICT dstC,
                             const SkPMColor* SK_RESTRICT cache, int toggle,
                             int count) {
     for (; count > 0; --count) {
@@ -216,39 +203,8 @@
     return false;
 }
 
-size_t SkTwoPointConicalGradient::contextSize() const {
-    return sizeof(TwoPointConicalGradientContext);
-}
-
-SkShader::Context* SkTwoPointConicalGradient::createContext(
-        const SkBitmap& device, const SkPaint& paint,
-        const SkMatrix& matrix, void* storage) const {
-    if (!this->validContext(device, paint, matrix)) {
-        return NULL;
-    }
-
-    return SkNEW_PLACEMENT_ARGS(storage, TwoPointConicalGradientContext,
-                                (*this, device, paint, matrix));
-}
-
-SkTwoPointConicalGradient::TwoPointConicalGradientContext::TwoPointConicalGradientContext(
-        const SkTwoPointConicalGradient& shader, const SkBitmap& device,
-        const SkPaint& paint, const SkMatrix& matrix)
-    : INHERITED(shader, device, paint, matrix)
-{
-    // we don't have a span16 proc
-    fFlags &= ~kHasSpan16_Flag;
-
-    // in general, we might discard based on computed-radius, so clear
-    // this flag (todo: sometimes we can detect that we never discard...)
-    fFlags &= ~kOpaqueAlpha_Flag;
-}
-
-void SkTwoPointConicalGradient::TwoPointConicalGradientContext::shadeSpan(
-        int x, int y, SkPMColor* dstCParam, int count) {
-    const SkTwoPointConicalGradient& twoPointConicalGradient =
-            static_cast<const SkTwoPointConicalGradient&>(fShader);
-
+void SkTwoPointConicalGradient::shadeSpan(int x, int y, SkPMColor* dstCParam,
+                                          int count) {
     int toggle = init_dither_toggle(x, y);
 
     SkASSERT(count > 0);
@@ -257,15 +213,15 @@
 
     SkMatrix::MapXYProc dstProc = fDstToIndexProc;
 
-    const SkPMColor* SK_RESTRICT cache = fCache->getCache32();
+    const SkPMColor* SK_RESTRICT cache = this->getCache32();
 
     TwoPointConicalProc shadeProc = twopoint_repeat;
-    if (SkShader::kClamp_TileMode == twoPointConicalGradient.fTileMode) {
+    if (SkShader::kClamp_TileMode == fTileMode) {
         shadeProc = twopoint_clamp;
-    } else if (SkShader::kMirror_TileMode == twoPointConicalGradient.fTileMode) {
+    } else if (SkShader::kMirror_TileMode == fTileMode) {
         shadeProc = twopoint_mirror;
     } else {
-        SkASSERT(SkShader::kRepeat_TileMode == twoPointConicalGradient.fTileMode);
+        SkASSERT(SkShader::kRepeat_TileMode == fTileMode);
     }
 
     if (fDstToIndexClass != kPerspective_MatrixClass) {
@@ -286,16 +242,16 @@
             dy = fDstToIndex.getSkewY();
         }
 
-        TwoPtRadialContext rec(twoPointConicalGradient.fRec, fx, fy, dx, dy);
-        (*shadeProc)(&rec, dstC, cache, toggle, count);
+        fRec.setup(fx, fy, dx, dy);
+        (*shadeProc)(&fRec, dstC, cache, toggle, count);
     } else {    // perspective case
         SkScalar dstX = SkIntToScalar(x) + SK_ScalarHalf;
         SkScalar dstY = SkIntToScalar(y) + SK_ScalarHalf;
         for (; count > 0; --count) {
             SkPoint srcPt;
             dstProc(fDstToIndex, dstX, dstY, &srcPt);
-            TwoPtRadialContext rec(twoPointConicalGradient.fRec, srcPt.fX, srcPt.fY, 0, 0);
-            (*shadeProc)(&rec, dstC, cache, toggle, 1);
+            fRec.setup(srcPt.fX, srcPt.fY, 0, 0);
+            (*shadeProc)(&fRec, dstC, cache, toggle, 1);
 
             dstX += SK_Scalar1;
             toggle = next_dither_toggle(toggle);
@@ -304,6 +260,23 @@
     }
 }
 
+bool SkTwoPointConicalGradient::setContext(const SkBitmap& device,
+                                           const SkPaint& paint,
+                                           const SkMatrix& matrix) {
+    if (!this->INHERITED::setContext(device, paint, matrix)) {
+        return false;
+    }
+
+    // we don't have a span16 proc
+    fFlags &= ~kHasSpan16_Flag;
+
+    // in general, we might discard based on computed-radius, so clear
+    // this flag (todo: sometimes we can detect that we never discard...)
+    fFlags &= ~kOpaqueAlpha_Flag;
+
+    return true;
+}
+
 SkShader::BitmapType SkTwoPointConicalGradient::asABitmap(
     SkBitmap* bitmap, SkMatrix* matrix, SkShader::TileMode* xy) const {
     SkPoint diff = fCenter2 - fCenter1;
diff --git a/src/effects/gradients/SkTwoPointConicalGradient.h b/src/effects/gradients/SkTwoPointConicalGradient.h
index 80aa6fa..b2e258e 100644
--- a/src/effects/gradients/SkTwoPointConicalGradient.h
+++ b/src/effects/gradients/SkTwoPointConicalGradient.h
@@ -11,8 +11,6 @@
 
 #include "SkGradientShaderPriv.h"
 
-// TODO(dominikg): Worth making it truly immutable (i.e. set values in constructor)?
-// Should only be initialized once via init(). Immutable afterwards.
 struct TwoPtRadial {
     enum {
         kDontDrawT  = 0x80000000
@@ -31,6 +29,13 @@
               const SkPoint& center1, SkScalar rad1,
               bool flipped);
 
+    // used by setup and nextT
+    float   fRelX, fRelY, fIncX, fIncY;
+    float   fB, fDB;
+
+    void setup(SkScalar fx, SkScalar fy, SkScalar dfx, SkScalar dfy);
+    SkFixed nextT();
+
     static bool DontDrawT(SkFixed t) {
         return kDontDrawT == (uint32_t)t;
     }
@@ -46,24 +51,11 @@
                               const SkPoint& end, SkScalar endRadius,
                               bool flippedGrad, const Descriptor&);
 
-
-    virtual SkShader::Context* createContext(const SkBitmap&, const SkPaint&, const SkMatrix&,
-                                             void* storage) const SK_OVERRIDE;
-    virtual size_t contextSize() const SK_OVERRIDE;
-
-    class TwoPointConicalGradientContext : public SkGradientShaderBase::GradientShaderBaseContext {
-    public:
-        TwoPointConicalGradientContext(const SkTwoPointConicalGradient& shader,
-                                       const SkBitmap& device,
-                                       const SkPaint& paint,
-                                       const SkMatrix& matrix);
-        ~TwoPointConicalGradientContext() {}
-
-        virtual void shadeSpan(int x, int y, SkPMColor dstC[], int count) SK_OVERRIDE;
-
-    private:
-        typedef SkGradientShaderBase::GradientShaderBaseContext INHERITED;
-    };
+    virtual void shadeSpan(int x, int y, SkPMColor* dstCParam,
+                           int count) SK_OVERRIDE;
+    virtual bool setContext(const SkBitmap& device,
+                            const SkPaint& paint,
+                            const SkMatrix& matrix) SK_OVERRIDE;
 
     virtual BitmapType asABitmap(SkBitmap* bitmap,
                                  SkMatrix* matrix,
diff --git a/src/effects/gradients/SkTwoPointRadialGradient.cpp b/src/effects/gradients/SkTwoPointRadialGradient.cpp
index a598c6e..e1359b1 100644
--- a/src/effects/gradients/SkTwoPointRadialGradient.cpp
+++ b/src/effects/gradients/SkTwoPointRadialGradient.cpp
@@ -220,60 +220,23 @@
     return kRadial2_GradientType;
 }
 
-size_t SkTwoPointRadialGradient::contextSize() const {
-    return sizeof(TwoPointRadialGradientContext);
-}
-
-bool SkTwoPointRadialGradient::validContext(const SkBitmap& device, const SkPaint& paint,
-                                            const SkMatrix& matrix, SkMatrix* totalInverse) const {
-    // For now, we might have divided by zero, so detect that.
-    if (0 == fDiffRadius) {
-        return false;
-    }
-
-    return this->INHERITED::validContext(device, paint, matrix, totalInverse);
-}
-
-SkShader::Context* SkTwoPointRadialGradient::createContext(
-        const SkBitmap& device, const SkPaint& paint,
-        const SkMatrix& matrix, void* storage) const {
-    if (!this->validContext(device, paint, matrix)) {
-        return NULL;
-    }
-
-    return SkNEW_PLACEMENT_ARGS(storage, TwoPointRadialGradientContext,
-                                (*this, device, paint, matrix));
-}
-
-SkTwoPointRadialGradient::TwoPointRadialGradientContext::TwoPointRadialGradientContext(
-        const SkTwoPointRadialGradient& shader, const SkBitmap& device,
-        const SkPaint& paint, const SkMatrix& matrix)
-    : INHERITED(shader, device, paint, matrix)
-{
-    // we don't have a span16 proc
-    fFlags &= ~kHasSpan16_Flag;
-}
-
-void SkTwoPointRadialGradient::TwoPointRadialGradientContext::shadeSpan(
-        int x, int y, SkPMColor* dstCParam, int count) {
+void SkTwoPointRadialGradient::shadeSpan(int x, int y, SkPMColor* dstCParam,
+                                         int count) {
     SkASSERT(count > 0);
 
-    const SkTwoPointRadialGradient& twoPointRadialGradient =
-            static_cast<const SkTwoPointRadialGradient&>(fShader);
-
     SkPMColor* SK_RESTRICT dstC = dstCParam;
 
     // Zero difference between radii:  fill with transparent black.
-    if (twoPointRadialGradient.fDiffRadius == 0) {
+    if (fDiffRadius == 0) {
       sk_bzero(dstC, count * sizeof(*dstC));
       return;
     }
     SkMatrix::MapXYProc dstProc = fDstToIndexProc;
-    TileProc            proc = twoPointRadialGradient.fTileProc;
-    const SkPMColor* SK_RESTRICT cache = fCache->getCache32();
+    TileProc            proc = fTileProc;
+    const SkPMColor* SK_RESTRICT cache = this->getCache32();
 
-    SkScalar foura = twoPointRadialGradient.fA * 4;
-    bool posRoot = twoPointRadialGradient.fDiffRadius < 0;
+    SkScalar foura = fA * 4;
+    bool posRoot = fDiffRadius < 0;
     if (fDstToIndexClass != kPerspective_MatrixClass) {
         SkPoint srcPt;
         dstProc(fDstToIndex, SkIntToScalar(x) + SK_ScalarHalf,
@@ -291,23 +254,21 @@
             dx = fDstToIndex.getScaleX();
             dy = fDstToIndex.getSkewY();
         }
-        SkScalar b = (SkScalarMul(twoPointRadialGradient.fDiff.fX, fx) +
-                     SkScalarMul(twoPointRadialGradient.fDiff.fY, fy) -
-                     twoPointRadialGradient.fStartRadius) * 2;
-        SkScalar db = (SkScalarMul(twoPointRadialGradient.fDiff.fX, dx) +
-                      SkScalarMul(twoPointRadialGradient.fDiff.fY, dy)) * 2;
+        SkScalar b = (SkScalarMul(fDiff.fX, fx) +
+                     SkScalarMul(fDiff.fY, fy) - fStartRadius) * 2;
+        SkScalar db = (SkScalarMul(fDiff.fX, dx) +
+                      SkScalarMul(fDiff.fY, dy)) * 2;
 
         TwoPointRadialShadeProc shadeProc = shadeSpan_twopoint_repeat;
-        if (SkShader::kClamp_TileMode == twoPointRadialGradient.fTileMode) {
+        if (SkShader::kClamp_TileMode == fTileMode) {
             shadeProc = shadeSpan_twopoint_clamp;
-        } else if (SkShader::kMirror_TileMode == twoPointRadialGradient.fTileMode) {
+        } else if (SkShader::kMirror_TileMode == fTileMode) {
             shadeProc = shadeSpan_twopoint_mirror;
         } else {
-            SkASSERT(SkShader::kRepeat_TileMode == twoPointRadialGradient.fTileMode);
+            SkASSERT(SkShader::kRepeat_TileMode == fTileMode);
         }
         (*shadeProc)(fx, dx, fy, dy, b, db,
-                     twoPointRadialGradient.fSr2D2, foura,
-                     twoPointRadialGradient.fOneOverTwoA, posRoot,
+                     fSr2D2, foura, fOneOverTwoA, posRoot,
                      dstC, cache, count);
     } else {    // perspective case
         SkScalar dstX = SkIntToScalar(x);
@@ -317,11 +278,10 @@
             dstProc(fDstToIndex, dstX, dstY, &srcPt);
             SkScalar fx = srcPt.fX;
             SkScalar fy = srcPt.fY;
-            SkScalar b = (SkScalarMul(twoPointRadialGradient.fDiff.fX, fx) +
-                         SkScalarMul(twoPointRadialGradient.fDiff.fY, fy) -
-                         twoPointRadialGradient.fStartRadius) * 2;
-            SkFixed t = two_point_radial(b, fx, fy, twoPointRadialGradient.fSr2D2, foura,
-                                         twoPointRadialGradient.fOneOverTwoA, posRoot);
+            SkScalar b = (SkScalarMul(fDiff.fX, fx) +
+                         SkScalarMul(fDiff.fY, fy) - fStartRadius) * 2;
+            SkFixed t = two_point_radial(b, fx, fy, fSr2D2, foura,
+                                         fOneOverTwoA, posRoot);
             SkFixed index = proc(t);
             SkASSERT(index <= 0xFFFF);
             *dstC++ = cache[index >> SkGradientShaderBase::kCache32Shift];
@@ -330,6 +290,23 @@
     }
 }
 
+bool SkTwoPointRadialGradient::setContext( const SkBitmap& device,
+                                          const SkPaint& paint,
+                                          const SkMatrix& matrix){
+    // For now, we might have divided by zero, so detect that
+    if (0 == fDiffRadius) {
+        return false;
+    }
+
+    if (!this->INHERITED::setContext(device, paint, matrix)) {
+        return false;
+    }
+
+    // we don't have a span16 proc
+    fFlags &= ~kHasSpan16_Flag;
+    return true;
+}
+
 #ifndef SK_IGNORE_TO_STRING
 void SkTwoPointRadialGradient::toString(SkString* str) const {
     str->append("SkTwoPointRadialGradient: (");
diff --git a/src/effects/gradients/SkTwoPointRadialGradient.h b/src/effects/gradients/SkTwoPointRadialGradient.h
index 9ba89f2..ee1b49e 100644
--- a/src/effects/gradients/SkTwoPointRadialGradient.h
+++ b/src/effects/gradients/SkTwoPointRadialGradient.h
@@ -23,26 +23,11 @@
     virtual GradientType asAGradient(GradientInfo* info) const SK_OVERRIDE;
     virtual GrEffectRef* asNewEffect(GrContext* context, const SkPaint&) const SK_OVERRIDE;
 
-
-    virtual size_t contextSize() const SK_OVERRIDE;
-    virtual bool validContext(const SkBitmap&, const SkPaint&,
-                              const SkMatrix&, SkMatrix* totalInverse = NULL) const SK_OVERRIDE;
-    virtual SkShader::Context* createContext(const SkBitmap&, const SkPaint&, const SkMatrix&,
-                                             void* storage) const SK_OVERRIDE;
-
-    class TwoPointRadialGradientContext : public SkGradientShaderBase::GradientShaderBaseContext {
-    public:
-        TwoPointRadialGradientContext(const SkTwoPointRadialGradient& shader,
-                                      const SkBitmap& device,
-                                      const SkPaint& paint,
-                                      const SkMatrix& matrix);
-        ~TwoPointRadialGradientContext() {}
-
-        virtual void shadeSpan(int x, int y, SkPMColor dstC[], int count) SK_OVERRIDE;
-
-    private:
-        typedef SkGradientShaderBase::GradientShaderBaseContext INHERITED;
-    };
+    virtual void shadeSpan(int x, int y, SkPMColor* dstCParam,
+                           int count) SK_OVERRIDE;
+    virtual bool setContext(const SkBitmap& device,
+                            const SkPaint& paint,
+                            const SkMatrix& matrix) SK_OVERRIDE;
 
     SkScalar getCenterX1() const { return fDiff.length(); }
     SkScalar getStartRadius() const { return fStartRadius; }
@@ -56,6 +41,7 @@
     virtual void flatten(SkWriteBuffer& buffer) const SK_OVERRIDE;
 
 private:
+    typedef SkGradientShaderBase INHERITED;
     const SkPoint fCenter1;
     const SkPoint fCenter2;
     const SkScalar fRadius1;
@@ -64,8 +50,6 @@
     SkScalar fStartRadius, fDiffRadius, fSr2D2, fA, fOneOverTwoA;
 
     void init();
-
-    typedef SkGradientShaderBase INHERITED;
 };
 
 #endif