Refactor Skia shaders handling.

With this change, Skia shaders can easily be applied to any mesh. This change also
supports ComposeShader. For instance, this can be used to blend a gradient and a
bitmap togehter and paint a string of text with the result.

Change-Id: I701c2f9cf7f89b2ff58005e8a1d0d80ccf4a4aea
diff --git a/core/java/android/view/GLES20Canvas.java b/core/java/android/view/GLES20Canvas.java
index 78648ff..21e6793 100644
--- a/core/java/android/view/GLES20Canvas.java
+++ b/core/java/android/view/GLES20Canvas.java
@@ -17,21 +17,17 @@
 package android.view;
 
 import android.graphics.Bitmap;
-import android.graphics.BitmapShader;
 import android.graphics.Canvas;
 import android.graphics.DrawFilter;
-import android.graphics.LinearGradient;
 import android.graphics.Matrix;
 import android.graphics.Paint;
 import android.graphics.Path;
 import android.graphics.Picture;
 import android.graphics.PorterDuff;
-import android.graphics.RadialGradient;
 import android.graphics.Rect;
 import android.graphics.RectF;
 import android.graphics.Region;
 import android.graphics.Shader;
-import android.graphics.SweepGradient;
 import android.graphics.TemporaryBuffer;
 import android.text.GraphicsOperations;
 import android.text.SpannableString;
@@ -593,7 +589,9 @@
         if ((index | count | (index + count) | (text.length - index - count)) < 0) {
             throw new IndexOutOfBoundsException();
         }
+        boolean hasShader = setupShader(paint);
         nDrawText(mRenderer, text, index, count, x, y, paint.mBidiFlags, paint.mNativePaint);
+        if (hasShader) nResetShader(mRenderer);
     }
     
     private native void nDrawText(int renderer, char[] text, int index, int count, float x, float y,
@@ -601,6 +599,7 @@
 
     @Override
     public void drawText(CharSequence text, int start, int end, float x, float y, Paint paint) {
+        boolean hasShader = setupShader(paint);
         if (text instanceof String || text instanceof SpannedString ||
                 text instanceof SpannableString) {
             nDrawText(mRenderer, text.toString(), start, end, x, y, paint.mBidiFlags,
@@ -614,6 +613,7 @@
             nDrawText(mRenderer, buf, 0, end - start, x, y, paint.mBidiFlags, paint.mNativePaint);
             TemporaryBuffer.recycle(buf);
         }
+        if (hasShader) nResetShader(mRenderer);
     }
 
     @Override
@@ -621,7 +621,9 @@
         if ((start | end | (end - start) | (text.length() - end)) < 0) {
             throw new IndexOutOfBoundsException();
         }
+        boolean hasShader = setupShader(paint);
         nDrawText(mRenderer, text, start, end, x, y, paint.mBidiFlags, paint.mNativePaint);
+        if (hasShader) nResetShader(mRenderer);
     }
 
     private native void nDrawText(int renderer, String text, int start, int end, float x, float y,
@@ -629,7 +631,9 @@
 
     @Override
     public void drawText(String text, float x, float y, Paint paint) {
+        boolean hasShader = setupShader(paint);
         nDrawText(mRenderer, text, 0, text.length(), x, y, paint.mBidiFlags, paint.mNativePaint);
+        if (hasShader) nResetShader(mRenderer);
     }
 
     @Override
@@ -665,28 +669,12 @@
     private boolean setupShader(Paint paint) {
         final Shader shader = paint.getShader();
         if (shader != null) {
-            if (shader instanceof BitmapShader) {
-                final BitmapShader bs = (BitmapShader) shader;
-                nSetupBitmapShader(mRenderer, bs.native_instance, bs.mBitmap.mNativeBitmap,
-                        bs.mTileX, bs.mTileY, bs.mLocalMatrix);
-                return true;
-            } else if (shader instanceof LinearGradient) {
-                final LinearGradient ls = (LinearGradient) shader;
-                nSetupLinearShader(mRenderer, ls.native_instance, ls.bounds, ls.colors,
-                        ls.positions, ls.count, ls.tileMode, ls.mLocalMatrix);
-                return true;
-            } else if (shader instanceof RadialGradient) {
-                // TODO: Implement
-            } else if (shader instanceof SweepGradient) {
-                // TODO: Implement
-            }
+            nSetupShader(mRenderer, shader.native_shader);
+            return true;
         }
         return false;
     }
 
-    private native void nSetupLinearShader(int renderer, int shader, int bounds,
-            int colors, int positions, int count, int tileMode, int localMatrix);
-    private native void nSetupBitmapShader(int renderer, int shader, int bitmap,
-            int tileX, int tileY, int matrix);
+    private native void nSetupShader(int renderer, int shader);
     private native void nResetShader(int renderer);
 }
diff --git a/core/jni/android/graphics/Shader.cpp b/core/jni/android/graphics/Shader.cpp
index eb600c4..34b4ab5 100644
--- a/core/jni/android/graphics/Shader.cpp
+++ b/core/jni/android/graphics/Shader.cpp
@@ -8,12 +8,14 @@
 #include "SkTemplates.h"
 #include "SkXfermode.h"
 
+#include <SkiaShader.h>
+
+using namespace android::uirenderer;
+
 static struct {
     jclass clazz;
-    jfieldID bounds;
-    jfieldID colors;
-    jfieldID positions;
-} gLinearGradientClassInfo;
+    jfieldID shader;
+} gShaderClassInfo;
 
 static void ThrowIAE_IfNull(JNIEnv* env, void* ptr) {
     if (NULL == ptr) {
@@ -48,8 +50,9 @@
 
 ///////////////////////////////////////////////////////////////////////////////////////////////
 
-static void Shader_destructor(JNIEnv* env, jobject, SkShader* shader)
+static void Shader_destructor(JNIEnv* env, jobject o, SkShader* shader, SkiaShader* skiaShader)
 {
+    delete skiaShader;
     shader->safeUnref();
 }
 
@@ -58,7 +61,8 @@
     return shader ? shader->getLocalMatrix(matrix) : false;
 }
  
-static void Shader_setLocalMatrix(JNIEnv* env, jobject, SkShader* shader, const SkMatrix* matrix)
+static void Shader_setLocalMatrix(JNIEnv* env, jobject o, SkShader* shader, SkiaShader* skiaShader,
+        const SkMatrix* matrix)
 {
     if (shader) {
         if (NULL == matrix) {
@@ -67,30 +71,33 @@
         else {
             shader->setLocalMatrix(*matrix);
         }
+        skiaShader->setMatrix(const_cast<SkMatrix*>(matrix));
     }
 }
 
 ///////////////////////////////////////////////////////////////////////////////////////////////
 
-static SkShader* BitmapShader_constructor(JNIEnv* env, jobject, const SkBitmap* bitmap,
+static SkShader* BitmapShader_constructor(JNIEnv* env, jobject o, const SkBitmap* bitmap,
                                           int tileModeX, int tileModeY)
 {
     SkShader* s = SkShader::CreateBitmapShader(*bitmap,
                                         (SkShader::TileMode)tileModeX,
                                         (SkShader::TileMode)tileModeY);
+
     ThrowIAE_IfNull(env, s);
     return s;
 }
+
+static SkiaShader* BitmapShader_postConstructor(JNIEnv* env, jobject o, SkShader* shader,
+        SkBitmap* bitmap, int tileModeX, int tileModeY) {
+    SkiaShader* skiaShader = new SkiaBitmapShader(bitmap, shader,
+            static_cast<SkShader::TileMode>(tileModeX), static_cast<SkShader::TileMode>(tileModeY),
+            NULL, (shader->getFlags() & SkShader::kOpaqueAlpha_Flag) == 0);
+    return skiaShader;
+}
     
 ///////////////////////////////////////////////////////////////////////////////////////////////
 
-static void LinearGradient_destructor(JNIEnv* env, jobject o, SkShader* shader)
-{
-    delete reinterpret_cast<jfloat*>(env->GetIntField(o, gLinearGradientClassInfo.bounds));
-    delete reinterpret_cast<jint*>(env->GetIntField(o, gLinearGradientClassInfo.colors));
-    delete reinterpret_cast<jfloat*>(env->GetIntField(o, gLinearGradientClassInfo.positions));
-}
-
 static SkShader* LinearGradient_create1(JNIEnv* env, jobject o,
                                         float x0, float y0, float x1, float y1,
                                         jintArray colorArray, jfloatArray posArray, int tileMode)
@@ -99,44 +106,83 @@
     pts[0].set(SkFloatToScalar(x0), SkFloatToScalar(y0));
     pts[1].set(SkFloatToScalar(x1), SkFloatToScalar(y1));
 
-    size_t      count = env->GetArrayLength(colorArray);
+    size_t count = env->GetArrayLength(colorArray);
     const jint* colorValues = env->GetIntArrayElements(colorArray, NULL);
 
     SkAutoSTMalloc<8, SkScalar> storage(posArray ? count : 0);
     SkScalar*                   pos = NULL;
 
-    jfloat* storedBounds = new jfloat[4];
-    storedBounds[0] = x0; storedBounds[1] = y0;
-    storedBounds[2] = x1; storedBounds[3] = y1;
-    jfloat* storedPositions = new jfloat[count];
-    jint* storedColors = new jint[count];
-    memcpy(storedColors, colorValues, count);
-
     if (posArray) {
         AutoJavaFloatArray autoPos(env, posArray, count);
         const float* posValues = autoPos.ptr();
         pos = (SkScalar*)storage.get();
         for (size_t i = 0; i < count; i++) {
             pos[i] = SkFloatToScalar(posValues[i]);
+        }
+    }
+    
+    SkShader* shader = SkGradientShader::CreateLinear(pts,
+                                reinterpret_cast<const SkColor*>(colorValues),
+                                pos, count,
+                                static_cast<SkShader::TileMode>(tileMode));
+
+    env->ReleaseIntArrayElements(colorArray, const_cast<jint*>(colorValues), JNI_ABORT);
+    ThrowIAE_IfNull(env, shader);
+    return shader;
+}
+
+static SkiaShader* LinearGradient_postCreate1(JNIEnv* env, jobject o, SkShader* shader,
+        float x0, float y0, float x1, float y1, jintArray colorArray,
+        jfloatArray posArray, int tileMode) {
+
+    size_t count = env->GetArrayLength(colorArray);
+    const jint* colorValues = env->GetIntArrayElements(colorArray, NULL);
+
+    jfloat* storedBounds = new jfloat[4];
+    storedBounds[0] = x0; storedBounds[1] = y0;
+    storedBounds[2] = x1; storedBounds[3] = y1;
+    jfloat* storedPositions = new jfloat[count];
+    uint32_t* storedColors = new uint32_t[count];
+    memcpy(storedColors, colorValues, count);
+
+    if (posArray) {
+        AutoJavaFloatArray autoPos(env, posArray, count);
+        const float* posValues = autoPos.ptr();
+        for (size_t i = 0; i < count; i++) {
             storedPositions[i] = posValues[i];
         }
     } else {
         storedPositions[0] = 0.0f;
         storedPositions[1] = 1.0f;
     }
-    
-    env->SetIntField(o, gLinearGradientClassInfo.bounds, reinterpret_cast<jint>(storedBounds));
-    env->SetIntField(o, gLinearGradientClassInfo.colors, reinterpret_cast<jint>(storedColors));
-    env->SetIntField(o, gLinearGradientClassInfo.positions, reinterpret_cast<jint>(storedPositions));
-    
-    SkShader* shader = SkGradientShader::CreateLinear(pts,
-                                reinterpret_cast<const SkColor*>(colorValues),
-                                pos, count,
-                                static_cast<SkShader::TileMode>(tileMode));
-    env->ReleaseIntArrayElements(colorArray, const_cast<jint*>(colorValues),
-                                 JNI_ABORT);
-    ThrowIAE_IfNull(env, shader);
-    return shader;
+
+    SkiaShader* skiaShader = new SkiaLinearGradientShader(storedBounds, storedColors,
+            storedPositions, count, shader, static_cast<SkShader::TileMode>(tileMode), NULL,
+            (shader->getFlags() & SkShader::kOpaqueAlpha_Flag) == 0);
+
+    env->ReleaseIntArrayElements(colorArray, const_cast<jint*>(colorValues), JNI_ABORT);
+    return skiaShader;
+}
+
+static SkiaShader* LinearGradient_postCreate2(JNIEnv* env, jobject o, SkShader* shader,
+        float x0, float y0, float x1, float y1, int color0, int color1, int tileMode) {
+    float* storedBounds = new float[4];
+    storedBounds[0] = x0; storedBounds[1] = y0;
+    storedBounds[2] = x1; storedBounds[3] = y1;
+
+    float* storedPositions = new float[2];
+    storedPositions[0] = 0.0f;
+    storedPositions[1] = 1.0f;
+
+    uint32_t* storedColors = new uint32_t[2];
+    storedColors[0] = color0;
+    storedColors[1] = color1;
+
+    SkiaShader* skiaShader = new SkiaLinearGradientShader(storedBounds, storedColors,
+            storedPositions, 2, shader, static_cast<SkShader::TileMode>(tileMode), NULL,
+            (shader->getFlags() & SkShader::kOpaqueAlpha_Flag) == 0);
+
+    return skiaShader;
 }
 
 static SkShader* LinearGradient_create2(JNIEnv* env, jobject o,
@@ -151,23 +197,8 @@
     colors[0] = color0;
     colors[1] = color1;
     
-    float* storedBounds = new float[4];
-    storedBounds[0] = x0; storedBounds[1] = y0;
-    storedBounds[2] = x1; storedBounds[3] = y1;
-    
-    float* storedPositions = new float[2];
-    storedPositions[0] = 0.0f;
-    storedPositions[1] = 1.0f;
-    
-    uint32_t* storedColors = new uint32_t[2];
-    storedColors[0] = color0;
-    storedColors[1] = color1;
-    
-    env->SetIntField(o, gLinearGradientClassInfo.bounds, reinterpret_cast<jint>(storedBounds));
-    env->SetIntField(o, gLinearGradientClassInfo.colors, reinterpret_cast<jint>(storedColors));
-    env->SetIntField(o, gLinearGradientClassInfo.positions, reinterpret_cast<jint>(storedPositions));
-
     SkShader* s = SkGradientShader::CreateLinear(pts, colors, NULL, 2, (SkShader::TileMode)tileMode);
+
     ThrowIAE_IfNull(env, s);
     return s;
 }
@@ -268,18 +299,38 @@
 
 ///////////////////////////////////////////////////////////////////////////////////////////////
 
-static SkShader* ComposeShader_create1(JNIEnv* env, jobject,
-                                       SkShader* shaderA, SkShader* shaderB, SkXfermode* mode)
+static SkShader* ComposeShader_create1(JNIEnv* env, jobject o,
+        SkShader* shaderA, SkShader* shaderB, SkXfermode* mode)
 {
     return new SkComposeShader(shaderA, shaderB, mode);
 }
 
-static SkShader* ComposeShader_create2(JNIEnv* env, jobject,
-                                       SkShader* shaderA, SkShader* shaderB, SkPorterDuff::Mode mode)
+static SkShader* ComposeShader_create2(JNIEnv* env, jobject o,
+        SkShader* shaderA, SkShader* shaderB, SkPorterDuff::Mode porterDuffMode)
 {
-    SkAutoUnref au(SkPorterDuff::CreateXfermode(mode));
+    SkAutoUnref au(SkPorterDuff::CreateXfermode(porterDuffMode));
+    SkXfermode* mode = (SkXfermode*) au.get();
+    return new SkComposeShader(shaderA, shaderB, mode);
+}
 
-    return new SkComposeShader(shaderA, shaderB, (SkXfermode*)au.get());
+static SkiaShader* ComposeShader_postCreate2(JNIEnv* env, jobject o, SkShader* shader,
+        SkiaShader* shaderA, SkiaShader* shaderB, SkPorterDuff::Mode porterDuffMode) {
+    SkAutoUnref au(SkPorterDuff::CreateXfermode(porterDuffMode));
+    SkXfermode* mode = (SkXfermode*) au.get();
+    SkXfermode::Mode skiaMode;
+    if (!SkXfermode::IsMode(mode, &skiaMode)) {
+        skiaMode = SkXfermode::kSrcOver_Mode;
+    }
+    return new SkiaComposeShader(shaderA, shaderB, skiaMode, shader);
+}
+
+static SkiaShader* ComposeShader_postCreate1(JNIEnv* env, jobject o, SkShader* shader,
+        SkiaShader* shaderA, SkiaShader* shaderB, SkXfermode* mode) {
+    SkXfermode::Mode skiaMode;
+    if (!SkXfermode::IsMode(mode, &skiaMode)) {
+        skiaMode = SkXfermode::kSrcOver_Mode;
+    }
+    return new SkiaComposeShader(shaderA, shaderB, skiaMode, shader);
 }
 
 ///////////////////////////////////////////////////////////////////////////////////////////////
@@ -290,19 +341,21 @@
 };
 
 static JNINativeMethod gShaderMethods[] = {
-    { "nativeDestructor",        "(I)V",     (void*)Shader_destructor        },
+    { "nativeDestructor",        "(II)V",    (void*)Shader_destructor        },
     { "nativeGetLocalMatrix",    "(II)Z",    (void*)Shader_getLocalMatrix    },
-    { "nativeSetLocalMatrix",    "(II)V",    (void*)Shader_setLocalMatrix    }
+    { "nativeSetLocalMatrix",    "(III)V",   (void*)Shader_setLocalMatrix    }
 };
 
 static JNINativeMethod gBitmapShaderMethods[] = {
-    { "nativeCreate",   "(III)I",  (void*)BitmapShader_constructor }
+    { "nativeCreate",     "(III)I",  (void*)BitmapShader_constructor },
+    { "nativePostCreate", "(IIII)I", (void*)BitmapShader_postConstructor }
 };
 
 static JNINativeMethod gLinearGradientMethods[] = {
-    { "nativeDestructor", "(I)V",         (void*)LinearGradient_destructor },
-    { "nativeCreate1",    "(FFFF[I[FI)I", (void*)LinearGradient_create1    },
-    { "nativeCreate2",    "(FFFFIII)I",   (void*)LinearGradient_create2    }
+    { "nativeCreate1",     "(FFFF[I[FI)I",  (void*)LinearGradient_create1     },
+    { "nativeCreate2",     "(FFFFIII)I",    (void*)LinearGradient_create2     },
+    { "nativePostCreate1", "(IFFFF[I[FI)I", (void*)LinearGradient_postCreate1 },
+    { "nativePostCreate2", "(IFFFFIII)I",   (void*)LinearGradient_postCreate2 }
 };
 
 static JNINativeMethod gRadialGradientMethods[] = {
@@ -316,8 +369,10 @@
 };
 
 static JNINativeMethod gComposeShaderMethods[] = {
-    {"nativeCreate1",  "(III)I",    (void*)ComposeShader_create1 },
-    {"nativeCreate2",  "(III)I",    (void*)ComposeShader_create2 }
+    {"nativeCreate1",      "(III)I",   (void*)ComposeShader_create1     },
+    {"nativeCreate2",      "(III)I",   (void*)ComposeShader_create2     },
+    {"nativePostCreate1",  "(IIII)I",  (void*)ComposeShader_postCreate1 },
+    {"nativePostCreate2",  "(IIII)I",  (void*)ComposeShader_postCreate2 }
 };
 
 #include <android_runtime/AndroidRuntime.h>
@@ -325,15 +380,6 @@
 #define REG(env, name, array)                                                                       \
     result = android::AndroidRuntime::registerNativeMethods(env, name, array, SK_ARRAY_COUNT(array));  \
     if (result < 0) return result
-    
-#define FIND_CLASS(var, className) \
-        var = env->FindClass(className); \
-        LOG_FATAL_IF(! var, "Unable to find class " className); \
-        var = jclass(env->NewGlobalRef(var));
-
-#define GET_FIELD_ID(var, clazz, fieldName, fieldType) \
-        var = env->GetFieldID(clazz, fieldName, fieldType); \
-        LOG_FATAL_IF(! var, "Unable to find field " fieldName);
 
 int register_android_graphics_Shader(JNIEnv* env);
 int register_android_graphics_Shader(JNIEnv* env)
@@ -348,11 +394,6 @@
     REG(env, "android/graphics/SweepGradient", gSweepGradientMethods);
     REG(env, "android/graphics/ComposeShader", gComposeShaderMethods);
     
-    FIND_CLASS(gLinearGradientClassInfo.clazz, "android/graphics/LinearGradient");
-    GET_FIELD_ID(gLinearGradientClassInfo.bounds, gLinearGradientClassInfo.clazz, "bounds", "I");
-    GET_FIELD_ID(gLinearGradientClassInfo.colors, gLinearGradientClassInfo.clazz, "colors", "I");
-    GET_FIELD_ID(gLinearGradientClassInfo.positions, gLinearGradientClassInfo.clazz, "positions", "I");
-    
     return result;
 }
 
diff --git a/core/jni/android_view_GLES20Canvas.cpp b/core/jni/android_view_GLES20Canvas.cpp
index 1e4a66d..ece9636 100644
--- a/core/jni/android_view_GLES20Canvas.cpp
+++ b/core/jni/android_view_GLES20Canvas.cpp
@@ -30,6 +30,7 @@
 #include <SkXfermode.h>
 
 #include <OpenGLRenderer.h>
+#include <SkiaShader.h>
 #include <Rect.h>
 #include <ui/Rect.h>
 
@@ -235,18 +236,9 @@
     renderer->resetShader();
 }
 
-static void android_view_GLES20Canvas_setupBitmapShader(JNIEnv* env, jobject canvas,
-        OpenGLRenderer* renderer, SkShader* shader, SkBitmap* bitmap,
-        SkShader::TileMode tileX, SkShader::TileMode tileY, SkMatrix* matrix) {
-    renderer->setupBitmapShader(bitmap, tileX, tileY, matrix,
-            (shader->getFlags() & SkShader::kOpaqueAlpha_Flag) == 0);
-}
-
-static void android_view_GLES20Canvas_setupLinearShader(JNIEnv* env, jobject canvas,
-        OpenGLRenderer* renderer, SkShader* shader, float* bounds, uint32_t* colors,
-        float* positions, int count, SkShader::TileMode tileMode, SkMatrix* matrix) {
-    renderer->setupLinearGradientShader(shader, bounds, colors, positions, count,
-            tileMode, matrix, (shader->getFlags() & SkShader::kOpaqueAlpha_Flag) == 0);
+static void android_view_GLES20Canvas_setupShader(JNIEnv* env, jobject canvas,
+        OpenGLRenderer* renderer, SkiaShader* shader) {
+    renderer->setupShader(shader);
 }
 
 // ----------------------------------------------------------------------------
@@ -320,8 +312,7 @@
     {   "nDrawRect",          "(IFFFFI)V",       (void*) android_view_GLES20Canvas_drawRect },
 
     {   "nResetShader",       "(I)V",            (void*) android_view_GLES20Canvas_resetShader },
-    {   "nSetupBitmapShader", "(IIIIII)V",       (void*) android_view_GLES20Canvas_setupBitmapShader },
-    {   "nSetupLinearShader", "(IIIIIIII)V",     (void*) android_view_GLES20Canvas_setupLinearShader },
+    {   "nSetupShader",       "(II)V",           (void*) android_view_GLES20Canvas_setupShader },
 
     {   "nDrawText",          "(I[CIIFFII)V",    (void*) android_view_GLES20Canvas_drawTextArray },
     {   "nDrawText",          "(ILjava/lang/String;IIFFII)V",
diff --git a/graphics/java/android/graphics/BitmapShader.java b/graphics/java/android/graphics/BitmapShader.java
index 37b40e7..4c92942 100644
--- a/graphics/java/android/graphics/BitmapShader.java
+++ b/graphics/java/android/graphics/BitmapShader.java
@@ -22,21 +22,6 @@
  */
 public class BitmapShader extends Shader {
     /**
-     * We hold on just for the GC, since our native counterpart is using it.
-     * 
-     * @hide 
-     */
-    public Bitmap mBitmap;
-    /**
-     * @hide 
-     */
-    public int mTileX;
-    /**
-     * @hide 
-     */
-    public int mTileY;
-
-    /**
      * Call this to create a new shader that will draw with a bitmap.
      *
      * @param bitmap            The bitmap to use inside the shader
@@ -44,12 +29,13 @@
      * @param tileY             The tiling mode for y to draw the bitmap in.
      */
     public BitmapShader(Bitmap bitmap, TileMode tileX, TileMode tileY) {
-        mBitmap = bitmap;
-        mTileX = tileX.nativeInt;
-        mTileY = tileY.nativeInt;
-        native_instance = nativeCreate(bitmap.ni(), mTileX, mTileY);
+        final int b = bitmap.ni();
+        native_instance = nativeCreate(b, tileX.nativeInt, tileY.nativeInt);
+        native_shader = nativePostCreate(native_instance, b, tileX.nativeInt, tileY.nativeInt);
     }
 
     private static native int nativeCreate(int native_bitmap, int shaderTileModeX,
-            int shaderTileModeY);    
+            int shaderTileModeY);
+    private static native int nativePostCreate(int native_shader, int native_bitmap,
+            int shaderTileModeX, int shaderTileModeY);
 }
diff --git a/graphics/java/android/graphics/ComposeShader.java b/graphics/java/android/graphics/ComposeShader.java
index a06d30b..9b57ea4 100644
--- a/graphics/java/android/graphics/ComposeShader.java
+++ b/graphics/java/android/graphics/ComposeShader.java
@@ -30,7 +30,9 @@
     */
     public ComposeShader(Shader shaderA, Shader shaderB, Xfermode mode) {
         native_instance = nativeCreate1(shaderA.native_instance, shaderB.native_instance,
-                                        (mode != null) ? mode.native_instance : 0);
+                (mode != null) ? mode.native_instance : 0);
+        native_shader = nativePostCreate1(native_instance, shaderA.native_shader,
+                shaderB.native_shader, (mode != null) ? mode.native_instance : 0);
     }
 
     /** Create a new compose shader, given shaders A, B, and a combining PorterDuff mode.
@@ -42,10 +44,17 @@
     */
     public ComposeShader(Shader shaderA, Shader shaderB, PorterDuff.Mode mode) {
         native_instance = nativeCreate2(shaderA.native_instance, shaderB.native_instance,
-                                        mode.nativeInt);
+                mode.nativeInt);
+        native_shader = nativePostCreate2(native_instance, shaderA.native_shader,
+                shaderB.native_shader, mode.nativeInt);
     }
 
-    private static native int nativeCreate1(int native_shaderA, int native_shaderB, int native_mode);
-    private static native int nativeCreate2(int native_shaderA, int native_shaderB, int porterDuffMode);
+    private static native int nativeCreate1(int native_shaderA, int native_shaderB,
+            int native_mode);
+    private static native int nativeCreate2(int native_shaderA, int native_shaderB,
+            int porterDuffMode);
+    private static native int nativePostCreate1(int native_shader, int native_skiaShaderA,
+            int native_skiaShaderB, int native_mode);
+    private static native int nativePostCreate2(int native_shader, int native_skiaShaderA,
+            int native_skiaShaderB, int porterDuffMode);
 }
-
diff --git a/graphics/java/android/graphics/LinearGradient.java b/graphics/java/android/graphics/LinearGradient.java
index fd57591..82ed199 100644
--- a/graphics/java/android/graphics/LinearGradient.java
+++ b/graphics/java/android/graphics/LinearGradient.java
@@ -17,28 +17,6 @@
 package android.graphics;
 
 public class LinearGradient extends Shader {
-    /**
-     * These fields are manipulated by the JNI layer, don't touch!
-     * @hide
-     */
-    public int bounds;
-    /**
-     * @hide
-     */
-    public int colors;
-    /**
-     * @hide
-     */
-    public int positions;
-    /**
-     * @hide
-     */
-    public int count;
-    /**
-     * @hide
-     */
-    public int tileMode;
-
 	/**	Create a shader that draws a linear gradient along a line.
         @param x0           The x-coordinate for the start of the gradient line
         @param y0           The y-coordinate for the start of the gradient line
@@ -59,8 +37,8 @@
             throw new IllegalArgumentException("color and position arrays must be of equal length");
         }
         native_instance = nativeCreate1(x0, y0, x1, y1, colors, positions, tile.nativeInt);
-        count = colors.length;
-        tileMode = tile.nativeInt;
+        native_shader = nativePostCreate1(native_instance, x0, y0, x1, y1, colors, positions,
+                tile.nativeInt);
     }
 
 	/**	Create a shader that draws a linear gradient along a line.
@@ -75,21 +53,16 @@
 	public LinearGradient(float x0, float y0, float x1, float y1,
                           int color0, int color1, TileMode tile) {
         native_instance = nativeCreate2(x0, y0, x1, y1, color0, color1, tile.nativeInt);
-        count = 2;
-        tileMode = tile.nativeInt;
-    }
-    
-    protected void finalize() throws Throwable {
-        try {
-            super.finalize();
-        } finally {
-            nativeDestructor(native_instance);
-        }
+        native_shader = nativePostCreate2(native_instance, x0, y0, x1, y1, color0, color1,
+                tile.nativeInt);
     }
 
-    private native void nativeDestructor(int native_shader);
 	private native int nativeCreate1(float x0, float y0, float x1, float y1,
             int colors[], float positions[], int tileMode);
 	private native int nativeCreate2(float x0, float y0, float x1, float y1,
             int color0, int color1, int tileMode);
+    private native int nativePostCreate1(int native_shader, float x0, float y0, float x1, float y1,
+            int colors[], float positions[], int tileMode);
+    private native int nativePostCreate2(int native_shader, float x0, float y0, float x1, float y1,
+            int color0, int color1, int tileMode);
 }
diff --git a/graphics/java/android/graphics/Shader.java b/graphics/java/android/graphics/Shader.java
index 86c71e3..b397662 100644
--- a/graphics/java/android/graphics/Shader.java
+++ b/graphics/java/android/graphics/Shader.java
@@ -24,18 +24,15 @@
  */
 public class Shader {
     /**
-     * Local matrix native instance.
-     * 
-     * @hide
-     */
-    public int mLocalMatrix;
-
-    /**
      * This is set by subclasses, but don't make it public.
      * 
      * @hide 
      */
     public int native_instance;
+    /**
+     * @hide
+     */
+    public int native_shader;
 
     public enum TileMode {
         /**
@@ -74,21 +71,20 @@
      * @param localM The shader's new local matrix, or null to specify identity
      */
     public void setLocalMatrix(Matrix localM) {
-        mLocalMatrix = localM != null ? localM.native_instance : 0;
-        nativeSetLocalMatrix(native_instance, mLocalMatrix);
+        nativeSetLocalMatrix(native_instance, native_shader, localM.native_instance);
     }
 
     protected void finalize() throws Throwable {
         try {
             super.finalize();
         } finally {
-            nativeDestructor(native_instance);
+            nativeDestructor(native_instance, native_shader);
         }
     }
 
-    private static native void nativeDestructor(int native_shader);
+    private static native void nativeDestructor(int native_shader, int native_skiaShader);
     private static native boolean nativeGetLocalMatrix(int native_shader,
-                                                       int matrix_instance);
+            int matrix_instance); 
     private static native void nativeSetLocalMatrix(int native_shader,
-                                                    int matrix_instance);
+            int native_skiaShader, int matrix_instance);
 }
diff --git a/libs/hwui/Android.mk b/libs/hwui/Android.mk
index a9714c7..fe1b524 100644
--- a/libs/hwui/Android.mk
+++ b/libs/hwui/Android.mk
@@ -11,6 +11,7 @@
 	PatchCache.cpp \
 	Program.cpp \
 	ProgramCache.cpp \
+	SkiaShader.cpp \
 	TextureCache.cpp
 
 LOCAL_C_INCLUDES += \
diff --git a/libs/hwui/OpenGLRenderer.cpp b/libs/hwui/OpenGLRenderer.cpp
index cc8e6bc..187e9d8 100644
--- a/libs/hwui/OpenGLRenderer.cpp
+++ b/libs/hwui/OpenGLRenderer.cpp
@@ -41,6 +41,8 @@
 #define DEFAULT_PATCH_CACHE_SIZE 100
 #define DEFAULT_GRADIENT_CACHE_SIZE 0.5f
 
+#define REQUIRED_TEXTURE_UNITS_COUNT 3
+
 // Converts a number of mega-bytes into bytes
 #define MB(s) s * 1024 * 1024
 
@@ -88,16 +90,10 @@
         { SkXfermode::kXor_Mode,     GL_ONE_MINUS_DST_ALPHA,  GL_ONE_MINUS_SRC_ALPHA }
 };
 
-static const GLint gTileModes[] = {
-        GL_CLAMP_TO_EDGE,   // SkShader::kClamp_TileMode
-        GL_REPEAT,          // SkShader::kRepeat_Mode
-        GL_MIRRORED_REPEAT  // SkShader::kMirror_TileMode
-};
-
 static const GLenum gTextureUnits[] = {
-        GL_TEXTURE0,        // Bitmap or text
-        GL_TEXTURE1,        // Gradient
-        GL_TEXTURE2         // Bitmap shader
+        GL_TEXTURE0,
+        GL_TEXTURE1,
+        GL_TEXTURE2
 };
 
 ///////////////////////////////////////////////////////////////////////////////
@@ -135,12 +131,7 @@
     }
 
     mCurrentProgram = NULL;
-
-    mShader = kShaderNone;
-    mShaderTileX = GL_CLAMP_TO_EDGE;
-    mShaderTileY = GL_CLAMP_TO_EDGE;
-    mShaderMatrix = NULL;
-    mShaderBitmap = NULL;
+    mShader = NULL;
 
     memcpy(mMeshVertices, gMeshVertices, sizeof(gMeshVertices));
 
@@ -560,16 +551,21 @@
 
     mModelView.loadIdentity();
 
+    GLuint textureUnit = 0;
+
     ProgramDescription description;
     description.hasTexture = true;
     description.hasAlpha8Texture = true;
+    if (mShader) {
+        mShader->describe(description, mExtensions);
+    }
 
     useProgram(mProgramCache.get(description));
     mCurrentProgram->set(mOrthoMatrix, mModelView, mSnapshot->transform);
 
     chooseBlending(true, mode);
-    bindTexture(mFontRenderer.getTexture(), GL_CLAMP_TO_EDGE, GL_CLAMP_TO_EDGE, 0);
-    glUniform1i(mCurrentProgram->getUniform("sampler"), 0);
+    bindTexture(mFontRenderer.getTexture(), GL_CLAMP_TO_EDGE, GL_CLAMP_TO_EDGE, textureUnit);
+    glUniform1i(mCurrentProgram->getUniform("sampler"), textureUnit);
 
     int texCoordsSlot = mCurrentProgram->getAttrib("texCoords");
     glEnableVertexAttribArray(texCoordsSlot);
@@ -577,6 +573,12 @@
     // Always premultiplied
     glUniform4f(mCurrentProgram->color, r, g, b, a);
 
+    textureUnit++;
+    // Setup attributes and uniforms required by the shaders
+    if (mShader) {
+        mShader->setupProgram(mCurrentProgram, mModelView, *mSnapshot, &textureUnit);
+    }
+
     // TODO: Implement scale properly
     const Rect& clip = mSnapshot->getLocalClip();
     mFontRenderer.setFont(paint, SkTypeface::UniqueID(paint->getTypeface()), paint->getTextSize());
@@ -591,36 +593,14 @@
 ///////////////////////////////////////////////////////////////////////////////
 
 void OpenGLRenderer::resetShader() {
-    mShader = OpenGLRenderer::kShaderNone;
-    mShaderKey = NULL;
-    mShaderBlend = false;
-    mShaderTileX = GL_CLAMP_TO_EDGE;
-    mShaderTileY = GL_CLAMP_TO_EDGE;
+    mShader = NULL;
 }
 
-void OpenGLRenderer::setupBitmapShader(SkBitmap* bitmap, SkShader::TileMode tileX,
-        SkShader::TileMode tileY, SkMatrix* matrix, bool hasAlpha) {
-    mShader = OpenGLRenderer::kShaderBitmap;
-    mShaderBlend = hasAlpha;
-    mShaderBitmap = bitmap;
-    mShaderTileX = gTileModes[tileX];
-    mShaderTileY = gTileModes[tileY];
-    mShaderMatrix = matrix;
-}
-
-void OpenGLRenderer::setupLinearGradientShader(SkShader* shader, float* bounds, uint32_t* colors,
-        float* positions, int count, SkShader::TileMode tileMode, SkMatrix* matrix, bool hasAlpha) {
-    // TODO: We should use a struct to describe each shader
-    mShader = OpenGLRenderer::kShaderLinearGradient;
-    mShaderKey = shader;
-    mShaderBlend = hasAlpha;
-    mShaderTileX = gTileModes[tileMode];
-    mShaderTileY = gTileModes[tileMode];
-    mShaderMatrix = matrix;
-    mShaderBounds = bounds;
-    mShaderColors = colors;
-    mShaderPositions = positions;
-    mShaderCount = count;
+void OpenGLRenderer::setupShader(SkiaShader* shader) {
+    mShader = shader;
+    if (mShader) {
+        mShader->set(&mTextureCache, &mGradientCache);
+    }
 }
 
 ///////////////////////////////////////////////////////////////////////////////
@@ -630,165 +610,52 @@
 void OpenGLRenderer::drawColorRect(float left, float top, float right, float bottom,
         int color, SkXfermode::Mode mode, bool ignoreTransform) {
     // If a shader is set, preserve only the alpha
-    if (mShader != kShaderNone) {
+    if (mShader) {
         color |= 0x00ffffff;
     }
 
     // Render using pre-multiplied alpha
     const int alpha = (color >> 24) & 0xFF;
     const GLfloat a = alpha / 255.0f;
-
-    switch (mShader) {
-        case OpenGLRenderer::kShaderBitmap:
-            drawBitmapShader(left, top, right, bottom, a, mode);
-            return;
-        case OpenGLRenderer::kShaderLinearGradient:
-            drawLinearGradientShader(left, top, right, bottom, a, mode);
-            return;
-        default:
-            break;
-    }
-
     const GLfloat r = a * ((color >> 16) & 0xFF) / 255.0f;
     const GLfloat g = a * ((color >>  8) & 0xFF) / 255.0f;
     const GLfloat b = a * ((color      ) & 0xFF) / 255.0f;
 
-    // Pre-multiplication happens when setting the shader color
-    chooseBlending(alpha < 255 || mShaderBlend, mode);
+    GLuint textureUnit = 0;
 
-    mModelView.loadTranslate(left, top, 0.0f);
-    mModelView.scale(right - left, bottom - top, 1.0f);
+    // Setup the blending mode
+    chooseBlending(alpha < 255 || (mShader && mShader->blend()), mode);
 
+    // Describe the required shaders
     ProgramDescription description;
-    Program* program = mProgramCache.get(description);
-    if (!useProgram(program)) {
-        const GLvoid* vertices = &mMeshVertices[0].position[0];
-        const GLvoid* texCoords = &mMeshVertices[0].texture[0];
-
-        glVertexAttribPointer(mCurrentProgram->position, 2, GL_FLOAT, GL_FALSE,
-                gMeshStride, vertices);
+    if (mShader) {
+        mShader->describe(description, mExtensions);
     }
 
+    // Build and use the appropriate shader
+    useProgram(mProgramCache.get(description));
+
+    // Setup attributes
+    glVertexAttribPointer(mCurrentProgram->position, 2, GL_FLOAT, GL_FALSE,
+            gMeshStride, &mMeshVertices[0].position[0]);
+
+    // Setup uniforms
+    mModelView.loadTranslate(left, top, 0.0f);
+    mModelView.scale(right - left, bottom - top, 1.0f);
     if (!ignoreTransform) {
         mCurrentProgram->set(mOrthoMatrix, mModelView, mSnapshot->transform);
     } else {
         mat4 identity;
         mCurrentProgram->set(mOrthoMatrix, mModelView, identity);
     }
-
     glUniform4f(mCurrentProgram->color, r, g, b, a);
 
-    glDrawArrays(GL_TRIANGLE_STRIP, 0, gMeshCount);
-}
-
-void OpenGLRenderer::drawLinearGradientShader(float left, float top, float right, float bottom,
-        float alpha, SkXfermode::Mode mode) {
-    Texture* texture = mGradientCache.get(mShaderKey);
-    if (!texture) {
-        SkShader::TileMode tileMode = SkShader::kClamp_TileMode;
-        switch (mShaderTileX) {
-            case GL_REPEAT:
-                tileMode = SkShader::kRepeat_TileMode;
-                break;
-            case GL_MIRRORED_REPEAT:
-                tileMode = SkShader::kMirror_TileMode;
-                break;
-        }
-
-        texture = mGradientCache.addLinearGradient(mShaderKey, mShaderBounds, mShaderColors,
-                mShaderPositions, mShaderCount, tileMode);
+    // Setup attributes and uniforms required by the shaders
+    if (mShader) {
+        mShader->setupProgram(mCurrentProgram, mModelView, *mSnapshot, &textureUnit);
     }
 
-    ProgramDescription description;
-    description.hasGradient = true;
-
-    mModelView.loadTranslate(left, top, 0.0f);
-    mModelView.scale(right - left, bottom - top, 1.0f);
-
-    useProgram(mProgramCache.get(description));
-    mCurrentProgram->set(mOrthoMatrix, mModelView, mSnapshot->transform);
-
-    chooseBlending(mShaderBlend || alpha < 1.0f, mode);
-    bindTexture(texture->id, mShaderTileX, mShaderTileY, 0);
-    glUniform1i(mCurrentProgram->getUniform("gradientSampler"), 0);
-
-    Rect start(mShaderBounds[0], mShaderBounds[1], mShaderBounds[2], mShaderBounds[3]);
-    if (mShaderMatrix) {
-        mat4 shaderMatrix(*mShaderMatrix);
-        shaderMatrix.mapRect(start);
-    }
-    mSnapshot->transform.mapRect(start);
-
-    const float gradientX = start.right - start.left;
-    const float gradientY = start.bottom - start.top;
-
-    mat4 screenSpace(mSnapshot->transform);
-    screenSpace.multiply(mModelView);
-
-    // Always premultiplied
-    glUniform4f(mCurrentProgram->color, alpha, alpha, alpha, alpha);
-    glUniform2f(mCurrentProgram->getUniform("gradientStart"), start.left, start.top);
-    glUniform2f(mCurrentProgram->getUniform("gradient"), gradientX, gradientY);
-    glUniform1f(mCurrentProgram->getUniform("gradientLength"),
-            1.0f / (gradientX * gradientX + gradientY * gradientY));
-    glUniformMatrix4fv(mCurrentProgram->getUniform("screenSpace"), 1, GL_FALSE,
-            &screenSpace.data[0]);
-
-    glVertexAttribPointer(mCurrentProgram->position, 2, GL_FLOAT, GL_FALSE,
-            gMeshStride, &mMeshVertices[0].position[0]);
-
-    glDrawArrays(GL_TRIANGLE_STRIP, 0, gMeshCount);
-}
-
-void OpenGLRenderer::drawBitmapShader(float left, float top, float right, float bottom,
-        float alpha, SkXfermode::Mode mode) {
-    const Texture* texture = mTextureCache.get(mShaderBitmap);
-
-    const float width = texture->width;
-    const float height = texture->height;
-
-    mModelView.loadTranslate(left, top, 0.0f);
-    mModelView.scale(right - left, bottom - top, 1.0f);
-
-    mat4 textureTransform;
-    if (mShaderMatrix) {
-        SkMatrix inverse;
-        mShaderMatrix->invert(&inverse);
-        textureTransform.load(inverse);
-        textureTransform.multiply(mModelView);
-    } else {
-        textureTransform.load(mModelView);
-    }
-
-    ProgramDescription description;
-    description.hasBitmap = true;
-    // The driver does not support non-power of two mirrored/repeated
-    // textures, so do it ourselves
-    if (!mExtensions.hasNPot()) {
-        description.isBitmapNpot = true;
-        description.bitmapWrapS = mShaderTileX;
-        description.bitmapWrapT = mShaderTileY;
-    }
-
-    useProgram(mProgramCache.get(description));
-    mCurrentProgram->set(mOrthoMatrix, mModelView, mSnapshot->transform);
-
-    chooseBlending(texture->blend || alpha < 1.0f, mode);
-
-    // Texture
-    bindTexture(texture->id, mShaderTileX, mShaderTileY, 0);
-    glUniform1i(mCurrentProgram->getUniform("bitmapSampler"), 0);
-    glUniformMatrix4fv(mCurrentProgram->getUniform("textureTransform"), 1,
-            GL_FALSE, &textureTransform.data[0]);
-    glUniform2f(mCurrentProgram->getUniform("textureDimension"), 1.0f / width, 1.0f / height);
-
-    // Always premultiplied
-    glUniform4f(mCurrentProgram->color, alpha, alpha, alpha, alpha);
-
-    // Mesh
-    glVertexAttribPointer(mCurrentProgram->position, 2, GL_FLOAT, GL_FALSE,
-            gMeshStride, &mMeshVertices[0].position[0]);
-
+    // Draw the mesh
     glDrawArrays(GL_TRIANGLE_STRIP, 0, gMeshCount);
 }
 
@@ -823,7 +690,7 @@
     chooseBlending(blend || alpha < 1.0f, mode);
 
     // Texture
-    bindTexture(texture, mShaderTileX, mShaderTileY, 0);
+    bindTexture(texture, GL_CLAMP_TO_EDGE, GL_CLAMP_TO_EDGE, 0);
     glUniform1i(mCurrentProgram->getUniform("sampler"), 0);
 
     // Always premultiplied
diff --git a/libs/hwui/OpenGLRenderer.h b/libs/hwui/OpenGLRenderer.h
index 937ff08..dc0f50f 100644
--- a/libs/hwui/OpenGLRenderer.h
+++ b/libs/hwui/OpenGLRenderer.h
@@ -42,6 +42,7 @@
 #include "Vertex.h"
 #include "FontRenderer.h"
 #include "ProgramCache.h"
+#include "SkiaShader.h"
 
 namespace android {
 namespace uirenderer {
@@ -50,8 +51,6 @@
 // Renderer
 ///////////////////////////////////////////////////////////////////////////////
 
-#define REQUIRED_TEXTURE_UNITS_COUNT 3
-
 /**
  * OpenGL renderer used to draw accelerated 2D graphics. The API is a
  * simplified version of Skia's Canvas API.
@@ -94,27 +93,12 @@
     void drawRect(float left, float top, float right, float bottom, const SkPaint* paint);
 
     void resetShader();
-    void setupBitmapShader(SkBitmap* bitmap, SkShader::TileMode tileX, SkShader::TileMode tileY,
-            SkMatrix* matrix, bool hasAlpha);
-    void setupLinearGradientShader(SkShader* shader, float* bounds, uint32_t* colors,
-            float* positions, int count, SkShader::TileMode tileMode,
-            SkMatrix* matrix, bool hasAlpha);
+    void setupShader(SkiaShader* shader);
 
     void drawText(const char* text, int bytesCount, int count, float x, float y, SkPaint* paint);
 
 private:
     /**
-     * Type of Skia shader in use.
-     */
-    enum ShaderType {
-        kShaderNone,
-        kShaderBitmap,
-        kShaderLinearGradient,
-        kShaderCircularGradient,
-        kShaderSweepGradient
-    };
-
-    /**
      * Saves the current state of the renderer as a new snapshot.
      * The new snapshot is saved in mSnapshot and the previous snapshot
      * is linked from mSnapshot->previous.
@@ -232,32 +216,6 @@
             GLvoid* vertices, GLvoid* texCoords, GLvoid* indices, GLsizei elementsCount = 0);
 
     /**
-     * Fills the specified rectangle with the currently set bitmap shader.
-     *
-     * @param left The left coordinate of the rectangle
-     * @param top The top coordinate of the rectangle
-     * @param right The right coordinate of the rectangle
-     * @param bottom The bottom coordinate of the rectangle
-     * @param alpha An additional translucency parameter, between 0.0f and 1.0f
-     * @param mode The blending mode
-     */
-    void drawBitmapShader(float left, float top, float right, float bottom, float alpha,
-            SkXfermode::Mode mode);
-
-    /**
-     * Fills the specified rectangle with the currently set linear gradient shader.
-     *
-     * @param left The left coordinate of the rectangle
-     * @param top The top coordinate of the rectangle
-     * @param right The right coordinate of the rectangle
-     * @param bottom The bottom coordinate of the rectangle
-     * @param alpha An additional translucency parameter, between 0.0f and 1.0f
-     * @param mode The blending mode
-     */
-    void drawLinearGradientShader(float left, float top, float right, float bottom, float alpha,
-            SkXfermode::Mode mode);
-
-    /**
      * Resets the texture coordinates stored in mMeshVertices. Setting the values
      * back to default is achieved by calling:
      *
@@ -321,6 +279,7 @@
 
     // Shaders
     Program* mCurrentProgram;
+    SkiaShader* mShader;
 
     // Used to draw textured quads
     TextureVertex mMeshVertices[4];
@@ -330,21 +289,6 @@
     GLenum mLastSrcMode;
     GLenum mLastDstMode;
 
-    // Skia shaders
-    ShaderType mShader;
-    SkShader* mShaderKey;
-    bool mShaderBlend;
-    GLenum mShaderTileX;
-    GLenum mShaderTileY;
-    SkMatrix* mShaderMatrix;
-    // Bitmaps
-    SkBitmap* mShaderBitmap;
-    // Gradients
-    float* mShaderBounds;
-    uint32_t* mShaderColors;
-    float* mShaderPositions;
-    int mShaderCount;
-
     // GL extensions
     Extensions mExtensions;
 
diff --git a/libs/hwui/ProgramCache.cpp b/libs/hwui/ProgramCache.cpp
index c9e2d2e..23923f6 100644
--- a/libs/hwui/ProgramCache.cpp
+++ b/libs/hwui/ProgramCache.cpp
@@ -106,9 +106,9 @@
 const char* gFS_Main_FetchBitmapNpot =
         "    vec4 bitmapColor = texture2D(bitmapSampler, wrap(outBitmapTexCoords));\n";
 const char* gFS_Main_BlendShadersBG =
-        "    fragColor = blendShaders(bitmapColor, gradientColor)";
-const char* gFS_Main_BlendShadersGB =
         "    fragColor = blendShaders(gradientColor, bitmapColor)";
+const char* gFS_Main_BlendShadersGB =
+        "    fragColor = blendShaders(bitmapColor, gradientColor)";
 const char* gFS_Main_BlendShaders_Modulate =
         " * fragColor.a;\n";
 const char* gFS_Main_GradientShader_Modulate =
@@ -144,23 +144,23 @@
         // Dst
         "return dst;\n",
         // SrcOver
-        "return vec4(src.rgb + (1.0 - src.a) * dst.rgb, src.a + dst.a - src.a * dst.a);\n",
+        "return src + dst * (1.0 - src.a);\n",
         // DstOver
-        "return vec4(dst.rgb + (1.0 - dst.a) * src.rgb, src.a + dst.a - src.a * dst.a);\n",
+        "return dst + src * (1.0 - dst.a);\n",
         // SrcIn
-        "return vec4(src.rgb * dst.a, src.a * dst.a);\n",
+        "return src * dst.a;\n",
         // DstIn
-        "return vec4(dst.rgb * src.a, src.a * dst.a);\n",
+        "return dst * src.a;\n",
         // SrcOut
-        "return vec4(src.rgb * (1.0 - dst.a), src.a * (1.0 - dst.a));\n",
+        "return src * (1.0 - dst.a);\n",
         // DstOut
-        "return vec4(dst.rgb * (1.0 - src.a), dst.a * (1.0 - src.a));\n",
+        "return dst * (1.0 - src.a);\n",
         // SrcAtop
         "return vec4(src.rgb * dst.a + (1.0 - src.a) * dst.rgb, dst.a);\n",
         // DstAtop
         "return vec4(dst.rgb * src.a + (1.0 - dst.a) * src.rgb, src.a);\n",
         // Xor
-        "return vec4(src.rgb * (1.0 - dst.a) + (1.0 - src.a) * dst.rgb, "
+        "return vec4(src.rgb * (1.0 - dst.a) + (1.0 - src.a) * dst.rgb, 1.0, "
                 "src.a + dst.a - 2.0 * src.a * dst.a);\n",
 };
 
diff --git a/libs/hwui/ProgramCache.h b/libs/hwui/ProgramCache.h
index 5a6ec33..d60f6ce 100644
--- a/libs/hwui/ProgramCache.h
+++ b/libs/hwui/ProgramCache.h
@@ -19,6 +19,7 @@
 
 #include <utils/KeyedVector.h>
 #include <utils/Log.h>
+#include <utils/String8.h>
 
 #include <GLES2/gl2.h>
 
diff --git a/libs/hwui/SkiaShader.cpp b/libs/hwui/SkiaShader.cpp
new file mode 100644
index 0000000..fedb56c
--- /dev/null
+++ b/libs/hwui/SkiaShader.cpp
@@ -0,0 +1,220 @@
+/*
+ * Copyright (C) 2010 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#define LOG_TAG "OpenGLRenderer"
+
+#include <utils/Log.h>
+
+#include <SkMatrix.h>
+
+#include "SkiaShader.h"
+#include "Texture.h"
+#include "Matrix.h"
+
+namespace android {
+namespace uirenderer {
+
+///////////////////////////////////////////////////////////////////////////////
+// Support
+///////////////////////////////////////////////////////////////////////////////
+
+static const GLenum gTextureUnitsMap[] = {
+        GL_TEXTURE0,
+        GL_TEXTURE1,
+        GL_TEXTURE2
+};
+
+static const GLint gTileModes[] = {
+        GL_CLAMP_TO_EDGE,   // == SkShader::kClamp_TileMode
+        GL_REPEAT,          // == SkShader::kRepeat_Mode
+        GL_MIRRORED_REPEAT  // == SkShader::kMirror_TileMode
+};
+
+///////////////////////////////////////////////////////////////////////////////
+// Base shader
+///////////////////////////////////////////////////////////////////////////////
+
+SkiaShader::SkiaShader(Type type, SkShader* key, SkShader::TileMode tileX,
+        SkShader::TileMode tileY, SkMatrix* matrix, bool blend):
+        mType(type), mKey(key), mTileX(tileX), mTileY(tileY), mMatrix(matrix), mBlend(blend) {
+}
+
+SkiaShader::~SkiaShader() {
+}
+
+void SkiaShader::describe(ProgramDescription& description, const Extensions& extensions) {
+}
+
+void SkiaShader::setupProgram(Program* program, const mat4& modelView, const Snapshot& snapshot,
+        GLuint* textureUnit) {
+}
+
+void SkiaShader::bindTexture(GLuint texture, GLenum wrapS, GLenum wrapT, GLuint textureUnit) {
+    glActiveTexture(gTextureUnitsMap[textureUnit]);
+    glBindTexture(GL_TEXTURE_2D, texture);
+    glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_S, wrapS);
+    glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_T, wrapT);
+}
+
+///////////////////////////////////////////////////////////////////////////////
+// Bitmap shader
+///////////////////////////////////////////////////////////////////////////////
+
+SkiaBitmapShader::SkiaBitmapShader(SkBitmap* bitmap, SkShader* key, SkShader::TileMode tileX,
+        SkShader::TileMode tileY, SkMatrix* matrix, bool blend):
+        SkiaShader(kBitmap, key, tileX, tileY, matrix, blend), mBitmap(bitmap) {
+}
+
+SkiaBitmapShader::~SkiaBitmapShader() {
+}
+
+void SkiaBitmapShader::describe(ProgramDescription& description, const Extensions& extensions) {
+    const Texture* texture = mTextureCache->get(mBitmap);
+
+    const float width = texture->width;
+    const float height = texture->height;
+
+    description.hasBitmap = true;
+    // The driver does not support non-power of two mirrored/repeated
+    // textures, so do it ourselves
+    if (!extensions.hasNPot() && !isPowerOfTwo(width) && !isPowerOfTwo(height)) {
+        description.isBitmapNpot = true;
+        description.bitmapWrapS = gTileModes[mTileX];
+        description.bitmapWrapT = gTileModes[mTileY];
+    }
+}
+
+void SkiaBitmapShader::setupProgram(Program* program, const mat4& modelView,
+        const Snapshot& snapshot, GLuint* textureUnit) {
+    GLuint textureSlot = (*textureUnit)++;
+    glActiveTexture(gTextureUnitsMap[textureSlot]);
+    const Texture* texture = mTextureCache->get(mBitmap);
+
+    const float width = texture->width;
+    const float height = texture->height;
+
+    mat4 textureTransform;
+    if (mMatrix) {
+        SkMatrix inverse;
+        mMatrix->invert(&inverse);
+        textureTransform.load(inverse);
+        textureTransform.multiply(modelView);
+    } else {
+        textureTransform.load(modelView);
+    }
+
+    // Uniforms
+    bindTexture(texture->id, gTileModes[mTileX], gTileModes[mTileY], textureSlot);
+    glUniform1i(program->getUniform("bitmapSampler"), textureSlot);
+    glUniformMatrix4fv(program->getUniform("textureTransform"), 1,
+            GL_FALSE, &textureTransform.data[0]);
+    glUniform2f(program->getUniform("textureDimension"), 1.0f / width, 1.0f / height);
+}
+
+///////////////////////////////////////////////////////////////////////////////
+// Linear gradient shader
+///////////////////////////////////////////////////////////////////////////////
+
+SkiaLinearGradientShader::SkiaLinearGradientShader(float* bounds, uint32_t* colors,
+        float* positions, int count, SkShader* key, SkShader::TileMode tileMode,
+        SkMatrix* matrix, bool blend):
+        SkiaShader(kLinearGradient, key, tileMode, tileMode, matrix, blend),
+        mBounds(bounds), mColors(colors), mPositions(positions), mCount(count) {
+}
+
+SkiaLinearGradientShader::~SkiaLinearGradientShader() {
+    delete mBounds;
+    delete mColors;
+    delete mPositions;
+}
+
+void SkiaLinearGradientShader::describe(ProgramDescription& description,
+        const Extensions& extensions) {
+    description.hasGradient = true;
+}
+
+void SkiaLinearGradientShader::setupProgram(Program* program, const mat4& modelView,
+        const Snapshot& snapshot, GLuint* textureUnit) {
+    GLuint textureSlot = (*textureUnit)++;
+    glActiveTexture(gTextureUnitsMap[textureSlot]);
+
+    Texture* texture = mGradientCache->get(mKey);
+    if (!texture) {
+        texture = mGradientCache->addLinearGradient(mKey, mBounds, mColors, mPositions,
+                mCount, mTileX);
+    }
+
+    Rect start(mBounds[0], mBounds[1], mBounds[2], mBounds[3]);
+    if (mMatrix) {
+        mat4 shaderMatrix(*mMatrix);
+        shaderMatrix.mapRect(start);
+    }
+    snapshot.transform.mapRect(start);
+
+    const float gradientX = start.right - start.left;
+    const float gradientY = start.bottom - start.top;
+
+    mat4 screenSpace(snapshot.transform);
+    screenSpace.multiply(modelView);
+
+    // Uniforms
+    bindTexture(texture->id, gTileModes[mTileX], gTileModes[mTileY], textureSlot);
+    glUniform1i(program->getUniform("gradientSampler"), textureSlot);
+    glUniform2f(program->getUniform("gradientStart"), start.left, start.top);
+    glUniform2f(program->getUniform("gradient"), gradientX, gradientY);
+    glUniform1f(program->getUniform("gradientLength"),
+            1.0f / (gradientX * gradientX + gradientY * gradientY));
+    glUniformMatrix4fv(program->getUniform("screenSpace"), 1, GL_FALSE, &screenSpace.data[0]);
+}
+
+///////////////////////////////////////////////////////////////////////////////
+// Compose shader
+///////////////////////////////////////////////////////////////////////////////
+
+SkiaComposeShader::SkiaComposeShader(SkiaShader* first, SkiaShader* second,
+        SkXfermode::Mode mode, SkShader* key):
+        SkiaShader(kCompose, key, SkShader::kClamp_TileMode, SkShader::kClamp_TileMode,
+        NULL, first->blend() || second->blend()), mFirst(first), mSecond(second), mMode(mode) {
+}
+
+SkiaComposeShader::~SkiaComposeShader() {
+    delete mFirst;
+    delete mSecond;
+}
+
+void SkiaComposeShader::set(TextureCache* textureCache, GradientCache* gradientCache) {
+    SkiaShader::set(textureCache, gradientCache);
+    mFirst->set(textureCache, gradientCache);
+    mSecond->set(textureCache, gradientCache);
+}
+
+void SkiaComposeShader::describe(ProgramDescription& description, const Extensions& extensions) {
+    mFirst->describe(description, extensions);
+    mSecond->describe(description, extensions);
+    if (mFirst->type() == kBitmap) {
+        description.isBitmapFirst = true;
+    }
+    description.shadersMode = mMode;
+}
+
+void SkiaComposeShader::setupProgram(Program* program, const mat4& modelView,
+        const Snapshot& snapshot, GLuint* textureUnit) {
+    mFirst->setupProgram(program, modelView, snapshot, textureUnit);
+    mSecond->setupProgram(program, modelView, snapshot, textureUnit);
+}
+
+}; // namespace uirenderer
+}; // namespace android
diff --git a/libs/hwui/SkiaShader.h b/libs/hwui/SkiaShader.h
new file mode 100644
index 0000000..b5e6aeb
--- /dev/null
+++ b/libs/hwui/SkiaShader.h
@@ -0,0 +1,163 @@
+/*
+ * Copyright (C) 2010 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef SKIA_SHADER_H
+#define SKIA_SHADER_H
+
+#include <SkShader.h>
+#include <SkXfermode.h>
+
+#include <GLES2/gl2.h>
+
+#include "Extensions.h"
+#include "ProgramCache.h"
+#include "TextureCache.h"
+#include "GradientCache.h"
+#include "Snapshot.h"
+
+namespace android {
+namespace uirenderer {
+
+///////////////////////////////////////////////////////////////////////////////
+// Base shader
+///////////////////////////////////////////////////////////////////////////////
+
+/**
+ * Represents a Skia shader. A shader will modify the GL context and active
+ * program to recreate the original effect.
+ */
+struct SkiaShader {
+    /**
+     * Type of Skia shader in use.
+     */
+    enum Type {
+        kNone,
+        kBitmap,
+        kLinearGradient,
+        kCircularGradient,
+        kSweepGradient,
+        kCompose
+    };
+
+    SkiaShader(Type type, SkShader* key, SkShader::TileMode tileX, SkShader::TileMode tileY,
+            SkMatrix* matrix, bool blend);
+    virtual ~SkiaShader();
+
+    virtual void describe(ProgramDescription& description, const Extensions& extensions);
+    virtual void setupProgram(Program* program, const mat4& modelView, const Snapshot& snapshot,
+            GLuint* textureUnit);
+
+    inline bool blend() const {
+        return mBlend;
+    }
+
+    Type type() const {
+        return mType;
+    }
+
+    virtual void set(TextureCache* textureCache, GradientCache* gradientCache) {
+        mTextureCache = textureCache;
+        mGradientCache = gradientCache;
+    }
+
+    void setMatrix(SkMatrix* matrix) {
+        mMatrix = matrix;
+    }
+
+protected:
+    inline void bindTexture(GLuint texture, GLenum wrapS, GLenum wrapT, GLuint textureUnit);
+
+    Type mType;
+    SkShader* mKey;
+    SkShader::TileMode mTileX;
+    SkShader::TileMode mTileY;
+    SkMatrix* mMatrix;
+    bool mBlend;
+
+    TextureCache* mTextureCache;
+    GradientCache* mGradientCache;
+}; // struct SkiaShader
+
+
+///////////////////////////////////////////////////////////////////////////////
+// Implementations
+///////////////////////////////////////////////////////////////////////////////
+
+/**
+ * A shader that draws a bitmap.
+ */
+struct SkiaBitmapShader: public SkiaShader {
+    SkiaBitmapShader(SkBitmap* bitmap, SkShader* key, SkShader::TileMode tileX,
+            SkShader::TileMode tileY, SkMatrix* matrix, bool blend);
+    ~SkiaBitmapShader();
+
+    void describe(ProgramDescription& description, const Extensions& extensions);
+    void setupProgram(Program* program, const mat4& modelView, const Snapshot& snapshot,
+            GLuint* textureUnit);
+
+private:
+    /**
+     * This method does not work for n == 0.
+     */
+    inline bool isPowerOfTwo(unsigned int n) {
+        return !(n & (n - 1));
+    }
+
+    SkBitmap* mBitmap;
+}; // struct SkiaBitmapShader
+
+/**
+ * A shader that draws a linear gradient.
+ */
+struct SkiaLinearGradientShader: public SkiaShader {
+    SkiaLinearGradientShader(float* bounds, uint32_t* colors, float* positions, int count,
+            SkShader* key, SkShader::TileMode tileMode, SkMatrix* matrix, bool blend);
+    ~SkiaLinearGradientShader();
+
+    void describe(ProgramDescription& description, const Extensions& extensions);
+    void setupProgram(Program* program, const mat4& modelView, const Snapshot& snapshot,
+            GLuint* textureUnit);
+
+private:
+    float* mBounds;
+    uint32_t* mColors;
+    float* mPositions;
+    int mCount;
+}; // struct SkiaLinearGradientShader
+
+/**
+ * A shader that draws two shaders, composited with an xfermode.
+ */
+struct SkiaComposeShader: public SkiaShader {
+    SkiaComposeShader(SkiaShader* first, SkiaShader* second, SkXfermode::Mode mode, SkShader* key);
+    ~SkiaComposeShader();
+
+    void set(TextureCache* textureCache, GradientCache* gradientCache);
+
+    void describe(ProgramDescription& description, const Extensions& extensions);
+    void setupProgram(Program* program, const mat4& modelView, const Snapshot& snapshot,
+            GLuint* textureUnit);
+
+private:
+    SkiaShader* mFirst;
+    SkiaShader* mSecond;
+    SkXfermode::Mode mMode;
+}; // struct SkiaComposeShader
+
+}; // namespace uirenderer
+}; // namespace android
+
+#endif // SKIA_SHADER_H
diff --git a/tests/HwAccelerationTest/src/com/google/android/test/hwui/MoreShadersActivity.java b/tests/HwAccelerationTest/src/com/google/android/test/hwui/MoreShadersActivity.java
index 8ee3117..cbf34a0 100644
--- a/tests/HwAccelerationTest/src/com/google/android/test/hwui/MoreShadersActivity.java
+++ b/tests/HwAccelerationTest/src/com/google/android/test/hwui/MoreShadersActivity.java
@@ -112,11 +112,10 @@
 
             canvas.restore();
 
-            canvas.drawText("OpenGL rendering", 0.0f, 20.0f, mLargePaint);
             canvas.save();
             canvas.translate(40.0f + mDrawWidth + 40.0f, 40.0f);
 
-            //mLargePaint.setShader(mHorGradient);
+            mLargePaint.setShader(mHorGradient);
             canvas.drawText("OpenGL rendering", 0.0f, 20.0f, mLargePaint);
             
             mLargePaint.setShader(mScaled2Shader);