Revert of Revert of Extract most of the mutable state of SkShader into a separate Context object. (https://codereview.chromium.org/249643002/)

Reason for revert:
Chromium side change landed along side DEPS roll that includes r14323.

Original issue's description:
> Revert of Extract most of the mutable state of SkShader into a separate Context object. (https://codereview.chromium.org/207683004/)
>
> Reason for revert:
> This is blocking the DEPS roll into Chromium. Failures can be seen here:
>
> http://build.chromium.org/p/tryserver.chromium/builders/android_dbg/builds/174333
>
> Original issue's description:
> > Extract most of the mutable state of SkShader into a separate Context object.
> >
> > SkShader currently stores some state during draw calls via setContext(...).
> > Move that mutable state into a separate SkShader::Context class that is
> > constructed on demand for the duration of the draw.
> >
> > Calls to setContext() are replaced with createContext() which returns a context
> > corresponding to the shader object or NULL if the parameters to createContext
> > are invalid.
> >
> > TEST=out/Debug/dm
> > BUG=skia:1976
> >
> > Committed: http://code.google.com/p/skia/source/detail?r=14216
> >
> > Committed: http://code.google.com/p/skia/source/detail?r=14323
>
> TBR=scroggo@google.com,skyostil@chromium.org,tomhudson@chromium.org,senorblanco@chromium.org,reed@google.com,bungeman@google.com,dominikg@chromium.org
> NOTREECHECKS=true
> NOTRY=true
> BUG=skia:1976
>
> Committed: http://code.google.com/p/skia/source/detail?r=14326

R=scroggo@google.com, skyostil@chromium.org, tomhudson@chromium.org, senorblanco@chromium.org, reed@google.com, bungeman@google.com, dominikg@chromium.org
TBR=bungeman@google.com, dominikg@chromium.org, reed@google.com, scroggo@google.com, senorblanco@chromium.org, skyostil@chromium.org, tomhudson@chromium.org
NOTREECHECKS=true
NOTRY=true
BUG=skia:1976

Author: bsalomon@google.com

Review URL: https://codereview.chromium.org/246403013

git-svn-id: http://skia.googlecode.com/svn/trunk@14328 2bbb7eff-a529-9590-31e7-b0007b416f81
diff --git a/src/core/SkBitmapProcShader.cpp b/src/core/SkBitmapProcShader.cpp
index a397b78..5f5eb18 100644
--- a/src/core/SkBitmapProcShader.cpp
+++ b/src/core/SkBitmapProcShader.cpp
@@ -34,18 +34,16 @@
 SkBitmapProcShader::SkBitmapProcShader(const SkBitmap& src,
                                        TileMode tmx, TileMode tmy) {
     fRawBitmap = src;
-    fState.fTileModeX = (uint8_t)tmx;
-    fState.fTileModeY = (uint8_t)tmy;
-    fFlags = 0; // computed in setContext
+    fTileModeX = (uint8_t)tmx;
+    fTileModeY = (uint8_t)tmy;
 }
 
 SkBitmapProcShader::SkBitmapProcShader(SkReadBuffer& buffer)
         : INHERITED(buffer) {
     buffer.readBitmap(&fRawBitmap);
     fRawBitmap.setImmutable();
-    fState.fTileModeX = buffer.readUInt();
-    fState.fTileModeY = buffer.readUInt();
-    fFlags = 0; // computed in setContext
+    fTileModeX = buffer.readUInt();
+    fTileModeY = buffer.readUInt();
 }
 
 SkShader::BitmapType SkBitmapProcShader::asABitmap(SkBitmap* texture,
@@ -58,8 +56,8 @@
         texM->reset();
     }
     if (xy) {
-        xy[0] = (TileMode)fState.fTileModeX;
-        xy[1] = (TileMode)fState.fTileModeY;
+        xy[0] = (TileMode)fTileModeX;
+        xy[1] = (TileMode)fTileModeY;
     }
     return kDefault_BitmapType;
 }
@@ -68,8 +66,8 @@
     this->INHERITED::flatten(buffer);
 
     buffer.writeBitmap(fRawBitmap);
-    buffer.writeUInt(fState.fTileModeX);
-    buffer.writeUInt(fState.fTileModeY);
+    buffer.writeUInt(fTileModeX);
+    buffer.writeUInt(fTileModeY);
 }
 
 static bool only_scale_and_translate(const SkMatrix& matrix) {
@@ -98,25 +96,67 @@
     return true;
 }
 
-bool SkBitmapProcShader::setContext(const SkBitmap& device,
-                                    const SkPaint& paint,
-                                    const SkMatrix& matrix) {
+bool SkBitmapProcShader::validInternal(const SkBitmap& device,
+                                       const SkPaint& paint,
+                                       const SkMatrix& matrix,
+                                       SkMatrix* totalInverse,
+                                       SkBitmapProcState* state) const {
     if (!fRawBitmap.getTexture() && !valid_for_drawing(fRawBitmap)) {
         return false;
     }
 
-    // do this first, so we have a correct inverse matrix
-    if (!this->INHERITED::setContext(device, paint, matrix)) {
+    // Make sure we can use totalInverse as a cache.
+    SkMatrix totalInverseLocal;
+    if (NULL == totalInverse) {
+        totalInverse = &totalInverseLocal;
+    }
+
+    // Do this first, so we know the matrix can be inverted.
+    if (!this->INHERITED::validContext(device, paint, matrix, totalInverse)) {
         return false;
     }
 
-    fState.fOrigBitmap = fRawBitmap;
-    if (!fState.chooseProcs(this->getTotalInverse(), paint)) {
-        this->INHERITED::endContext();
-        return false;
+    SkASSERT(state);
+    state->fTileModeX = fTileModeX;
+    state->fTileModeY = fTileModeY;
+    state->fOrigBitmap = fRawBitmap;
+    return state->chooseProcs(*totalInverse, paint);
+}
+
+bool SkBitmapProcShader::validContext(const SkBitmap& device,
+                                      const SkPaint& paint,
+                                      const SkMatrix& matrix,
+                                      SkMatrix* totalInverse) const {
+    SkBitmapProcState state;
+    return this->validInternal(device, paint, matrix, totalInverse, &state);
+}
+
+SkShader::Context* SkBitmapProcShader::createContext(const SkBitmap& device, const SkPaint& paint,
+                                                     const SkMatrix& matrix, void* storage) const {
+    void* stateStorage = (char*)storage + sizeof(BitmapProcShaderContext);
+    SkBitmapProcState* state = SkNEW_PLACEMENT(stateStorage, SkBitmapProcState);
+    if (!this->validInternal(device, paint, matrix, NULL, state)) {
+        state->~SkBitmapProcState();
+        return NULL;
     }
 
-    const SkBitmap& bitmap = *fState.fBitmap;
+    return SkNEW_PLACEMENT_ARGS(storage, BitmapProcShaderContext,
+                                (*this, device, paint, matrix, state));
+}
+
+size_t SkBitmapProcShader::contextSize() const {
+    // The SkBitmapProcState is stored outside of the context object, with the context holding
+    // a pointer to it.
+    return sizeof(BitmapProcShaderContext) + sizeof(SkBitmapProcState);
+}
+
+SkBitmapProcShader::BitmapProcShaderContext::BitmapProcShaderContext(
+        const SkBitmapProcShader& shader, const SkBitmap& device,
+        const SkPaint& paint, const SkMatrix& matrix, SkBitmapProcState* state)
+    : INHERITED(shader, device, paint, matrix)
+    , fState(state)
+{
+    const SkBitmap& bitmap = *fState->fBitmap;
     bool bitmapIsOpaque = bitmap.isOpaque();
 
     // update fFlags
@@ -157,12 +197,12 @@
     }
 
     fFlags = flags;
-    return true;
 }
 
-void SkBitmapProcShader::endContext() {
-    fState.endContext();
-    this->INHERITED::endContext();
+SkBitmapProcShader::BitmapProcShaderContext::~BitmapProcShaderContext() {
+    // The bitmap proc state has been created outside of the context on memory that will be freed
+    // elsewhere. Only call the destructor but leave the freeing of the memory to the caller.
+    fState->~SkBitmapProcState();
 }
 
 #define BUF_MAX     128
@@ -176,8 +216,9 @@
     #define TEST_BUFFER_EXTRA   0
 #endif
 
-void SkBitmapProcShader::shadeSpan(int x, int y, SkPMColor dstC[], int count) {
-    const SkBitmapProcState& state = fState;
+void SkBitmapProcShader::BitmapProcShaderContext::shadeSpan(int x, int y, SkPMColor dstC[],
+                                                            int count) {
+    const SkBitmapProcState& state = *fState;
     if (state.getShaderProc32()) {
         state.getShaderProc32()(state, x, y, dstC, count);
         return;
@@ -186,7 +227,7 @@
     uint32_t buffer[BUF_MAX + TEST_BUFFER_EXTRA];
     SkBitmapProcState::MatrixProc   mproc = state.getMatrixProc();
     SkBitmapProcState::SampleProc32 sproc = state.getSampleProc32();
-    int max = fState.maxCountForBufferSize(sizeof(buffer[0]) * BUF_MAX);
+    int max = state.maxCountForBufferSize(sizeof(buffer[0]) * BUF_MAX);
 
     SkASSERT(state.fBitmap->getPixels());
     SkASSERT(state.fBitmap->pixelRef() == NULL ||
@@ -220,16 +261,17 @@
     }
 }
 
-SkShader::ShadeProc SkBitmapProcShader::asAShadeProc(void** ctx) {
-    if (fState.getShaderProc32()) {
-        *ctx = &fState;
-        return (ShadeProc)fState.getShaderProc32();
+SkShader::Context::ShadeProc SkBitmapProcShader::BitmapProcShaderContext::asAShadeProc(void** ctx) {
+    if (fState->getShaderProc32()) {
+        *ctx = fState;
+        return (ShadeProc)fState->getShaderProc32();
     }
     return NULL;
 }
 
-void SkBitmapProcShader::shadeSpan16(int x, int y, uint16_t dstC[], int count) {
-    const SkBitmapProcState& state = fState;
+void SkBitmapProcShader::BitmapProcShaderContext::shadeSpan16(int x, int y, uint16_t dstC[],
+                                                              int count) {
+    const SkBitmapProcState& state = *fState;
     if (state.getShaderProc16()) {
         state.getShaderProc16()(state, x, y, dstC, count);
         return;
@@ -238,7 +280,7 @@
     uint32_t buffer[BUF_MAX];
     SkBitmapProcState::MatrixProc   mproc = state.getMatrixProc();
     SkBitmapProcState::SampleProc16 sproc = state.getSampleProc16();
-    int max = fState.maxCountForBufferSize(sizeof(buffer));
+    int max = state.maxCountForBufferSize(sizeof(buffer));
 
     SkASSERT(state.fBitmap->getPixels());
     SkASSERT(state.fBitmap->pixelRef() == NULL ||
@@ -342,8 +384,8 @@
     str->append("BitmapShader: (");
 
     str->appendf("(%s, %s)",
-                 gTileModeName[fState.fTileModeX],
-                 gTileModeName[fState.fTileModeY]);
+                 gTileModeName[fTileModeX],
+                 gTileModeName[fTileModeY]);
 
     str->append(" ");
     fRawBitmap.toString(str);
@@ -384,8 +426,8 @@
     matrix.preConcat(lmInverse);
 
     SkShader::TileMode tm[] = {
-        (TileMode)fState.fTileModeX,
-        (TileMode)fState.fTileModeY,
+        (TileMode)fTileModeX,
+        (TileMode)fTileModeY,
     };
 
     // Must set wrap and filter on the sampler before requesting a texture. In two places below
diff --git a/src/core/SkBitmapProcShader.h b/src/core/SkBitmapProcShader.h
index 8e225a5..e0c78b8 100644
--- a/src/core/SkBitmapProcShader.h
+++ b/src/core/SkBitmapProcShader.h
@@ -20,14 +20,16 @@
 
     // overrides from SkShader
     virtual bool isOpaque() const SK_OVERRIDE;
-    virtual bool setContext(const SkBitmap&, const SkPaint&, const SkMatrix&) SK_OVERRIDE;
-    virtual void endContext() SK_OVERRIDE;
-    virtual uint32_t getFlags() SK_OVERRIDE { return fFlags; }
-    virtual void shadeSpan(int x, int y, SkPMColor dstC[], int count) SK_OVERRIDE;
-    virtual ShadeProc asAShadeProc(void** ctx) SK_OVERRIDE;
-    virtual void shadeSpan16(int x, int y, uint16_t dstC[], int count) SK_OVERRIDE;
     virtual BitmapType asABitmap(SkBitmap*, SkMatrix*, TileMode*) const SK_OVERRIDE;
 
+    virtual bool validContext(const SkBitmap& device,
+                              const SkPaint& paint,
+                              const SkMatrix& matrix,
+                              SkMatrix* totalInverse = NULL) const SK_OVERRIDE;
+    virtual SkShader::Context* createContext(const SkBitmap&, const SkPaint&,
+                                             const SkMatrix&, void* storage) const SK_OVERRIDE;
+    virtual size_t contextSize() const SK_OVERRIDE;
+
     static bool CanDo(const SkBitmap&, TileMode tx, TileMode ty);
 
     SK_TO_STRING_OVERRIDE()
@@ -37,22 +39,54 @@
     GrEffectRef* asNewEffect(GrContext*, const SkPaint&) const SK_OVERRIDE;
 #endif
 
+    class BitmapProcShaderContext : public SkShader::Context {
+    public:
+        // The context takes ownership of the state. It will call its destructor
+        // but will NOT free the memory.
+        BitmapProcShaderContext(const SkBitmapProcShader& shader,
+                                const SkBitmap& device,
+                                const SkPaint& paint,
+                                const SkMatrix& matrix,
+                                SkBitmapProcState* state);
+        virtual ~BitmapProcShaderContext();
+
+        virtual void shadeSpan(int x, int y, SkPMColor dstC[], int count) SK_OVERRIDE;
+        virtual ShadeProc asAShadeProc(void** ctx) SK_OVERRIDE;
+        virtual void shadeSpan16(int x, int y, uint16_t dstC[], int count) SK_OVERRIDE;
+
+        virtual uint32_t getFlags() const SK_OVERRIDE { return fFlags; }
+
+    private:
+        SkBitmapProcState*  fState;
+        uint32_t            fFlags;
+
+        typedef SkShader::Context INHERITED;
+    };
+
 protected:
     SkBitmapProcShader(SkReadBuffer& );
     virtual void flatten(SkWriteBuffer&) const SK_OVERRIDE;
 
-    SkBitmap          fRawBitmap;   // experimental for RLE encoding
-    SkBitmapProcState fState;
-    uint32_t          fFlags;
+    SkBitmap    fRawBitmap;   // experimental for RLE encoding
+    uint8_t     fTileModeX, fTileModeY;
 
 private:
+    bool validInternal(const SkBitmap& device, const SkPaint& paint,
+                       const SkMatrix& matrix, SkMatrix* totalInverse,
+                       SkBitmapProcState* state) const;
+
     typedef SkShader INHERITED;
 };
 
-// Commonly used allocator. It currently is only used to allocate up to 2 objects. The total
-// bytes requested is calculated using one of our large shaders plus the size of an Sk3DBlitter
-// in SkDraw.cpp
-typedef SkSmallAllocator<2, sizeof(SkBitmapProcShader) + sizeof(void*) * 2> SkTBlitterAllocator;
+// Commonly used allocator. It currently is only used to allocate up to 3 objects. The total
+// bytes requested is calculated using one of our large shaders, its context size plus the size of
+// an Sk3DBlitter in SkDraw.cpp
+// Note that some contexts may contain other contexts (e.g. for compose shaders), but we've not
+// yet found a situation where the size below isn't big enough.
+typedef SkSmallAllocator<3, sizeof(SkBitmapProcShader) +
+                            sizeof(SkBitmapProcShader::BitmapProcShaderContext) +
+                            sizeof(SkBitmapProcState) +
+                            sizeof(void*) * 2> SkTBlitterAllocator;
 
 // If alloc is non-NULL, it will be used to allocate the returned SkShader, and MUST outlive
 // the SkShader.
diff --git a/src/core/SkBitmapProcState.cpp b/src/core/SkBitmapProcState.cpp
index be87d83..eecfbbc 100644
--- a/src/core/SkBitmapProcState.cpp
+++ b/src/core/SkBitmapProcState.cpp
@@ -360,17 +360,6 @@
     return true;
 }
 
-void SkBitmapProcState::endContext() {
-    SkDELETE(fBitmapFilter);
-    fBitmapFilter = NULL;
-    fScaledBitmap.reset();
-
-    if (fScaledCacheID) {
-        SkScaledImageCache::Unlock(fScaledCacheID);
-        fScaledCacheID = NULL;
-    }
-}
-
 SkBitmapProcState::~SkBitmapProcState() {
     if (fScaledCacheID) {
         SkScaledImageCache::Unlock(fScaledCacheID);
@@ -399,6 +388,7 @@
     }
     // The above logic should have always assigned fBitmap, but in case it
     // didn't, we check for that now...
+    // TODO(dominikg): Ask humper@ if we can just use an SkASSERT(fBitmap)?
     if (NULL == fBitmap) {
         return false;
     }
@@ -487,6 +477,7 @@
     // shader will perform.
 
     fMatrixProc = this->chooseMatrixProc(trivialMatrix);
+    // TODO(dominikg): SkASSERT(fMatrixProc) instead? chooseMatrixProc never returns NULL.
     if (NULL == fMatrixProc) {
         return false;
     }
@@ -528,6 +519,7 @@
                 fPaintPMColor = SkPreMultiplyColor(paint.getColor());
                 break;
             default:
+                // TODO(dominikg): Should we ever get here? SkASSERT(false) instead?
                 return false;
         }
 
diff --git a/src/core/SkBitmapProcState.h b/src/core/SkBitmapProcState.h
index d5a354e..663bcb8 100644
--- a/src/core/SkBitmapProcState.h
+++ b/src/core/SkBitmapProcState.h
@@ -89,12 +89,6 @@
     uint8_t             fTileModeY;         // CONSTRUCTOR
     uint8_t             fFilterLevel;       // chooseProcs
 
-    /** The shader will let us know when we can release some of our resources
-      * like scaled bitmaps.
-      */
-
-    void endContext();
-
     /** Platforms implement this, and can optionally overwrite only the
         following fields:
 
diff --git a/src/core/SkBlitter.cpp b/src/core/SkBlitter.cpp
index 52a05ed..7243f52 100644
--- a/src/core/SkBlitter.cpp
+++ b/src/core/SkBlitter.cpp
@@ -26,6 +26,15 @@
 
 bool SkBlitter::isNullBlitter() const { return false; }
 
+bool SkBlitter::resetShaderContext(const SkBitmap& device, const SkPaint& paint,
+                                   const SkMatrix& matrix) {
+    return true;
+}
+
+SkShader::Context* SkBlitter::getShaderContext() const {
+    return NULL;
+}
+
 const SkBitmap* SkBlitter::justAnOpaqueColor(uint32_t* value) {
     return NULL;
 }
@@ -568,102 +577,149 @@
 public:
     Sk3DShader(SkShader* proxy) : fProxy(proxy) {
         SkSafeRef(proxy);
-        fMask = NULL;
     }
 
     virtual ~Sk3DShader() {
         SkSafeUnref(fProxy);
     }
 
-    void setMask(const SkMask* mask) { fMask = mask; }
+    virtual size_t contextSize() const SK_OVERRIDE {
+        size_t size = sizeof(Sk3DShaderContext);
+        if (fProxy) {
+            size += fProxy->contextSize();
+        }
+        return size;
+    }
 
-    virtual bool setContext(const SkBitmap& device, const SkPaint& paint,
-                            const SkMatrix& matrix) SK_OVERRIDE {
-        if (!this->INHERITED::setContext(device, paint, matrix)) {
+    virtual bool validContext(const SkBitmap& device, const SkPaint& paint,
+                              const SkMatrix& matrix, SkMatrix* totalInverse = NULL) const
+            SK_OVERRIDE
+    {
+        if (!this->INHERITED::validContext(device, paint, matrix, totalInverse)) {
             return false;
         }
         if (fProxy) {
-            if (!fProxy->setContext(device, paint, matrix)) {
-                // must keep our set/end context calls balanced
-                this->INHERITED::endContext();
-                return false;
-            }
-        } else {
-            fPMColor = SkPreMultiplyColor(paint.getColor());
+            return fProxy->validContext(device, paint, matrix);
         }
         return true;
     }
 
-    virtual void endContext() SK_OVERRIDE {
-        if (fProxy) {
-            fProxy->endContext();
+    virtual SkShader::Context* createContext(const SkBitmap& device,
+                                             const SkPaint& paint,
+                                             const SkMatrix& matrix,
+                                             void* storage) const SK_OVERRIDE
+    {
+        if (!this->validContext(device, paint, matrix)) {
+            return NULL;
         }
-        this->INHERITED::endContext();
+
+        SkShader::Context* proxyContext;
+        if (fProxy) {
+            char* proxyContextStorage = (char*) storage + sizeof(Sk3DShaderContext);
+            proxyContext = fProxy->createContext(device, paint, matrix, proxyContextStorage);
+            SkASSERT(proxyContext);
+        } else {
+            proxyContext = NULL;
+        }
+        return SkNEW_PLACEMENT_ARGS(storage, Sk3DShaderContext, (*this, device, paint, matrix,
+                                                                 proxyContext));
     }
 
-    virtual void shadeSpan(int x, int y, SkPMColor span[], int count) SK_OVERRIDE {
-        if (fProxy) {
-            fProxy->shadeSpan(x, y, span, count);
-        }
-
-        if (fMask == NULL) {
-            if (fProxy == NULL) {
-                sk_memset32(span, fPMColor, count);
+    class Sk3DShaderContext : public SkShader::Context {
+    public:
+        // Calls proxyContext's destructor but will NOT free its memory.
+        Sk3DShaderContext(const Sk3DShader& shader, const SkBitmap& device, const SkPaint& paint,
+                          const SkMatrix& matrix, SkShader::Context* proxyContext)
+            : INHERITED(shader, device, paint, matrix)
+            , fMask(NULL)
+            , fProxyContext(proxyContext)
+        {
+            if (!fProxyContext) {
+                fPMColor = SkPreMultiplyColor(paint.getColor());
             }
-            return;
         }
 
-        SkASSERT(fMask->fBounds.contains(x, y));
-        SkASSERT(fMask->fBounds.contains(x + count - 1, y));
+        virtual ~Sk3DShaderContext() {
+            if (fProxyContext) {
+                fProxyContext->~Context();
+            }
+        }
 
-        size_t          size = fMask->computeImageSize();
-        const uint8_t*  alpha = fMask->getAddr8(x, y);
-        const uint8_t*  mulp = alpha + size;
-        const uint8_t*  addp = mulp + size;
+        void setMask(const SkMask* mask) { fMask = mask; }
 
-        if (fProxy) {
-            for (int i = 0; i < count; i++) {
-                if (alpha[i]) {
-                    SkPMColor c = span[i];
-                    if (c) {
-                        unsigned a = SkGetPackedA32(c);
-                        unsigned r = SkGetPackedR32(c);
-                        unsigned g = SkGetPackedG32(c);
-                        unsigned b = SkGetPackedB32(c);
+        virtual void shadeSpan(int x, int y, SkPMColor span[], int count) SK_OVERRIDE {
+            if (fProxyContext) {
+                fProxyContext->shadeSpan(x, y, span, count);
+            }
 
+            if (fMask == NULL) {
+                if (fProxyContext == NULL) {
+                    sk_memset32(span, fPMColor, count);
+                }
+                return;
+            }
+
+            SkASSERT(fMask->fBounds.contains(x, y));
+            SkASSERT(fMask->fBounds.contains(x + count - 1, y));
+
+            size_t          size = fMask->computeImageSize();
+            const uint8_t*  alpha = fMask->getAddr8(x, y);
+            const uint8_t*  mulp = alpha + size;
+            const uint8_t*  addp = mulp + size;
+
+            if (fProxyContext) {
+                for (int i = 0; i < count; i++) {
+                    if (alpha[i]) {
+                        SkPMColor c = span[i];
+                        if (c) {
+                            unsigned a = SkGetPackedA32(c);
+                            unsigned r = SkGetPackedR32(c);
+                            unsigned g = SkGetPackedG32(c);
+                            unsigned b = SkGetPackedB32(c);
+
+                            unsigned mul = SkAlpha255To256(mulp[i]);
+                            unsigned add = addp[i];
+
+                            r = SkFastMin32(SkAlphaMul(r, mul) + add, a);
+                            g = SkFastMin32(SkAlphaMul(g, mul) + add, a);
+                            b = SkFastMin32(SkAlphaMul(b, mul) + add, a);
+
+                            span[i] = SkPackARGB32(a, r, g, b);
+                        }
+                    } else {
+                        span[i] = 0;
+                    }
+                }
+            } else {    // color
+                unsigned a = SkGetPackedA32(fPMColor);
+                unsigned r = SkGetPackedR32(fPMColor);
+                unsigned g = SkGetPackedG32(fPMColor);
+                unsigned b = SkGetPackedB32(fPMColor);
+                for (int i = 0; i < count; i++) {
+                    if (alpha[i]) {
                         unsigned mul = SkAlpha255To256(mulp[i]);
                         unsigned add = addp[i];
 
-                        r = SkFastMin32(SkAlphaMul(r, mul) + add, a);
-                        g = SkFastMin32(SkAlphaMul(g, mul) + add, a);
-                        b = SkFastMin32(SkAlphaMul(b, mul) + add, a);
-
-                        span[i] = SkPackARGB32(a, r, g, b);
+                        span[i] = SkPackARGB32( a,
+                                        SkFastMin32(SkAlphaMul(r, mul) + add, a),
+                                        SkFastMin32(SkAlphaMul(g, mul) + add, a),
+                                        SkFastMin32(SkAlphaMul(b, mul) + add, a));
+                    } else {
+                        span[i] = 0;
                     }
-                } else {
-                    span[i] = 0;
-                }
-            }
-        } else {    // color
-            unsigned a = SkGetPackedA32(fPMColor);
-            unsigned r = SkGetPackedR32(fPMColor);
-            unsigned g = SkGetPackedG32(fPMColor);
-            unsigned b = SkGetPackedB32(fPMColor);
-            for (int i = 0; i < count; i++) {
-                if (alpha[i]) {
-                    unsigned mul = SkAlpha255To256(mulp[i]);
-                    unsigned add = addp[i];
-
-                    span[i] = SkPackARGB32( a,
-                                    SkFastMin32(SkAlphaMul(r, mul) + add, a),
-                                    SkFastMin32(SkAlphaMul(g, mul) + add, a),
-                                    SkFastMin32(SkAlphaMul(b, mul) + add, a));
-                } else {
-                    span[i] = 0;
                 }
             }
         }
-    }
+
+    private:
+        // Unowned.
+        const SkMask*       fMask;
+        // Memory is unowned, but we need to call the destructor.
+        SkShader::Context*  fProxyContext;
+        SkPMColor           fPMColor;
+
+        typedef SkShader::Context INHERITED;
+    };
 
 #ifndef SK_IGNORE_TO_STRING
     virtual void toString(SkString* str) const SK_OVERRIDE {
@@ -685,29 +741,30 @@
 protected:
     Sk3DShader(SkReadBuffer& buffer) : INHERITED(buffer) {
         fProxy = buffer.readShader();
-        fPMColor = buffer.readColor();
-        fMask = NULL;
+        // Leaving this here until we bump the picture version, though this
+        // shader should never be recorded.
+        buffer.readColor();
     }
 
     virtual void flatten(SkWriteBuffer& buffer) const SK_OVERRIDE {
         this->INHERITED::flatten(buffer);
         buffer.writeFlattenable(fProxy);
-        buffer.writeColor(fPMColor);
+        // Leaving this here until we bump the picture version, though this
+        // shader should never be recorded.
+        buffer.writeColor(SkColor());
     }
 
 private:
     SkShader*       fProxy;
-    SkPMColor       fPMColor;
-    const SkMask*   fMask;
 
     typedef SkShader INHERITED;
 };
 
 class Sk3DBlitter : public SkBlitter {
 public:
-    Sk3DBlitter(SkBlitter* proxy, Sk3DShader* shader)
+    Sk3DBlitter(SkBlitter* proxy, Sk3DShader::Sk3DShaderContext* shaderContext)
         : fProxy(proxy)
-        , f3DShader(SkRef(shader))
+        , f3DShaderContext(shaderContext)
     {}
 
     virtual void blitH(int x, int y, int width) {
@@ -729,22 +786,22 @@
 
     virtual void blitMask(const SkMask& mask, const SkIRect& clip) {
         if (mask.fFormat == SkMask::k3D_Format) {
-            f3DShader->setMask(&mask);
+            f3DShaderContext->setMask(&mask);
 
             ((SkMask*)&mask)->fFormat = SkMask::kA8_Format;
             fProxy->blitMask(mask, clip);
             ((SkMask*)&mask)->fFormat = SkMask::k3D_Format;
 
-            f3DShader->setMask(NULL);
+            f3DShaderContext->setMask(NULL);
         } else {
             fProxy->blitMask(mask, clip);
         }
     }
 
 private:
-    // fProxy is unowned. It will be deleted by SkSmallAllocator.
-    SkBlitter*               fProxy;
-    SkAutoTUnref<Sk3DShader> f3DShader;
+    // Both pointers are unowned. They will be deleted by SkSmallAllocator.
+    SkBlitter*                     fProxy;
+    Sk3DShader::Sk3DShaderContext* f3DShaderContext;
 };
 
 ///////////////////////////////////////////////////////////////////////////////
@@ -754,8 +811,7 @@
 static bool just_solid_color(const SkPaint& paint) {
     if (paint.getAlpha() == 0xFF && paint.getColorFilter() == NULL) {
         SkShader* shader = paint.getShader();
-        if (NULL == shader ||
-            (shader->getFlags() & SkShader::kOpaqueAlpha_Flag)) {
+        if (NULL == shader) {
             return true;
         }
     }
@@ -893,16 +949,22 @@
     }
 
     /*
-     *  We need to have balanced calls to the shader:
-     *      setContext
-     *      endContext
-     *  We make the first call here, in case it fails we can abort the draw.
-     *  The endContext() call is made by the blitter (assuming setContext did
-     *  not fail) in its destructor.
+     *  We create a SkShader::Context object, and store it on the blitter.
      */
-    if (shader && !shader->setContext(device, *paint, matrix)) {
-        blitter = allocator->createT<SkNullBlitter>();
-        return blitter;
+    SkShader::Context* shaderContext;
+    if (shader) {
+        // Try to create the ShaderContext
+        void* storage = allocator->reserveT<SkShader::Context>(shader->contextSize());
+        shaderContext = shader->createContext(device, *paint, matrix, storage);
+        if (!shaderContext) {
+            allocator->freeLast();
+            blitter = allocator->createT<SkNullBlitter>();
+            return blitter;
+        }
+        SkASSERT(shaderContext);
+        SkASSERT((void*) shaderContext == storage);
+    } else {
+        shaderContext = NULL;
     }
 
 
@@ -913,19 +975,20 @@
                 SkASSERT(NULL == paint->getXfermode());
                 blitter = allocator->createT<SkA8_Coverage_Blitter>(device, *paint);
             } else if (shader) {
-                blitter = allocator->createT<SkA8_Shader_Blitter>(device, *paint);
+                blitter = allocator->createT<SkA8_Shader_Blitter>(device, *paint, shaderContext);
             } else {
                 blitter = allocator->createT<SkA8_Blitter>(device, *paint);
             }
             break;
 
         case kRGB_565_SkColorType:
-            blitter = SkBlitter_ChooseD565(device, *paint, allocator);
+            blitter = SkBlitter_ChooseD565(device, *paint, shaderContext, allocator);
             break;
 
         case kN32_SkColorType:
             if (shader) {
-                blitter = allocator->createT<SkARGB32_Shader_Blitter>(device, *paint);
+                blitter = allocator->createT<SkARGB32_Shader_Blitter>(
+                        device, *paint, shaderContext);
             } else if (paint->getColor() == SK_ColorBLACK) {
                 blitter = allocator->createT<SkARGB32_Black_Blitter>(device, *paint);
             } else if (paint->getAlpha() == 0xFF) {
@@ -944,7 +1007,9 @@
     if (shader3D) {
         SkBlitter* innerBlitter = blitter;
         // innerBlitter was allocated by allocator, which will delete it.
-        blitter = allocator->createT<Sk3DBlitter>(innerBlitter, shader3D);
+        // We know shaderContext is of type Sk3DShaderContext because it belongs to shader3D.
+        blitter = allocator->createT<Sk3DBlitter>(innerBlitter,
+                static_cast<Sk3DShader::Sk3DShaderContext*>(shaderContext));
     }
     return blitter;
 }
@@ -956,18 +1021,36 @@
 
 ///////////////////////////////////////////////////////////////////////////////
 
-SkShaderBlitter::SkShaderBlitter(const SkBitmap& device, const SkPaint& paint)
-        : INHERITED(device) {
-    fShader = paint.getShader();
+SkShaderBlitter::SkShaderBlitter(const SkBitmap& device, const SkPaint& paint,
+                                 SkShader::Context* shaderContext)
+        : INHERITED(device)
+        , fShader(paint.getShader())
+        , fShaderContext(shaderContext) {
     SkASSERT(fShader);
-    SkASSERT(fShader->setContextHasBeenCalled());
+    SkASSERT(fShaderContext);
 
     fShader->ref();
-    fShaderFlags = fShader->getFlags();
+    fShaderFlags = fShaderContext->getFlags();
 }
 
 SkShaderBlitter::~SkShaderBlitter() {
-    SkASSERT(fShader->setContextHasBeenCalled());
-    fShader->endContext();
     fShader->unref();
 }
+
+bool SkShaderBlitter::resetShaderContext(const SkBitmap& device, const SkPaint& paint,
+                                         const SkMatrix& matrix) {
+    if (!fShader->validContext(device, paint, matrix)) {
+        return false;
+    }
+
+    // Only destroy the old context if we have a new one. We need to ensure to have a
+    // live context in fShaderContext because the storage is owned by an SkSmallAllocator
+    // outside of this class.
+    // The new context will be of the same size as the old one because we use the same
+    // shader to create it. It is therefore safe to re-use the storage.
+    fShaderContext->~Context();
+    fShaderContext = fShader->createContext(device, paint, matrix, (void*)fShaderContext);
+    SkASSERT(fShaderContext);
+
+    return true;
+}
diff --git a/src/core/SkBlitter.h b/src/core/SkBlitter.h
index d19a34b..f76839e 100644
--- a/src/core/SkBlitter.h
+++ b/src/core/SkBlitter.h
@@ -61,6 +61,13 @@
      */
     virtual bool isNullBlitter() const;
 
+    /**
+     *  Special methods for SkShaderBlitter. On all other classes this is a no-op.
+     */
+    virtual bool resetShaderContext(const SkBitmap& device, const SkPaint& paint,
+                                    const SkMatrix& matrix);
+    virtual SkShader::Context* getShaderContext() const;
+
     ///@name non-virtual helpers
     void blitMaskRegion(const SkMask& mask, const SkRegion& clip);
     void blitRectRegion(const SkIRect& rect, const SkRegion& clip);
diff --git a/src/core/SkBlitter_A8.cpp b/src/core/SkBlitter_A8.cpp
index 983a226..11f4259 100644
--- a/src/core/SkBlitter_A8.cpp
+++ b/src/core/SkBlitter_A8.cpp
@@ -228,11 +228,12 @@
 
 ///////////////////////////////////////////////////////////////////////
 
-SkA8_Shader_Blitter::SkA8_Shader_Blitter(const SkBitmap& device, const SkPaint& paint)
-    : INHERITED(device, paint) {
+SkA8_Shader_Blitter::SkA8_Shader_Blitter(const SkBitmap& device, const SkPaint& paint,
+                                         SkShader::Context* shaderContext)
+    : INHERITED(device, paint, shaderContext) {
     if ((fXfermode = paint.getXfermode()) != NULL) {
         fXfermode->ref();
-        SkASSERT(fShader);
+        SkASSERT(fShaderContext);
     }
 
     int width = device.width();
@@ -250,13 +251,14 @@
              (unsigned)(x + width) <= (unsigned)fDevice.width());
 
     uint8_t* device = fDevice.getAddr8(x, y);
+    SkShader::Context* shaderContext = fShaderContext;
 
-    if ((fShader->getFlags() & SkShader::kOpaqueAlpha_Flag) && !fXfermode) {
+    if ((shaderContext->getFlags() & SkShader::kOpaqueAlpha_Flag) && !fXfermode) {
         memset(device, 0xFF, width);
     } else {
         SkPMColor*  span = fBuffer;
 
-        fShader->shadeSpan(x, y, span, width);
+        shaderContext->shadeSpan(x, y, span, width);
         if (fXfermode) {
             fXfermode->xferA8(device, span, width, NULL);
         } else {
@@ -282,12 +284,12 @@
 
 void SkA8_Shader_Blitter::blitAntiH(int x, int y, const SkAlpha antialias[],
                                     const int16_t runs[]) {
-    SkShader*   shader = fShader;
-    SkXfermode* mode = fXfermode;
-    uint8_t*    aaExpand = fAAExpand;
-    SkPMColor*  span = fBuffer;
-    uint8_t*    device = fDevice.getAddr8(x, y);
-    int         opaque = fShader->getFlags() & SkShader::kOpaqueAlpha_Flag;
+    SkShader::Context* shaderContext = fShaderContext;
+    SkXfermode*        mode = fXfermode;
+    uint8_t*           aaExpand = fAAExpand;
+    SkPMColor*         span = fBuffer;
+    uint8_t*           device = fDevice.getAddr8(x, y);
+    int                opaque = shaderContext->getFlags() & SkShader::kOpaqueAlpha_Flag;
 
     for (;;) {
         int count = *runs;
@@ -299,7 +301,7 @@
             if (opaque && aa == 255 && mode == NULL) {
                 memset(device, 0xFF, count);
             } else {
-                shader->shadeSpan(x, y, span, count);
+                shaderContext->shadeSpan(x, y, span, count);
                 if (mode) {
                     memset(aaExpand, aa, count);
                     mode->xferA8(device, span, count, aaExpand);
@@ -329,11 +331,12 @@
     int height = clip.height();
     uint8_t* device = fDevice.getAddr8(x, y);
     const uint8_t* alpha = mask.getAddr8(x, y);
+    SkShader::Context* shaderContext = fShaderContext;
 
     SkPMColor*  span = fBuffer;
 
     while (--height >= 0) {
-        fShader->shadeSpan(x, y, span, width);
+        shaderContext->shadeSpan(x, y, span, width);
         if (fXfermode) {
             fXfermode->xferA8(device, span, width, alpha);
         } else {
diff --git a/src/core/SkBlitter_ARGB32.cpp b/src/core/SkBlitter_ARGB32.cpp
index d4bec1b..118a1d1 100644
--- a/src/core/SkBlitter_ARGB32.cpp
+++ b/src/core/SkBlitter_ARGB32.cpp
@@ -275,14 +275,16 @@
 }
 
 SkARGB32_Shader_Blitter::SkARGB32_Shader_Blitter(const SkBitmap& device,
-                            const SkPaint& paint) : INHERITED(device, paint) {
+        const SkPaint& paint, SkShader::Context* shaderContext)
+    : INHERITED(device, paint, shaderContext)
+{
     fBuffer = (SkPMColor*)sk_malloc_throw(device.width() * (sizeof(SkPMColor)));
 
     fXfermode = paint.getXfermode();
     SkSafeRef(fXfermode);
 
     int flags = 0;
-    if (!(fShader->getFlags() & SkShader::kOpaqueAlpha_Flag)) {
+    if (!(shaderContext->getFlags() & SkShader::kOpaqueAlpha_Flag)) {
         flags |= SkBlitRow::kSrcPixelAlpha_Flag32;
     }
     // we call this on the output from the shader
@@ -292,7 +294,7 @@
 
     fShadeDirectlyIntoDevice = false;
     if (fXfermode == NULL) {
-        if (fShader->getFlags() & SkShader::kOpaqueAlpha_Flag) {
+        if (shaderContext->getFlags() & SkShader::kOpaqueAlpha_Flag) {
             fShadeDirectlyIntoDevice = true;
         }
     } else {
@@ -305,7 +307,7 @@
         }
     }
 
-    fConstInY = SkToBool(fShader->getFlags() & SkShader::kConstInY32_Flag);
+    fConstInY = SkToBool(shaderContext->getFlags() & SkShader::kConstInY32_Flag);
 }
 
 SkARGB32_Shader_Blitter::~SkARGB32_Shader_Blitter() {
@@ -319,10 +321,10 @@
     uint32_t*   device = fDevice.getAddr32(x, y);
 
     if (fShadeDirectlyIntoDevice) {
-        fShader->shadeSpan(x, y, device, width);
+        fShaderContext->shadeSpan(x, y, device, width);
     } else {
         SkPMColor*  span = fBuffer;
-        fShader->shadeSpan(x, y, span, width);
+        fShaderContext->shadeSpan(x, y, span, width);
         if (fXfermode) {
             fXfermode->xfer32(device, span, width, NULL);
         } else {
@@ -335,22 +337,22 @@
     SkASSERT(x >= 0 && y >= 0 &&
              x + width <= fDevice.width() && y + height <= fDevice.height());
 
-    uint32_t*   device = fDevice.getAddr32(x, y);
-    size_t      deviceRB = fDevice.rowBytes();
-    SkShader*   shader = fShader;
-    SkPMColor*  span = fBuffer;
+    uint32_t*          device = fDevice.getAddr32(x, y);
+    size_t             deviceRB = fDevice.rowBytes();
+    SkShader::Context* shaderContext = fShaderContext;
+    SkPMColor*         span = fBuffer;
 
     if (fConstInY) {
         if (fShadeDirectlyIntoDevice) {
             // shade the first row directly into the device
-            fShader->shadeSpan(x, y, device, width);
+            shaderContext->shadeSpan(x, y, device, width);
             span = device;
             while (--height > 0) {
                 device = (uint32_t*)((char*)device + deviceRB);
                 memcpy(device, span, width << 2);
             }
         } else {
-            fShader->shadeSpan(x, y, span, width);
+            shaderContext->shadeSpan(x, y, span, width);
             SkXfermode* xfer = fXfermode;
             if (xfer) {
                 do {
@@ -372,7 +374,7 @@
 
     if (fShadeDirectlyIntoDevice) {
         void* ctx;
-        SkShader::ShadeProc shadeProc = fShader->asAShadeProc(&ctx);
+        SkShader::Context::ShadeProc shadeProc = shaderContext->asAShadeProc(&ctx);
         if (shadeProc) {
             do {
                 shadeProc(ctx, x, y, device, width);
@@ -381,7 +383,7 @@
             } while (--height > 0);
         } else {
             do {
-                shader->shadeSpan(x, y, device, width);
+                shaderContext->shadeSpan(x, y, device, width);
                 y += 1;
                 device = (uint32_t*)((char*)device + deviceRB);
             } while (--height > 0);
@@ -390,7 +392,7 @@
         SkXfermode* xfer = fXfermode;
         if (xfer) {
             do {
-                shader->shadeSpan(x, y, span, width);
+                shaderContext->shadeSpan(x, y, span, width);
                 xfer->xfer32(device, span, width, NULL);
                 y += 1;
                 device = (uint32_t*)((char*)device + deviceRB);
@@ -398,7 +400,7 @@
         } else {
             SkBlitRow::Proc32 proc = fProc32;
             do {
-                shader->shadeSpan(x, y, span, width);
+                shaderContext->shadeSpan(x, y, span, width);
                 proc(device, span, width, 255);
                 y += 1;
                 device = (uint32_t*)((char*)device + deviceRB);
@@ -409,9 +411,9 @@
 
 void SkARGB32_Shader_Blitter::blitAntiH(int x, int y, const SkAlpha antialias[],
                                         const int16_t runs[]) {
-    SkPMColor*  span = fBuffer;
-    uint32_t*   device = fDevice.getAddr32(x, y);
-    SkShader*   shader = fShader;
+    SkPMColor*         span = fBuffer;
+    uint32_t*          device = fDevice.getAddr32(x, y);
+    SkShader::Context* shaderContext = fShaderContext;
 
     if (fXfermode && !fShadeDirectlyIntoDevice) {
         for (;;) {
@@ -422,7 +424,7 @@
                 break;
             int aa = *antialias;
             if (aa) {
-                shader->shadeSpan(x, y, span, count);
+                shaderContext->shadeSpan(x, y, span, count);
                 if (aa == 255) {
                     xfer->xfer32(device, span, count, NULL);
                 } else {
@@ -438,7 +440,7 @@
             x += count;
         }
     } else if (fShadeDirectlyIntoDevice ||
-               (fShader->getFlags() & SkShader::kOpaqueAlpha_Flag)) {
+               (shaderContext->getFlags() & SkShader::kOpaqueAlpha_Flag)) {
         for (;;) {
             int count = *runs;
             if (count <= 0) {
@@ -448,9 +450,9 @@
             if (aa) {
                 if (aa == 255) {
                     // cool, have the shader draw right into the device
-                    shader->shadeSpan(x, y, device, count);
+                    shaderContext->shadeSpan(x, y, device, count);
                 } else {
-                    shader->shadeSpan(x, y, span, count);
+                    shaderContext->shadeSpan(x, y, span, count);
                     fProc32Blend(device, span, count, aa);
                 }
             }
@@ -467,7 +469,7 @@
             }
             int aa = *antialias;
             if (aa) {
-                fShader->shadeSpan(x, y, span, count);
+                shaderContext->shadeSpan(x, y, span, count);
                 if (aa == 255) {
                     fProc32(device, span, count, 255);
                 } else {
@@ -491,10 +493,11 @@
 
     SkASSERT(mask.fBounds.contains(clip));
 
+    SkShader::Context*  shaderContext = fShaderContext;
     SkBlitMask::RowProc proc = NULL;
     if (!fXfermode) {
         unsigned flags = 0;
-        if (fShader->getFlags() & SkShader::kOpaqueAlpha_Flag) {
+        if (shaderContext->getFlags() & SkShader::kOpaqueAlpha_Flag) {
             flags |= SkBlitMask::kSrcIsOpaque_RowFlag;
         }
         proc = SkBlitMask::RowFactory(SkBitmap::kARGB_8888_Config, mask.fFormat,
@@ -515,14 +518,13 @@
     const uint8_t* maskRow = (const uint8_t*)mask.getAddr(x, y);
     const size_t maskRB = mask.fRowBytes;
 
-    SkShader* shader = fShader;
     SkPMColor* span = fBuffer;
 
     if (fXfermode) {
         SkASSERT(SkMask::kA8_Format == mask.fFormat);
         SkXfermode* xfer = fXfermode;
         do {
-            shader->shadeSpan(x, y, span, width);
+            shaderContext->shadeSpan(x, y, span, width);
             xfer->xfer32((SkPMColor*)dstRow, span, width, maskRow);
             dstRow += dstRB;
             maskRow += maskRB;
@@ -530,7 +532,7 @@
         } while (--height > 0);
     } else {
         do {
-            shader->shadeSpan(x, y, span, width);
+            shaderContext->shadeSpan(x, y, span, width);
             proc(dstRow, maskRow, span, width);
             dstRow += dstRB;
             maskRow += maskRB;
@@ -542,13 +544,13 @@
 void SkARGB32_Shader_Blitter::blitV(int x, int y, int height, SkAlpha alpha) {
     SkASSERT(x >= 0 && y >= 0 && y + height <= fDevice.height());
 
-    uint32_t*   device = fDevice.getAddr32(x, y);
-    size_t      deviceRB = fDevice.rowBytes();
-    SkShader*   shader = fShader;
+    uint32_t*          device = fDevice.getAddr32(x, y);
+    size_t             deviceRB = fDevice.rowBytes();
+    SkShader::Context* shaderContext = fShaderContext;
 
     if (fConstInY) {
         SkPMColor c;
-        fShader->shadeSpan(x, y, &c, 1);
+        shaderContext->shadeSpan(x, y, &c, 1);
 
         if (fShadeDirectlyIntoDevice) {
             if (255 == alpha) {
@@ -582,7 +584,7 @@
 
     if (fShadeDirectlyIntoDevice) {
         void* ctx;
-        SkShader::ShadeProc shadeProc = fShader->asAShadeProc(&ctx);
+        SkShader::Context::ShadeProc shadeProc = shaderContext->asAShadeProc(&ctx);
         if (255 == alpha) {
             if (shadeProc) {
                 do {
@@ -592,7 +594,7 @@
                 } while (--height > 0);
             } else {
                 do {
-                    shader->shadeSpan(x, y, device, 1);
+                    shaderContext->shadeSpan(x, y, device, 1);
                     y += 1;
                     device = (uint32_t*)((char*)device + deviceRB);
                 } while (--height > 0);
@@ -608,7 +610,7 @@
                 } while (--height > 0);
             } else {
                 do {
-                    shader->shadeSpan(x, y, &c, 1);
+                    shaderContext->shadeSpan(x, y, &c, 1);
                     *device = SkFourByteInterp(c, *device, alpha);
                     y += 1;
                     device = (uint32_t*)((char*)device + deviceRB);
@@ -620,7 +622,7 @@
         SkXfermode* xfer = fXfermode;
         if (xfer) {
             do {
-                shader->shadeSpan(x, y, span, 1);
+                shaderContext->shadeSpan(x, y, span, 1);
                 xfer->xfer32(device, span, 1, &alpha);
                 y += 1;
                 device = (uint32_t*)((char*)device + deviceRB);
@@ -628,7 +630,7 @@
         } else {
             SkBlitRow::Proc32 proc = (255 == alpha) ? fProc32 : fProc32Blend;
             do {
-                shader->shadeSpan(x, y, span, 1);
+                shaderContext->shadeSpan(x, y, span, 1);
                 proc(device, span, 1, alpha);
                 y += 1;
                 device = (uint32_t*)((char*)device + deviceRB);
diff --git a/src/core/SkBlitter_RGB16.cpp b/src/core/SkBlitter_RGB16.cpp
index 21b5a16..e22aac4 100644
--- a/src/core/SkBlitter_RGB16.cpp
+++ b/src/core/SkBlitter_RGB16.cpp
@@ -107,7 +107,8 @@
 
 class SkRGB16_Shader_Blitter : public SkShaderBlitter {
 public:
-    SkRGB16_Shader_Blitter(const SkBitmap& device, const SkPaint& paint);
+    SkRGB16_Shader_Blitter(const SkBitmap& device, const SkPaint& paint,
+                           SkShader::Context* shaderContext);
     virtual ~SkRGB16_Shader_Blitter();
     virtual void blitH(int x, int y, int width);
     virtual void blitAntiH(int x, int y, const SkAlpha* antialias,
@@ -129,7 +130,8 @@
 // used only if the shader can perform shadSpan16
 class SkRGB16_Shader16_Blitter : public SkRGB16_Shader_Blitter {
 public:
-    SkRGB16_Shader16_Blitter(const SkBitmap& device, const SkPaint& paint);
+    SkRGB16_Shader16_Blitter(const SkBitmap& device, const SkPaint& paint,
+                             SkShader::Context* shaderContext);
     virtual void blitH(int x, int y, int width);
     virtual void blitAntiH(int x, int y, const SkAlpha* antialias,
                            const int16_t* runs);
@@ -141,7 +143,8 @@
 
 class SkRGB16_Shader_Xfermode_Blitter : public SkShaderBlitter {
 public:
-    SkRGB16_Shader_Xfermode_Blitter(const SkBitmap& device, const SkPaint& paint);
+    SkRGB16_Shader_Xfermode_Blitter(const SkBitmap& device, const SkPaint& paint,
+                                    SkShader::Context* shaderContext);
     virtual ~SkRGB16_Shader_Xfermode_Blitter();
     virtual void blitH(int x, int y, int width);
     virtual void blitAntiH(int x, int y, const SkAlpha* antialias,
@@ -679,8 +682,9 @@
 ///////////////////////////////////////////////////////////////////////////////
 
 SkRGB16_Shader16_Blitter::SkRGB16_Shader16_Blitter(const SkBitmap& device,
-                                                   const SkPaint& paint)
-    : SkRGB16_Shader_Blitter(device, paint) {
+                                                   const SkPaint& paint,
+                                                   SkShader::Context* shaderContext)
+    : SkRGB16_Shader_Blitter(device, paint, shaderContext) {
     SkASSERT(SkShader::CanCallShadeSpan16(fShaderFlags));
 }
 
@@ -688,28 +692,28 @@
     SkASSERT(x + width <= fDevice.width());
 
     uint16_t* SK_RESTRICT device = fDevice.getAddr16(x, y);
-    SkShader*   shader = fShader;
+    SkShader::Context*    shaderContext = fShaderContext;
 
-    int alpha = shader->getSpan16Alpha();
+    int alpha = shaderContext->getSpan16Alpha();
     if (0xFF == alpha) {
-        shader->shadeSpan16(x, y, device, width);
+        shaderContext->shadeSpan16(x, y, device, width);
     } else {
         uint16_t* span16 = (uint16_t*)fBuffer;
-        shader->shadeSpan16(x, y, span16, width);
+        shaderContext->shadeSpan16(x, y, span16, width);
         SkBlendRGB16(span16, device, SkAlpha255To256(alpha), width);
     }
 }
 
 void SkRGB16_Shader16_Blitter::blitRect(int x, int y, int width, int height) {
-    SkShader*   shader = fShader;
-    uint16_t*   dst = fDevice.getAddr16(x, y);
-    size_t      dstRB = fDevice.rowBytes();
-    int         alpha = shader->getSpan16Alpha();
+    SkShader::Context* shaderContext = fShaderContext;
+    uint16_t*          dst = fDevice.getAddr16(x, y);
+    size_t             dstRB = fDevice.rowBytes();
+    int                alpha = shaderContext->getSpan16Alpha();
 
     if (0xFF == alpha) {
         if (fShaderFlags & SkShader::kConstInY16_Flag) {
             // have the shader blit directly into the device the first time
-            shader->shadeSpan16(x, y, dst, width);
+            shaderContext->shadeSpan16(x, y, dst, width);
             // and now just memcpy that line on the subsequent lines
             if (--height > 0) {
                 const uint16_t* orig = dst;
@@ -720,7 +724,7 @@
             }
         } else {    // need to call shadeSpan16 for every line
             do {
-                shader->shadeSpan16(x, y, dst, width);
+                shaderContext->shadeSpan16(x, y, dst, width);
                 y += 1;
                 dst = (uint16_t*)((char*)dst + dstRB);
             } while (--height);
@@ -729,14 +733,14 @@
         int scale = SkAlpha255To256(alpha);
         uint16_t* span16 = (uint16_t*)fBuffer;
         if (fShaderFlags & SkShader::kConstInY16_Flag) {
-            shader->shadeSpan16(x, y, span16, width);
+            shaderContext->shadeSpan16(x, y, span16, width);
             do {
                 SkBlendRGB16(span16, dst, scale, width);
                 dst = (uint16_t*)((char*)dst + dstRB);
             } while (--height);
         } else {
             do {
-                shader->shadeSpan16(x, y, span16, width);
+                shaderContext->shadeSpan16(x, y, span16, width);
                 SkBlendRGB16(span16, dst, scale, width);
                 y += 1;
                 dst = (uint16_t*)((char*)dst + dstRB);
@@ -748,11 +752,11 @@
 void SkRGB16_Shader16_Blitter::blitAntiH(int x, int y,
                                          const SkAlpha* SK_RESTRICT antialias,
                                          const int16_t* SK_RESTRICT runs) {
-    SkShader*   shader = fShader;
+    SkShader::Context*     shaderContext = fShaderContext;
     SkPMColor* SK_RESTRICT span = fBuffer;
-    uint16_t* SK_RESTRICT device = fDevice.getAddr16(x, y);
+    uint16_t* SK_RESTRICT  device = fDevice.getAddr16(x, y);
 
-    int alpha = shader->getSpan16Alpha();
+    int alpha = shaderContext->getSpan16Alpha();
     uint16_t* span16 = (uint16_t*)span;
 
     if (0xFF == alpha) {
@@ -766,9 +770,9 @@
             int aa = *antialias;
             if (aa == 255) {
                 // go direct to the device!
-                shader->shadeSpan16(x, y, device, count);
+                shaderContext->shadeSpan16(x, y, device, count);
             } else if (aa) {
-                shader->shadeSpan16(x, y, span16, count);
+                shaderContext->shadeSpan16(x, y, span16, count);
                 SkBlendRGB16(span16, device, SkAlpha255To256(aa), count);
             }
             device += count;
@@ -787,7 +791,7 @@
 
             int aa = SkAlphaMul(*antialias, alpha);
             if (aa) {
-                shader->shadeSpan16(x, y, span16, count);
+                shaderContext->shadeSpan16(x, y, span16, count);
                 SkBlendRGB16(span16, device, SkAlpha255To256(aa), count);
             }
 
@@ -802,8 +806,9 @@
 ///////////////////////////////////////////////////////////////////////////////
 
 SkRGB16_Shader_Blitter::SkRGB16_Shader_Blitter(const SkBitmap& device,
-                                               const SkPaint& paint)
-: INHERITED(device, paint) {
+                                               const SkPaint& paint,
+                                               SkShader::Context* shaderContext)
+: INHERITED(device, paint, shaderContext) {
     SkASSERT(paint.getXfermode() == NULL);
 
     fBuffer = (SkPMColor*)sk_malloc_throw(device.width() * sizeof(SkPMColor));
@@ -834,20 +839,20 @@
 void SkRGB16_Shader_Blitter::blitH(int x, int y, int width) {
     SkASSERT(x + width <= fDevice.width());
 
-    fShader->shadeSpan(x, y, fBuffer, width);
+    fShaderContext->shadeSpan(x, y, fBuffer, width);
     // shaders take care of global alpha, so we pass 0xFF (should be ignored)
     fOpaqueProc(fDevice.getAddr16(x, y), fBuffer, width, 0xFF, x, y);
 }
 
 void SkRGB16_Shader_Blitter::blitRect(int x, int y, int width, int height) {
-    SkShader*       shader = fShader;
-    SkBlitRow::Proc proc = fOpaqueProc;
-    SkPMColor*      buffer = fBuffer;
-    uint16_t*       dst = fDevice.getAddr16(x, y);
-    size_t          dstRB = fDevice.rowBytes();
+    SkShader::Context* shaderContext = fShaderContext;
+    SkBlitRow::Proc    proc = fOpaqueProc;
+    SkPMColor*         buffer = fBuffer;
+    uint16_t*          dst = fDevice.getAddr16(x, y);
+    size_t             dstRB = fDevice.rowBytes();
 
     if (fShaderFlags & SkShader::kConstInY32_Flag) {
-        shader->shadeSpan(x, y, buffer, width);
+        shaderContext->shadeSpan(x, y, buffer, width);
         do {
             proc(dst, buffer, width, 0xFF, x, y);
             y += 1;
@@ -855,7 +860,7 @@
         } while (--height);
     } else {
         do {
-            shader->shadeSpan(x, y, buffer, width);
+            shaderContext->shadeSpan(x, y, buffer, width);
             proc(dst, buffer, width, 0xFF, x, y);
             y += 1;
             dst = (uint16_t*)((char*)dst + dstRB);
@@ -880,9 +885,9 @@
 void SkRGB16_Shader_Blitter::blitAntiH(int x, int y,
                                        const SkAlpha* SK_RESTRICT antialias,
                                        const int16_t* SK_RESTRICT runs) {
-    SkShader*   shader = fShader;
+    SkShader::Context*     shaderContext = fShaderContext;
     SkPMColor* SK_RESTRICT span = fBuffer;
-    uint16_t* SK_RESTRICT device = fDevice.getAddr16(x, y);
+    uint16_t* SK_RESTRICT  device = fDevice.getAddr16(x, y);
 
     for (;;) {
         int count = *runs;
@@ -901,7 +906,7 @@
         int nonZeroCount = count + count_nonzero_span(runs + count, antialias + count);
 
         SkASSERT(nonZeroCount <= fDevice.width()); // don't overrun fBuffer
-        shader->shadeSpan(x, y, span, nonZeroCount);
+        shaderContext->shadeSpan(x, y, span, nonZeroCount);
 
         SkPMColor* localSpan = span;
         for (;;) {
@@ -928,8 +933,9 @@
 ///////////////////////////////////////////////////////////////////////
 
 SkRGB16_Shader_Xfermode_Blitter::SkRGB16_Shader_Xfermode_Blitter(
-                                const SkBitmap& device, const SkPaint& paint)
-: INHERITED(device, paint) {
+                                const SkBitmap& device, const SkPaint& paint,
+                                SkShader::Context* shaderContext)
+: INHERITED(device, paint, shaderContext) {
     fXfermode = paint.getXfermode();
     SkASSERT(fXfermode);
     fXfermode->ref();
@@ -950,18 +956,18 @@
     uint16_t*   device = fDevice.getAddr16(x, y);
     SkPMColor*  span = fBuffer;
 
-    fShader->shadeSpan(x, y, span, width);
+    fShaderContext->shadeSpan(x, y, span, width);
     fXfermode->xfer16(device, span, width, NULL);
 }
 
 void SkRGB16_Shader_Xfermode_Blitter::blitAntiH(int x, int y,
                                 const SkAlpha* SK_RESTRICT antialias,
                                 const int16_t* SK_RESTRICT runs) {
-    SkShader*   shader = fShader;
-    SkXfermode* mode = fXfermode;
+    SkShader::Context*     shaderContext = fShaderContext;
+    SkXfermode*            mode = fXfermode;
     SkPMColor* SK_RESTRICT span = fBuffer;
-    uint8_t* SK_RESTRICT aaExpand = fAAExpand;
-    uint16_t* SK_RESTRICT device = fDevice.getAddr16(x, y);
+    uint8_t* SK_RESTRICT   aaExpand = fAAExpand;
+    uint16_t* SK_RESTRICT  device = fDevice.getAddr16(x, y);
 
     for (;;) {
         int count = *runs;
@@ -981,7 +987,7 @@
                                                       antialias + count);
 
         SkASSERT(nonZeroCount <= fDevice.width()); // don't overrun fBuffer
-        shader->shadeSpan(x, y, span, nonZeroCount);
+        shaderContext->shadeSpan(x, y, span, nonZeroCount);
 
         x += nonZeroCount;
         SkPMColor* localSpan = span;
@@ -1012,6 +1018,7 @@
 ///////////////////////////////////////////////////////////////////////////////
 
 SkBlitter* SkBlitter_ChooseD565(const SkBitmap& device, const SkPaint& paint,
+        SkShader::Context* shaderContext,
         SkTBlitterAllocator* allocator) {
     SkASSERT(allocator != NULL);
 
@@ -1023,12 +1030,14 @@
     SkASSERT(NULL == mode || NULL != shader);
 
     if (shader) {
+        SkASSERT(shaderContext != NULL);
         if (mode) {
-            blitter = allocator->createT<SkRGB16_Shader_Xfermode_Blitter>(device, paint);
-        } else if (shader->canCallShadeSpan16()) {
-            blitter = allocator->createT<SkRGB16_Shader16_Blitter>(device, paint);
+            blitter = allocator->createT<SkRGB16_Shader_Xfermode_Blitter>(device, paint,
+                                                                          shaderContext);
+        } else if (shaderContext->canCallShadeSpan16()) {
+            blitter = allocator->createT<SkRGB16_Shader16_Blitter>(device, paint, shaderContext);
         } else {
-            blitter = allocator->createT<SkRGB16_Shader_Blitter>(device, paint);
+            blitter = allocator->createT<SkRGB16_Shader_Blitter>(device, paint, shaderContext);
         }
     } else {
         // no shader, no xfermode, (and we always ignore colorfilter)
diff --git a/src/core/SkCanvas.cpp b/src/core/SkCanvas.cpp
index d839971..e3451cd 100644
--- a/src/core/SkCanvas.cpp
+++ b/src/core/SkCanvas.cpp
@@ -91,32 +91,10 @@
 };
 #endif
 
-class AutoCheckNoSetContext {
-public:
-    AutoCheckNoSetContext(const SkPaint& paint) : fPaint(paint) {
-        this->assertNoSetContext(fPaint);
-    }
-    ~AutoCheckNoSetContext() {
-        this->assertNoSetContext(fPaint);
-    }
-
-private:
-    const SkPaint& fPaint;
-
-    void assertNoSetContext(const SkPaint& paint) {
-        SkShader* s = paint.getShader();
-        if (s) {
-            SkASSERT(!s->setContextHasBeenCalled());
-        }
-    }
-};
-
 #define CHECK_LOCKCOUNT_BALANCE(bitmap)  AutoCheckLockCountBalance clcb(bitmap)
-#define CHECK_SHADER_NOSETCONTEXT(paint) AutoCheckNoSetContext     cshsc(paint)
 
 #else
     #define CHECK_LOCKCOUNT_BALANCE(bitmap)
-    #define CHECK_SHADER_NOSETCONTEXT(paint)
 #endif
 
 typedef SkTLazy<SkPaint> SkLazyPaint;
@@ -1940,8 +1918,6 @@
 }
 
 void SkCanvas::internalDrawPaint(const SkPaint& paint) {
-    CHECK_SHADER_NOSETCONTEXT(paint);
-
     LOOPER_BEGIN(paint, SkDrawFilter::kPaint_Type, NULL)
 
     while (iter.next()) {
@@ -1957,8 +1933,6 @@
         return;
     }
 
-    CHECK_SHADER_NOSETCONTEXT(paint);
-
     SkRect r, storage;
     const SkRect* bounds = NULL;
     if (paint.canComputeFastBounds()) {
@@ -1986,8 +1960,6 @@
 }
 
 void SkCanvas::drawRect(const SkRect& r, const SkPaint& paint) {
-    CHECK_SHADER_NOSETCONTEXT(paint);
-
     SkRect storage;
     const SkRect* bounds = NULL;
     if (paint.canComputeFastBounds()) {
@@ -2007,8 +1979,6 @@
 }
 
 void SkCanvas::drawOval(const SkRect& oval, const SkPaint& paint) {
-    CHECK_SHADER_NOSETCONTEXT(paint);
-
     SkRect storage;
     const SkRect* bounds = NULL;
     if (paint.canComputeFastBounds()) {
@@ -2028,8 +1998,6 @@
 }
 
 void SkCanvas::drawRRect(const SkRRect& rrect, const SkPaint& paint) {
-    CHECK_SHADER_NOSETCONTEXT(paint);
-
     SkRect storage;
     const SkRect* bounds = NULL;
     if (paint.canComputeFastBounds()) {
@@ -2060,8 +2028,6 @@
 
 void SkCanvas::onDrawDRRect(const SkRRect& outer, const SkRRect& inner,
                             const SkPaint& paint) {
-    CHECK_SHADER_NOSETCONTEXT(paint);
-
     SkRect storage;
     const SkRect* bounds = NULL;
     if (paint.canComputeFastBounds()) {
@@ -2081,8 +2047,6 @@
 }
 
 void SkCanvas::drawPath(const SkPath& path, const SkPaint& paint) {
-    CHECK_SHADER_NOSETCONTEXT(paint);
-
     if (!path.isFinite()) {
         return;
     }
@@ -2358,8 +2322,6 @@
 
 void SkCanvas::onDrawText(const void* text, size_t byteLength, SkScalar x, SkScalar y,
                           const SkPaint& paint) {
-    CHECK_SHADER_NOSETCONTEXT(paint);
-
     LOOPER_BEGIN(paint, SkDrawFilter::kText_Type, NULL)
 
     while (iter.next()) {
@@ -2374,10 +2336,8 @@
 
 void SkCanvas::onDrawPosText(const void* text, size_t byteLength, const SkPoint pos[],
                              const SkPaint& paint) {
-    CHECK_SHADER_NOSETCONTEXT(paint);
-    
     LOOPER_BEGIN(paint, SkDrawFilter::kText_Type, NULL)
-    
+
     while (iter.next()) {
         SkDeviceFilteredPaint dfp(iter.fDevice, looper.paint());
         iter.fDevice->drawPosText(iter, text, byteLength, &pos->fX, 0, 2,
@@ -2389,10 +2349,8 @@
 
 void SkCanvas::onDrawPosTextH(const void* text, size_t byteLength, const SkScalar xpos[],
                               SkScalar constY, const SkPaint& paint) {
-    CHECK_SHADER_NOSETCONTEXT(paint);
-    
     LOOPER_BEGIN(paint, SkDrawFilter::kText_Type, NULL)
-    
+
     while (iter.next()) {
         SkDeviceFilteredPaint dfp(iter.fDevice, looper.paint());
         iter.fDevice->drawPosText(iter, text, byteLength, xpos, constY, 1,
@@ -2404,10 +2362,8 @@
 
 void SkCanvas::onDrawTextOnPath(const void* text, size_t byteLength, const SkPath& path,
                                 const SkMatrix* matrix, const SkPaint& paint) {
-    CHECK_SHADER_NOSETCONTEXT(paint);
-    
     LOOPER_BEGIN(paint, SkDrawFilter::kText_Type, NULL)
-    
+
     while (iter.next()) {
         iter.fDevice->drawTextOnPath(iter, text, byteLength, path,
                                      matrix, looper.paint());
@@ -2439,8 +2395,6 @@
                             const SkColor colors[], SkXfermode* xmode,
                             const uint16_t indices[], int indexCount,
                             const SkPaint& paint) {
-    CHECK_SHADER_NOSETCONTEXT(paint);
-
     LOOPER_BEGIN(paint, SkDrawFilter::kPath_Type, NULL)
 
     while (iter.next()) {
diff --git a/src/core/SkComposeShader.cpp b/src/core/SkComposeShader.cpp
index f53eedf..77bc46f 100644
--- a/src/core/SkComposeShader.cpp
+++ b/src/core/SkComposeShader.cpp
@@ -45,6 +45,10 @@
     fShaderA->unref();
 }
 
+size_t SkComposeShader::contextSize() const {
+    return sizeof(ComposeShaderContext) + fShaderA->contextSize() + fShaderB->contextSize();
+}
+
 class SkAutoAlphaRestore {
 public:
     SkAutoAlphaRestore(SkPaint* paint, uint8_t newAlpha) {
@@ -69,17 +73,16 @@
     buffer.writeFlattenable(fMode);
 }
 
-/*  We call setContext on our two worker shaders. However, we
-    always let them see opaque alpha, and if the paint really
-    is translucent, then we apply that after the fact.
+/*  We call validContext/createContext on our two worker shaders.
+    However, we always let them see opaque alpha, and if the paint
+    really is translucent, then we apply that after the fact.
 
-    We need to keep the calls to setContext/endContext balanced, since if we
-    return false, our endContext() will not be called.
  */
-bool SkComposeShader::setContext(const SkBitmap& device,
-                                 const SkPaint& paint,
-                                 const SkMatrix& matrix) {
-    if (!this->INHERITED::setContext(device, paint, matrix)) {
+bool SkComposeShader::validContext(const SkBitmap& device,
+                                   const SkPaint& paint,
+                                   const SkMatrix& matrix,
+                                   SkMatrix* totalInverse) const {
+    if (!this->INHERITED::validContext(device, paint, matrix, totalInverse)) {
         return false;
     }
 
@@ -90,38 +93,62 @@
 
     tmpM.setConcat(matrix, this->getLocalMatrix());
 
+    return fShaderA->validContext(device, paint, tmpM) &&
+           fShaderB->validContext(device, paint, tmpM);
+}
+
+SkShader::Context* SkComposeShader::createContext(const SkBitmap& device, const SkPaint& paint,
+                                                  const SkMatrix& matrix, void* storage) const {
+    if (!this->validContext(device, paint, matrix)) {
+        return NULL;
+    }
+
+    // we preconcat our localMatrix (if any) with the device matrix
+    // before calling our sub-shaders
+
+    SkMatrix tmpM;
+
+    tmpM.setConcat(matrix, this->getLocalMatrix());
+
     SkAutoAlphaRestore  restore(const_cast<SkPaint*>(&paint), 0xFF);
 
-    bool setContextA = fShaderA->setContext(device, paint, tmpM);
-    bool setContextB = fShaderB->setContext(device, paint, tmpM);
-    if (!setContextA || !setContextB) {
-        if (setContextB) {
-            fShaderB->endContext();
-        }
-        else if (setContextA) {
-            fShaderA->endContext();
-        }
-        this->INHERITED::endContext();
-        return false;
-    }
-    return true;
+    char* aStorage = (char*) storage + sizeof(ComposeShaderContext);
+    char* bStorage = aStorage + fShaderA->contextSize();
+
+    SkShader::Context* contextA = fShaderA->createContext(device, paint, tmpM, aStorage);
+    SkShader::Context* contextB = fShaderB->createContext(device, paint, tmpM, bStorage);
+
+    // Both functions must succeed; otherwise validContext should have returned
+    // false.
+    SkASSERT(contextA);
+    SkASSERT(contextB);
+
+    return SkNEW_PLACEMENT_ARGS(storage, ComposeShaderContext,
+                                (*this, device, paint, matrix, contextA, contextB));
 }
 
-void SkComposeShader::endContext() {
-    fShaderB->endContext();
-    fShaderA->endContext();
-    this->INHERITED::endContext();
+SkComposeShader::ComposeShaderContext::ComposeShaderContext(
+        const SkComposeShader& shader, const SkBitmap& device,
+        const SkPaint& paint, const SkMatrix& matrix,
+        SkShader::Context* contextA, SkShader::Context* contextB)
+    : INHERITED(shader, device, paint, matrix)
+    , fShaderContextA(contextA)
+    , fShaderContextB(contextB) {}
+
+SkComposeShader::ComposeShaderContext::~ComposeShaderContext() {
+    fShaderContextA->~Context();
+    fShaderContextB->~Context();
 }
 
 // larger is better (fewer times we have to loop), but we shouldn't
 // take up too much stack-space (each element is 4 bytes)
 #define TMP_COLOR_COUNT     64
 
-void SkComposeShader::shadeSpan(int x, int y, SkPMColor result[], int count) {
-    SkShader*   shaderA = fShaderA;
-    SkShader*   shaderB = fShaderB;
-    SkXfermode* mode = fMode;
-    unsigned    scale = SkAlpha255To256(this->getPaintAlpha());
+void SkComposeShader::ComposeShaderContext::shadeSpan(int x, int y, SkPMColor result[], int count) {
+    SkShader::Context* shaderContextA = fShaderContextA;
+    SkShader::Context* shaderContextB = fShaderContextB;
+    SkXfermode*        mode = static_cast<const SkComposeShader&>(fShader).fMode;
+    unsigned           scale = SkAlpha255To256(this->getPaintAlpha());
 
     SkPMColor   tmp[TMP_COLOR_COUNT];
 
@@ -134,8 +161,8 @@
                 n = TMP_COLOR_COUNT;
             }
 
-            shaderA->shadeSpan(x, y, result, n);
-            shaderB->shadeSpan(x, y, tmp, n);
+            shaderContextA->shadeSpan(x, y, result, n);
+            shaderContextB->shadeSpan(x, y, tmp, n);
 
             if (256 == scale) {
                 for (int i = 0; i < n; i++) {
@@ -159,8 +186,8 @@
                 n = TMP_COLOR_COUNT;
             }
 
-            shaderA->shadeSpan(x, y, result, n);
-            shaderB->shadeSpan(x, y, tmp, n);
+            shaderContextA->shadeSpan(x, y, result, n);
+            shaderContextB->shadeSpan(x, y, tmp, n);
             mode->xfer32(result, tmp, n, NULL);
 
             if (256 == scale) {
diff --git a/src/core/SkCoreBlitters.h b/src/core/SkCoreBlitters.h
index 2851840..2d22d38 100644
--- a/src/core/SkCoreBlitters.h
+++ b/src/core/SkCoreBlitters.h
@@ -27,12 +27,29 @@
 
 class SkShaderBlitter : public SkRasterBlitter {
 public:
-    SkShaderBlitter(const SkBitmap& device, const SkPaint& paint);
+    /**
+      *  The storage for shaderContext is owned by the caller, but the object itself is not.
+      *  The blitter only ensures that the storage always holds a live object, but it may
+      *  exchange that object.
+      */
+    SkShaderBlitter(const SkBitmap& device, const SkPaint& paint,
+                    SkShader::Context* shaderContext);
     virtual ~SkShaderBlitter();
 
+    /**
+      *  Create a new shader context and uses it instead of the old one if successful.
+      *  Will create the context at the same location as the old one (this is safe
+      *  because the shader itself is unchanged).
+      */
+    virtual bool resetShaderContext(const SkBitmap& device, const SkPaint& paint,
+                                    const SkMatrix& matrix) SK_OVERRIDE;
+
+    virtual SkShader::Context* getShaderContext() const SK_OVERRIDE { return fShaderContext; }
+
 protected:
-    uint32_t    fShaderFlags;
-    SkShader*   fShader;
+    uint32_t            fShaderFlags;
+    const SkShader*     fShader;
+    SkShader::Context*  fShaderContext;
 
 private:
     // illegal
@@ -75,7 +92,8 @@
 
 class SkA8_Shader_Blitter : public SkShaderBlitter {
 public:
-    SkA8_Shader_Blitter(const SkBitmap& device, const SkPaint& paint);
+    SkA8_Shader_Blitter(const SkBitmap& device, const SkPaint& paint,
+                        SkShader::Context* shaderContext);
     virtual ~SkA8_Shader_Blitter();
     virtual void blitH(int x, int y, int width);
     virtual void blitAntiH(int x, int y, const SkAlpha antialias[], const int16_t runs[]);
@@ -141,7 +159,8 @@
 
 class SkARGB32_Shader_Blitter : public SkShaderBlitter {
 public:
-    SkARGB32_Shader_Blitter(const SkBitmap& device, const SkPaint& paint);
+    SkARGB32_Shader_Blitter(const SkBitmap& device, const SkPaint& paint,
+                            SkShader::Context* shaderContext);
     virtual ~SkARGB32_Shader_Blitter();
     virtual void blitH(int x, int y, int width) SK_OVERRIDE;
     virtual void blitV(int x, int y, int height, SkAlpha alpha) SK_OVERRIDE;
@@ -179,6 +198,7 @@
  */
 
 SkBlitter* SkBlitter_ChooseD565(const SkBitmap& device, const SkPaint& paint,
+                                SkShader::Context* shaderContext,
                                 SkTBlitterAllocator* allocator);
 
 #endif
diff --git a/src/core/SkDraw.cpp b/src/core/SkDraw.cpp
index 7eb0be6..6ddd0d2 100644
--- a/src/core/SkDraw.cpp
+++ b/src/core/SkDraw.cpp
@@ -2354,9 +2354,26 @@
 public:
     SkTriColorShader() {}
 
-    bool setup(const SkPoint pts[], const SkColor colors[], int, int, int);
+    virtual SkShader::Context* createContext(
+            const SkBitmap&, const SkPaint&, const SkMatrix&, void*) const SK_OVERRIDE;
+    virtual size_t contextSize() const SK_OVERRIDE;
 
-    virtual void shadeSpan(int x, int y, SkPMColor dstC[], int count) SK_OVERRIDE;
+    class TriColorShaderContext : public SkShader::Context {
+    public:
+        TriColorShaderContext(const SkTriColorShader& shader, const SkBitmap& device,
+                              const SkPaint& paint, const SkMatrix& matrix);
+        virtual ~TriColorShaderContext();
+
+        bool setup(const SkPoint pts[], const SkColor colors[], int, int, int);
+
+        virtual void shadeSpan(int x, int y, SkPMColor dstC[], int count) SK_OVERRIDE;
+
+    private:
+        SkMatrix    fDstToUnit;
+        SkPMColor   fColors[3];
+
+        typedef SkShader::Context INHERITED;
+    };
 
     SK_TO_STRING_OVERRIDE()
     SK_DECLARE_PUBLIC_FLATTENABLE_DESERIALIZATION_PROCS(SkTriColorShader)
@@ -2365,14 +2382,20 @@
     SkTriColorShader(SkReadBuffer& buffer) : SkShader(buffer) {}
 
 private:
-    SkMatrix    fDstToUnit;
-    SkPMColor   fColors[3];
-
     typedef SkShader INHERITED;
 };
 
-bool SkTriColorShader::setup(const SkPoint pts[], const SkColor colors[],
-                             int index0, int index1, int index2) {
+SkShader::Context* SkTriColorShader::createContext(const SkBitmap& device, const SkPaint& paint,
+                                                   const SkMatrix& matrix, void* storage) const {
+    if (!this->validContext(device, paint, matrix)) {
+        return NULL;
+    }
+
+    return SkNEW_PLACEMENT_ARGS(storage, TriColorShaderContext, (*this, device, paint, matrix));
+}
+
+bool SkTriColorShader::TriColorShaderContext::setup(const SkPoint pts[], const SkColor colors[],
+                                                    int index0, int index1, int index2) {
 
     fColors[0] = SkPreMultiplyColor(colors[index0]);
     fColors[1] = SkPreMultiplyColor(colors[index1]);
@@ -2407,7 +2430,18 @@
     return SkAlpha255To256(scale);
 }
 
-void SkTriColorShader::shadeSpan(int x, int y, SkPMColor dstC[], int count) {
+
+SkTriColorShader::TriColorShaderContext::TriColorShaderContext(
+        const SkTriColorShader& shader, const SkBitmap& device,
+        const SkPaint& paint, const SkMatrix& matrix)
+    : INHERITED(shader, device, paint, matrix) {}
+
+SkTriColorShader::TriColorShaderContext::~TriColorShaderContext() {}
+
+size_t SkTriColorShader::contextSize() const {
+    return sizeof(TriColorShaderContext);
+}
+void SkTriColorShader::TriColorShaderContext::shadeSpan(int x, int y, SkPMColor dstC[], int count) {
     SkPoint src;
 
     for (int i = 0; i < count; i++) {
@@ -2492,6 +2526,7 @@
     }
 
     // setup the custom shader (if needed)
+    SkAutoTUnref<SkComposeShader> composeShader;
     if (NULL != colors) {
         if (NULL == textures) {
             // just colors (no texture)
@@ -2504,9 +2539,8 @@
                 xmode = SkXfermode::Create(SkXfermode::kModulate_Mode);
                 releaseMode = true;
             }
-            SkShader* compose = SkNEW_ARGS(SkComposeShader,
-                                           (&triShader, shader, xmode));
-            p.setShader(compose)->unref();
+            composeShader.reset(SkNEW_ARGS(SkComposeShader, (&triShader, shader, xmode)));
+            p.setShader(composeShader);
             if (releaseMode) {
                 xmode->unref();
             }
@@ -2514,9 +2548,7 @@
     }
 
     SkAutoBlitterChoose blitter(*fBitmap, *fMatrix, p);
-    // important that we abort early, as below we may manipulate the shader
-    // and that is only valid if the shader returned true from setContext.
-    // If it returned false, then our blitter will be the NullBlitter.
+    // Abort early if we failed to create a shader context.
     if (blitter->isNullBlitter()) {
         return;
     }
@@ -2532,30 +2564,38 @@
             savedLocalM = shader->getLocalMatrix();
         }
 
-        // setContext has already been called and verified to return true
-        // by the constructor of SkAutoBlitterChoose
-        bool prevContextSuccess = true;
         while (vertProc(&state)) {
             if (NULL != textures) {
                 if (texture_to_matrix(state, vertices, textures, &tempM)) {
                     tempM.postConcat(savedLocalM);
                     shader->setLocalMatrix(tempM);
-                    // Need to recall setContext since we changed the local matrix.
-                    // However, we also need to balance the calls this with a
-                    // call to endContext which requires tracking the result of
-                    // the previous call to setContext.
-                    if (prevContextSuccess) {
-                        shader->endContext();
-                    }
-                    prevContextSuccess = shader->setContext(*fBitmap, p, *fMatrix);
-                    if (!prevContextSuccess) {
+                    if (!blitter->resetShaderContext(*fBitmap, p, *fMatrix)) {
                         continue;
                     }
                 }
             }
             if (NULL != colors) {
-                if (!triShader.setup(vertices, colors,
-                                     state.f0, state.f1, state.f2)) {
+                // Find the context for triShader.
+                SkTriColorShader::TriColorShaderContext* triColorShaderContext;
+
+                SkShader::Context* shaderContext = blitter->getShaderContext();
+                SkASSERT(shaderContext);
+                if (p.getShader() == &triShader) {
+                    triColorShaderContext =
+                            static_cast<SkTriColorShader::TriColorShaderContext*>(shaderContext);
+                } else {
+                    // The shader is a compose shader and triShader is its first shader.
+                    SkASSERT(p.getShader() == composeShader);
+                    SkASSERT(composeShader->getShaderA() == &triShader);
+                    SkComposeShader::ComposeShaderContext* composeShaderContext =
+                            static_cast<SkComposeShader::ComposeShaderContext*>(shaderContext);
+                    SkShader::Context* shaderContextA = composeShaderContext->getShaderContextA();
+                    triColorShaderContext =
+                            static_cast<SkTriColorShader::TriColorShaderContext*>(shaderContextA);
+                }
+
+                if (!triColorShaderContext->setup(vertices, colors,
+                                                  state.f0, state.f1, state.f2)) {
                     continue;
                 }
             }
@@ -2570,13 +2610,6 @@
         if (NULL != shader) {
             shader->setLocalMatrix(savedLocalM);
         }
-
-        // If the final call to setContext fails we must make it suceed so that the
-        // call to endContext in the destructor for SkAutoBlitterChoose is balanced.
-        if (!prevContextSuccess) {
-            prevContextSuccess = shader->setContext(*fBitmap, paint, SkMatrix::I());
-            SkASSERT(prevContextSuccess);
-        }
     } else {
         // no colors[] and no texture
         HairProc hairProc = ChooseHairProc(paint.isAntiAlias());
diff --git a/src/core/SkFilterShader.cpp b/src/core/SkFilterShader.cpp
index 5896191..5c5e8f3 100644
--- a/src/core/SkFilterShader.cpp
+++ b/src/core/SkFilterShader.cpp
@@ -38,9 +38,11 @@
     buffer.writeFlattenable(fFilter);
 }
 
-uint32_t SkFilterShader::getFlags() {
-    uint32_t shaderF = fShader->getFlags();
-    uint32_t filterF = fFilter->getFlags();
+uint32_t SkFilterShader::FilterShaderContext::getFlags() const {
+    const SkFilterShader& filterShader = static_cast<const SkFilterShader&>(fShader);
+
+    uint32_t shaderF = fShaderContext->getFlags();
+    uint32_t filterF = filterShader.fFilter->getFlags();
 
     // if the filter doesn't support 16bit, clear the matching bit in the shader
     if (!(filterF & SkColorFilter::kHasFilter16_Flag)) {
@@ -53,38 +55,62 @@
     return shaderF;
 }
 
-bool SkFilterShader::setContext(const SkBitmap& device,
-                                const SkPaint& paint,
-                                const SkMatrix& matrix) {
-    // we need to keep the setContext/endContext calls balanced. If we return
-    // false, our endContext() will not be called.
-
-    if (!this->INHERITED::setContext(device, paint, matrix)) {
-        return false;
+SkShader::Context* SkFilterShader::createContext(const SkBitmap& device,
+                                                 const SkPaint& paint,
+                                                 const SkMatrix& matrix,
+                                                 void* storage) const {
+    if (!this->validContext(device, paint, matrix)) {
+        return NULL;
     }
-    if (!fShader->setContext(device, paint, matrix)) {
-        this->INHERITED::endContext();
-        return false;
-    }
-    return true;
+
+    char* shaderContextStorage = (char*)storage + sizeof(FilterShaderContext);
+    SkShader::Context* shaderContext = fShader->createContext(device, paint, matrix,
+                                                              shaderContextStorage);
+    SkASSERT(shaderContext);
+
+    return SkNEW_PLACEMENT_ARGS(storage, FilterShaderContext,
+                                (*this, shaderContext, device, paint, matrix));
 }
 
-void SkFilterShader::endContext() {
-    fShader->endContext();
-    this->INHERITED::endContext();
+size_t SkFilterShader::contextSize() const {
+    return sizeof(FilterShaderContext) + fShader->contextSize();
 }
 
-void SkFilterShader::shadeSpan(int x, int y, SkPMColor result[], int count) {
-    fShader->shadeSpan(x, y, result, count);
-    fFilter->filterSpan(result, count, result);
+bool SkFilterShader::validContext(const SkBitmap& device,
+                                  const SkPaint& paint,
+                                  const SkMatrix& matrix,
+                                  SkMatrix* totalInverse) const {
+    return this->INHERITED::validContext(device, paint, matrix, totalInverse) &&
+           fShader->validContext(device, paint, matrix);
 }
 
-void SkFilterShader::shadeSpan16(int x, int y, uint16_t result[], int count) {
-    SkASSERT(fShader->getFlags() & SkShader::kHasSpan16_Flag);
-    SkASSERT(fFilter->getFlags() & SkColorFilter::kHasFilter16_Flag);
+SkFilterShader::FilterShaderContext::FilterShaderContext(const SkFilterShader& filterShader,
+                                                         SkShader::Context* shaderContext,
+                                                         const SkBitmap& device,
+                                                         const SkPaint& paint,
+                                                         const SkMatrix& matrix)
+    : INHERITED(filterShader, device, paint, matrix)
+    , fShaderContext(shaderContext) {}
 
-    fShader->shadeSpan16(x, y, result, count);
-    fFilter->filterSpan16(result, count, result);
+SkFilterShader::FilterShaderContext::~FilterShaderContext() {
+    fShaderContext->~Context();
+}
+
+void SkFilterShader::FilterShaderContext::shadeSpan(int x, int y, SkPMColor result[], int count) {
+    const SkFilterShader& filterShader = static_cast<const SkFilterShader&>(fShader);
+
+    fShaderContext->shadeSpan(x, y, result, count);
+    filterShader.fFilter->filterSpan(result, count, result);
+}
+
+void SkFilterShader::FilterShaderContext::shadeSpan16(int x, int y, uint16_t result[], int count) {
+    const SkFilterShader& filterShader = static_cast<const SkFilterShader&>(fShader);
+
+    SkASSERT(fShaderContext->getFlags() & SkShader::kHasSpan16_Flag);
+    SkASSERT(filterShader.fFilter->getFlags() & SkColorFilter::kHasFilter16_Flag);
+
+    fShaderContext->shadeSpan16(x, y, result, count);
+    filterShader.fFilter->filterSpan16(result, count, result);
 }
 
 #ifndef SK_IGNORE_TO_STRING
diff --git a/src/core/SkFilterShader.h b/src/core/SkFilterShader.h
index 11add0c..4ef4577 100644
--- a/src/core/SkFilterShader.h
+++ b/src/core/SkFilterShader.h
@@ -17,12 +17,29 @@
     SkFilterShader(SkShader* shader, SkColorFilter* filter);
     virtual ~SkFilterShader();
 
-    virtual uint32_t getFlags() SK_OVERRIDE;
-    virtual bool setContext(const SkBitmap&, const SkPaint&,
-                            const SkMatrix&) SK_OVERRIDE;
-    virtual void endContext() SK_OVERRIDE;
-    virtual void shadeSpan(int x, int y, SkPMColor[], int count) SK_OVERRIDE;
-    virtual void shadeSpan16(int x, int y, uint16_t[], int count) SK_OVERRIDE;
+    virtual bool validContext(const SkBitmap&, const SkPaint&,
+                              const SkMatrix&, SkMatrix* totalInverse = NULL) const SK_OVERRIDE;
+    virtual SkShader::Context* createContext(const SkBitmap&, const SkPaint&,
+                                             const SkMatrix&, void* storage) const SK_OVERRIDE;
+    virtual size_t contextSize() const SK_OVERRIDE;
+
+    class FilterShaderContext : public SkShader::Context {
+    public:
+        // Takes ownership of shaderContext and calls its destructor.
+        FilterShaderContext(const SkFilterShader& filterShader, SkShader::Context* shaderContext,
+                            const SkBitmap& device, const SkPaint& paint, const SkMatrix& matrix);
+        virtual ~FilterShaderContext();
+
+        virtual uint32_t getFlags() const SK_OVERRIDE;
+
+        virtual void shadeSpan(int x, int y, SkPMColor[], int count) SK_OVERRIDE;
+        virtual void shadeSpan16(int x, int y, uint16_t[], int count) SK_OVERRIDE;
+
+    private:
+        SkShader::Context* fShaderContext;
+
+        typedef SkShader::Context INHERITED;
+    };
 
     SK_TO_STRING_OVERRIDE()
     SK_DECLARE_PUBLIC_FLATTENABLE_DESERIALIZATION_PROCS(SkFilterShader)
diff --git a/src/core/SkPictureShader.cpp b/src/core/SkPictureShader.cpp
index bf31285..dc5c90b 100644
--- a/src/core/SkPictureShader.cpp
+++ b/src/core/SkPictureShader.cpp
@@ -49,7 +49,7 @@
     fPicture->flatten(buffer);
 }
 
-bool SkPictureShader::buildBitmapShader(const SkMatrix& matrix) const {
+SkShader* SkPictureShader::refBitmapShader(const SkMatrix& matrix) const {
     SkASSERT(fPicture && fPicture->width() > 0 && fPicture->height() > 0);
 
     SkMatrix m;
@@ -70,17 +70,20 @@
 
     SkISize tileSize = scaledSize.toRound();
     if (tileSize.isEmpty()) {
-        return false;
+        return NULL;
     }
 
     // The actual scale, compensating for rounding.
     SkSize tileScale = SkSize::Make(SkIntToScalar(tileSize.width()) / fPicture->width(),
                                     SkIntToScalar(tileSize.height()) / fPicture->height());
 
-    if (!fCachedShader || tileScale != fCachedTileScale) {
+    SkAutoMutexAcquire ama(fCachedBitmapShaderMutex);
+
+    if (!fCachedBitmapShader || tileScale != fCachedTileScale ||
+        this->getLocalMatrix() != fCachedLocalMatrix) {
         SkBitmap bm;
         if (!bm.allocN32Pixels(tileSize.width(), tileSize.height())) {
-            return false;
+            return NULL;
         }
         bm.eraseColor(SK_ColorTRANSPARENT);
 
@@ -88,66 +91,91 @@
         canvas.scale(tileScale.width(), tileScale.height());
         canvas.drawPicture(*fPicture);
 
-        fCachedShader.reset(CreateBitmapShader(bm, fTmx, fTmy));
+        fCachedBitmapShader.reset(CreateBitmapShader(bm, fTmx, fTmy));
         fCachedTileScale = tileScale;
+        fCachedLocalMatrix = this->getLocalMatrix();
+
+        SkMatrix shaderMatrix = this->getLocalMatrix();
+        shaderMatrix.preScale(1 / tileScale.width(), 1 / tileScale.height());
+        fCachedBitmapShader->setLocalMatrix(shaderMatrix);
     }
 
-    SkMatrix shaderMatrix = this->getLocalMatrix();
-    shaderMatrix.preScale(1 / tileScale.width(), 1 / tileScale.height());
-    fCachedShader->setLocalMatrix(shaderMatrix);
-
-    return true;
+    // Increment the ref counter inside the mutex to ensure the returned pointer is still valid.
+    // Otherwise, the pointer may have been overwritten on a different thread before the object's
+    // ref count was incremented.
+    fCachedBitmapShader.get()->ref();
+    return fCachedBitmapShader;
 }
 
-bool SkPictureShader::setContext(const SkBitmap& device,
-                                 const SkPaint& paint,
-                                 const SkMatrix& matrix) {
-    if (!this->buildBitmapShader(matrix)) {
-        return false;
+SkShader* SkPictureShader::validInternal(const SkBitmap& device, const SkPaint& paint,
+                                         const SkMatrix& matrix, SkMatrix* totalInverse) const {
+    if (!this->INHERITED::validContext(device, paint, matrix, totalInverse)) {
+        return NULL;
     }
 
-    if (!this->INHERITED::setContext(device, paint, matrix)) {
-        return false;
+    SkAutoTUnref<SkShader> bitmapShader(this->refBitmapShader(matrix));
+    if (!bitmapShader || !bitmapShader->validContext(device, paint, matrix)) {
+        return NULL;
     }
 
-    SkASSERT(fCachedShader);
-    if (!fCachedShader->setContext(device, paint, matrix)) {
-        this->INHERITED::endContext();
-        return false;
+    return bitmapShader.detach();
+}
+
+bool SkPictureShader::validContext(const SkBitmap& device, const SkPaint& paint,
+                                   const SkMatrix& matrix, SkMatrix* totalInverse) const {
+    SkAutoTUnref<SkShader> shader(this->validInternal(device, paint, matrix, totalInverse));
+    return shader != NULL;
+}
+
+SkShader::Context* SkPictureShader::createContext(const SkBitmap& device, const SkPaint& paint,
+                                                  const SkMatrix& matrix, void* storage) const {
+    SkAutoTUnref<SkShader> bitmapShader(this->validInternal(device, paint, matrix, NULL));
+    if (!bitmapShader) {
+        return NULL;
     }
 
-    return true;
+    return SkNEW_PLACEMENT_ARGS(storage, PictureShaderContext,
+                                (*this, device, paint, matrix, bitmapShader.detach()));
 }
 
-void SkPictureShader::endContext() {
-    SkASSERT(fCachedShader);
-    fCachedShader->endContext();
-
-    this->INHERITED::endContext();
+size_t SkPictureShader::contextSize() const {
+    return sizeof(PictureShaderContext);
 }
 
-uint32_t SkPictureShader::getFlags() {
-    if (NULL != fCachedShader) {
-        return fCachedShader->getFlags();
-    }
-    return 0;
+SkPictureShader::PictureShaderContext::PictureShaderContext(
+        const SkPictureShader& shader, const SkBitmap& device,
+        const SkPaint& paint, const SkMatrix& matrix, SkShader* bitmapShader)
+    : INHERITED(shader, device, paint, matrix)
+    , fBitmapShader(bitmapShader)
+{
+    SkASSERT(fBitmapShader);
+    fBitmapShaderContextStorage = sk_malloc_throw(fBitmapShader->contextSize());
+    fBitmapShaderContext = fBitmapShader->createContext(
+            device, paint, matrix, fBitmapShaderContextStorage);
+    SkASSERT(fBitmapShaderContext);
 }
 
-SkShader::ShadeProc SkPictureShader::asAShadeProc(void** ctx) {
-    if (fCachedShader) {
-        return fCachedShader->asAShadeProc(ctx);
-    }
-    return NULL;
+SkPictureShader::PictureShaderContext::~PictureShaderContext() {
+    fBitmapShaderContext->~Context();
+    sk_free(fBitmapShaderContextStorage);
 }
 
-void SkPictureShader::shadeSpan(int x, int y, SkPMColor dstC[], int count) {
-    SkASSERT(fCachedShader);
-    fCachedShader->shadeSpan(x, y, dstC, count);
+uint32_t SkPictureShader::PictureShaderContext::getFlags() const {
+    return fBitmapShaderContext->getFlags();
 }
 
-void SkPictureShader::shadeSpan16(int x, int y, uint16_t dstC[], int count) {
-    SkASSERT(fCachedShader);
-    fCachedShader->shadeSpan16(x, y, dstC, count);
+SkShader::Context::ShadeProc SkPictureShader::PictureShaderContext::asAShadeProc(void** ctx) {
+    return fBitmapShaderContext->asAShadeProc(ctx);
+}
+
+void SkPictureShader::PictureShaderContext::shadeSpan(int x, int y, SkPMColor dstC[], int count) {
+    SkASSERT(fBitmapShaderContext);
+    fBitmapShaderContext->shadeSpan(x, y, dstC, count);
+}
+
+void SkPictureShader::PictureShaderContext::shadeSpan16(int x, int y, uint16_t dstC[], int count) {
+    SkASSERT(fBitmapShaderContext);
+    fBitmapShaderContext->shadeSpan16(x, y, dstC, count);
 }
 
 #ifndef SK_IGNORE_TO_STRING
@@ -168,10 +196,10 @@
 
 #if SK_SUPPORT_GPU
 GrEffectRef* SkPictureShader::asNewEffect(GrContext* context, const SkPaint& paint) const {
-    if (!this->buildBitmapShader(context->getMatrix())) {
+    SkAutoTUnref<SkShader> bitmapShader(this->refBitmapShader(context->getMatrix()));
+    if (!bitmapShader) {
         return NULL;
     }
-    SkASSERT(fCachedShader);
-    return fCachedShader->asNewEffect(context, paint);
+    return bitmapShader->asNewEffect(context, paint);
 }
 #endif
diff --git a/src/core/SkPictureShader.h b/src/core/SkPictureShader.h
index ea74b56..d1be059 100644
--- a/src/core/SkPictureShader.h
+++ b/src/core/SkPictureShader.h
@@ -24,13 +24,33 @@
     static SkPictureShader* Create(SkPicture*, TileMode, TileMode);
     virtual ~SkPictureShader();
 
-    virtual bool setContext(const SkBitmap&, const SkPaint&, const SkMatrix&) SK_OVERRIDE;
-    virtual void endContext() SK_OVERRIDE;
-    virtual uint32_t getFlags() SK_OVERRIDE;
+    virtual bool validContext(const SkBitmap&, const SkPaint&,
+                              const SkMatrix&, SkMatrix* totalInverse = NULL) const SK_OVERRIDE;
+    virtual SkShader::Context* createContext(const SkBitmap& device, const SkPaint& paint,
+                                             const SkMatrix& matrix, void* storage) const
+            SK_OVERRIDE;
+    virtual size_t contextSize() const SK_OVERRIDE;
 
-    virtual ShadeProc asAShadeProc(void** ctx) SK_OVERRIDE;
-    virtual void shadeSpan(int x, int y, SkPMColor dstC[], int count) SK_OVERRIDE;
-    virtual void shadeSpan16(int x, int y, uint16_t dstC[], int count) SK_OVERRIDE;
+    class PictureShaderContext : public SkShader::Context {
+    public:
+        PictureShaderContext(const SkPictureShader& shader, const SkBitmap& device,
+                             const SkPaint& paint, const SkMatrix& matrix,
+                             SkShader* bitmapShader);
+        virtual ~PictureShaderContext();
+
+        virtual uint32_t getFlags() const SK_OVERRIDE;
+
+        virtual ShadeProc asAShadeProc(void** ctx) SK_OVERRIDE;
+        virtual void shadeSpan(int x, int y, SkPMColor dstC[], int count) SK_OVERRIDE;
+        virtual void shadeSpan16(int x, int y, uint16_t dstC[], int count) SK_OVERRIDE;
+
+    private:
+        SkAutoTUnref<SkShader>  fBitmapShader;
+        SkShader::Context*      fBitmapShaderContext;
+        void*                   fBitmapShaderContextStorage;
+
+        typedef SkShader::Context INHERITED;
+    };
 
     SK_TO_STRING_OVERRIDE()
     SK_DECLARE_PUBLIC_FLATTENABLE_DESERIALIZATION_PROCS(SkPictureShader)
@@ -46,13 +66,18 @@
 private:
     SkPictureShader(SkPicture*, TileMode, TileMode);
 
-    bool buildBitmapShader(const SkMatrix&) const;
+    SkShader* validInternal(const SkBitmap& device, const SkPaint& paint,
+                            const SkMatrix& matrix, SkMatrix* totalInverse) const;
+
+    SkShader* refBitmapShader(const SkMatrix&) const;
 
     SkPicture*  fPicture;
     TileMode    fTmx, fTmy;
 
-    mutable SkAutoTUnref<SkShader>  fCachedShader;
+    mutable SkMutex                 fCachedBitmapShaderMutex;
+    mutable SkAutoTUnref<SkShader>  fCachedBitmapShader;
     mutable SkSize                  fCachedTileScale;
+    mutable SkMatrix                fCachedLocalMatrix;
 
     typedef SkShader INHERITED;
 };
diff --git a/src/core/SkShader.cpp b/src/core/SkShader.cpp
index e337b7d..40e52a0 100644
--- a/src/core/SkShader.cpp
+++ b/src/core/SkShader.cpp
@@ -17,7 +17,6 @@
 
 SkShader::SkShader() {
     fLocalMatrix.reset();
-    SkDEBUGCODE(fInSetContext = false;)
 }
 
 SkShader::SkShader(SkReadBuffer& buffer)
@@ -27,12 +26,9 @@
     } else {
         fLocalMatrix.reset();
     }
-
-    SkDEBUGCODE(fInSetContext = false;)
 }
 
 SkShader::~SkShader() {
-    SkASSERT(!fInSetContext);
 }
 
 void SkShader::flatten(SkWriteBuffer& buffer) const {
@@ -44,39 +40,48 @@
     }
 }
 
-bool SkShader::setContext(const SkBitmap& device,
-                          const SkPaint& paint,
-                          const SkMatrix& matrix) {
-    SkASSERT(!this->setContextHasBeenCalled());
-
+bool SkShader::computeTotalInverse(const SkMatrix& matrix, SkMatrix* totalInverse) const {
     const SkMatrix* m = &matrix;
     SkMatrix        total;
 
-    fPaintAlpha = paint.getAlpha();
     if (this->hasLocalMatrix()) {
         total.setConcat(matrix, this->getLocalMatrix());
         m = &total;
     }
-    if (m->invert(&fTotalInverse)) {
-        fTotalInverseClass = (uint8_t)ComputeMatrixClass(fTotalInverse);
-        SkDEBUGCODE(fInSetContext = true;)
-        return true;
-    }
-    return false;
+
+    return m->invert(totalInverse);
 }
 
-void SkShader::endContext() {
-    SkASSERT(fInSetContext);
-    SkDEBUGCODE(fInSetContext = false;)
+bool SkShader::validContext(const SkBitmap& device,
+                            const SkPaint& paint,
+                            const SkMatrix& matrix,
+                            SkMatrix* totalInverse) const {
+    return this->computeTotalInverse(matrix, totalInverse);
 }
 
-SkShader::ShadeProc SkShader::asAShadeProc(void** ctx) {
+SkShader::Context::Context(const SkShader& shader, const SkBitmap& device,
+                           const SkPaint& paint, const SkMatrix& matrix)
+    : fShader(shader)
+{
+    SkASSERT(fShader.validContext(device, paint, matrix));
+
+    // Because the context parameters must be valid at this point, we know that the matrix is
+    // invertible.
+    SkAssertResult(fShader.computeTotalInverse(matrix, &fTotalInverse));
+    fTotalInverseClass = (uint8_t)ComputeMatrixClass(fTotalInverse);
+
+    fPaintAlpha = paint.getAlpha();
+}
+
+SkShader::Context::~Context() {}
+
+SkShader::Context::ShadeProc SkShader::Context::asAShadeProc(void** ctx) {
     return NULL;
 }
 
 #include "SkColorPriv.h"
 
-void SkShader::shadeSpan16(int x, int y, uint16_t span16[], int count) {
+void SkShader::Context::shadeSpan16(int x, int y, uint16_t span16[], int count) {
     SkASSERT(span16);
     SkASSERT(count > 0);
     SkASSERT(this->canCallShadeSpan16());
@@ -94,7 +99,7 @@
     #define SkU32BitShiftToByteOffset(shift)    ((shift) >> 3)
 #endif
 
-void SkShader::shadeSpanAlpha(int x, int y, uint8_t alpha[], int count) {
+void SkShader::Context::shadeSpanAlpha(int x, int y, uint8_t alpha[], int count) {
     SkASSERT(count > 0);
 
     SkPMColor   colors[kTempColorCount];
@@ -148,7 +153,7 @@
 #endif
 }
 
-SkShader::MatrixClass SkShader::ComputeMatrixClass(const SkMatrix& mat) {
+SkShader::Context::MatrixClass SkShader::Context::ComputeMatrixClass(const SkMatrix& mat) {
     MatrixClass mc = kLinear_MatrixClass;
 
     if (mat.hasPerspective()) {
@@ -163,8 +168,7 @@
 
 //////////////////////////////////////////////////////////////////////////////
 
-SkShader::BitmapType SkShader::asABitmap(SkBitmap*, SkMatrix*,
-                                         TileMode*) const {
+SkShader::BitmapType SkShader::asABitmap(SkBitmap*, SkMatrix*, TileMode*) const {
     return kNone_BitmapType;
 }
 
@@ -199,19 +203,16 @@
 #include "SkColorShader.h"
 #include "SkUtils.h"
 
-SkColorShader::SkColorShader() {
-    fFlags = 0;
-    fInheritColor = true;
+SkColorShader::SkColorShader()
+    : fColor()
+    , fInheritColor(true) {
 }
 
-SkColorShader::SkColorShader(SkColor c) {
-    fFlags = 0;
-    fColor = c;
-    fInheritColor = false;
+SkColorShader::SkColorShader(SkColor c)
+    : fColor(c)
+    , fInheritColor(false) {
 }
 
-SkColorShader::~SkColorShader() {}
-
 bool SkColorShader::isOpaque() const {
     if (fInheritColor) {
         return true; // using paint's alpha
@@ -220,8 +221,6 @@
 }
 
 SkColorShader::SkColorShader(SkReadBuffer& b) : INHERITED(b) {
-    fFlags = 0; // computed in setContext
-
     fInheritColor = b.readBool();
     if (fInheritColor) {
         return;
@@ -238,32 +237,43 @@
     buffer.writeColor(fColor);
 }
 
-uint32_t SkColorShader::getFlags() {
+uint32_t SkColorShader::ColorShaderContext::getFlags() const {
     return fFlags;
 }
 
-uint8_t SkColorShader::getSpan16Alpha() const {
+uint8_t SkColorShader::ColorShaderContext::getSpan16Alpha() const {
     return SkGetPackedA32(fPMColor);
 }
 
-bool SkColorShader::setContext(const SkBitmap& device, const SkPaint& paint,
-                               const SkMatrix& matrix) {
-    if (!this->INHERITED::setContext(device, paint, matrix)) {
-        return false;
+SkShader::Context* SkColorShader::createContext(const SkBitmap& device, const SkPaint& paint,
+                                                const SkMatrix& matrix, void* storage) const {
+    if (!this->validContext(device, paint, matrix)) {
+        return NULL;
     }
 
+    return SkNEW_PLACEMENT_ARGS(storage, ColorShaderContext, (*this, device, paint, matrix));
+}
+
+SkColorShader::ColorShaderContext::ColorShaderContext(const SkColorShader& shader,
+                                                      const SkBitmap& device,
+                                                      const SkPaint& paint,
+                                                      const SkMatrix& matrix)
+    : INHERITED(shader, device, paint, matrix)
+{
     unsigned a;
 
-    if (fInheritColor) {
-        fColor = paint.getColor();
-        a = SkColorGetA(fColor);
+    SkColor color;
+    if (shader.fInheritColor) {
+        color = paint.getColor();
+        a = SkColorGetA(color);
     } else {
-        a = SkAlphaMul(SkColorGetA(fColor), SkAlpha255To256(paint.getAlpha()));
+        color = shader.fColor;
+        a = SkAlphaMul(SkColorGetA(color), SkAlpha255To256(paint.getAlpha()));
     }
 
-    unsigned r = SkColorGetR(fColor);
-    unsigned g = SkColorGetG(fColor);
-    unsigned b = SkColorGetB(fColor);
+    unsigned r = SkColorGetR(color);
+    unsigned g = SkColorGetG(color);
+    unsigned b = SkColorGetB(color);
 
     // we want this before we apply any alpha
     fColor16 = SkPack888ToRGB16(r, g, b);
@@ -282,19 +292,17 @@
             fFlags |= kHasSpan16_Flag;
         }
     }
-
-    return true;
 }
 
-void SkColorShader::shadeSpan(int x, int y, SkPMColor span[], int count) {
+void SkColorShader::ColorShaderContext::shadeSpan(int x, int y, SkPMColor span[], int count) {
     sk_memset32(span, fPMColor, count);
 }
 
-void SkColorShader::shadeSpan16(int x, int y, uint16_t span[], int count) {
+void SkColorShader::ColorShaderContext::shadeSpan16(int x, int y, uint16_t span[], int count) {
     sk_memset16(span, fColor16, count);
 }
 
-void SkColorShader::shadeSpanAlpha(int x, int y, uint8_t alpha[], int count) {
+void SkColorShader::ColorShaderContext::shadeSpanAlpha(int x, int y, uint8_t alpha[], int count) {
     memset(alpha, SkGetPackedA32(fPMColor), count);
 }
 
@@ -334,27 +342,9 @@
 
 ///////////////////////////////////////////////////////////////////////////////
 
+#ifndef SK_IGNORE_TO_STRING
 #include "SkEmptyShader.h"
 
-uint32_t SkEmptyShader::getFlags() { return 0; }
-uint8_t SkEmptyShader::getSpan16Alpha() const { return 0; }
-
-bool SkEmptyShader::setContext(const SkBitmap&, const SkPaint&,
-                               const SkMatrix&) { return false; }
-
-void SkEmptyShader::shadeSpan(int x, int y, SkPMColor span[], int count) {
-    SkDEBUGFAIL("should never get called, since setContext() returned false");
-}
-
-void SkEmptyShader::shadeSpan16(int x, int y, uint16_t span[], int count) {
-    SkDEBUGFAIL("should never get called, since setContext() returned false");
-}
-
-void SkEmptyShader::shadeSpanAlpha(int x, int y, uint8_t alpha[], int count) {
-    SkDEBUGFAIL("should never get called, since setContext() returned false");
-}
-
-#ifndef SK_IGNORE_TO_STRING
 void SkEmptyShader::toString(SkString* str) const {
     str->append("SkEmptyShader: (");
 
diff --git a/src/core/SkSmallAllocator.h b/src/core/SkSmallAllocator.h
index 655008b..8d4b53a 100644
--- a/src/core/SkSmallAllocator.h
+++ b/src/core/SkSmallAllocator.h
@@ -117,10 +117,12 @@
             // but we're not sure we can catch all callers, so handle it but
             // assert false in debug mode.
             SkASSERT(false);
+            rec->fStorageSize = 0;
             rec->fHeapStorage = sk_malloc_throw(storageRequired);
             rec->fObj = static_cast<void*>(rec->fHeapStorage);
         } else {
             // There is space in fStorage.
+            rec->fStorageSize = storageRequired;
             rec->fHeapStorage = NULL;
             SkASSERT(SkIsAlign4(fStorageUsed));
             rec->fObj = static_cast<void*>(fStorage + (fStorageUsed / 4));
@@ -131,11 +133,26 @@
         return rec->fObj;
     }
 
+    /*
+     *  Free the memory reserved last without calling the destructor.
+     *  Can be used in a nested way, i.e. after reserving A and B, calling
+     *  freeLast once will free B and calling it again will free A.
+     */
+    void freeLast() {
+        SkASSERT(fNumObjects > 0);
+        Rec* rec = &fRecs[fNumObjects - 1];
+        sk_free(rec->fHeapStorage);
+        fStorageUsed -= rec->fStorageSize;
+
+        fNumObjects--;
+    }
+
 private:
     struct Rec {
-        void* fObj;
-        void* fHeapStorage;
-        void  (*fKillProc)(void*);
+        size_t fStorageSize;  // 0 if allocated on heap
+        void*  fObj;
+        void*  fHeapStorage;
+        void   (*fKillProc)(void*);
     };
 
     // Number of bytes used so far.