Moved shared element motion to be in called Activity.

Bug 15198605

Change-Id: I07ce4f3366dbdc6957d789f724e49d260fe34faf
diff --git a/core/java/android/app/ActivityTransitionCoordinator.java b/core/java/android/app/ActivityTransitionCoordinator.java
index d08978bd..5e4ddd0 100644
--- a/core/java/android/app/ActivityTransitionCoordinator.java
+++ b/core/java/android/app/ActivityTransitionCoordinator.java
@@ -15,14 +15,23 @@
  */
 package android.app;
 
+import android.content.Context;
+import android.content.res.Resources;
+import android.graphics.Bitmap;
+import android.graphics.Canvas;
+import android.graphics.Matrix;
 import android.graphics.Rect;
+import android.graphics.drawable.BitmapDrawable;
+import android.os.Bundle;
 import android.os.Handler;
 import android.os.ResultReceiver;
 import android.transition.Transition;
 import android.transition.TransitionSet;
 import android.util.ArrayMap;
+import android.util.Pair;
 import android.view.View;
 import android.view.ViewGroup;
+import android.view.ViewTreeObserver;
 import android.view.Window;
 import android.widget.ImageView;
 
@@ -181,6 +190,11 @@
      */
     public static final int MSG_CANCEL = 106;
 
+    /**
+     * When returning, this is the destination location for the shared element.
+     */
+    public static final int MSG_SHARED_ELEMENT_DESTINATION = 107;
+
     final private Window mWindow;
     final protected ArrayList<String> mAllSharedElementNames;
     final protected ArrayList<View> mSharedElements = new ArrayList<View>();
@@ -334,6 +348,210 @@
 
     protected abstract Transition getViewsTransition();
 
+    private static void setSharedElementState(View view, String name, Bundle transitionArgs,
+            int[] parentLoc) {
+        Bundle sharedElementBundle = transitionArgs.getBundle(name);
+        if (sharedElementBundle == null) {
+            return;
+        }
+
+        if (view instanceof ImageView) {
+            int scaleTypeInt = sharedElementBundle.getInt(KEY_SCALE_TYPE, -1);
+            if (scaleTypeInt >= 0) {
+                ImageView imageView = (ImageView) view;
+                ImageView.ScaleType scaleType = SCALE_TYPE_VALUES[scaleTypeInt];
+                imageView.setScaleType(scaleType);
+                if (scaleType == ImageView.ScaleType.MATRIX) {
+                    float[] matrixValues = sharedElementBundle.getFloatArray(KEY_IMAGE_MATRIX);
+                    Matrix matrix = new Matrix();
+                    matrix.setValues(matrixValues);
+                    imageView.setImageMatrix(matrix);
+                }
+            }
+        }
+
+        float z = sharedElementBundle.getFloat(KEY_TRANSLATION_Z);
+        view.setTranslationZ(z);
+
+        int x = sharedElementBundle.getInt(KEY_SCREEN_X);
+        int y = sharedElementBundle.getInt(KEY_SCREEN_Y);
+        int width = sharedElementBundle.getInt(KEY_WIDTH);
+        int height = sharedElementBundle.getInt(KEY_HEIGHT);
+
+        int widthSpec = View.MeasureSpec.makeMeasureSpec(width, View.MeasureSpec.EXACTLY);
+        int heightSpec = View.MeasureSpec.makeMeasureSpec(height, View.MeasureSpec.EXACTLY);
+        view.measure(widthSpec, heightSpec);
+
+        int left = x - parentLoc[0];
+        int top = y - parentLoc[1];
+        int right = left + width;
+        int bottom = top + height;
+        view.layout(left, top, right, bottom);
+    }
+
+    protected ArrayMap<ImageView, Pair<ImageView.ScaleType, Matrix>> setSharedElementState(
+            Bundle sharedElementState, final ArrayList<View> snapshots) {
+        ArrayMap<ImageView, Pair<ImageView.ScaleType, Matrix>> originalImageState =
+                new ArrayMap<ImageView, Pair<ImageView.ScaleType, Matrix>>();
+        if (sharedElementState != null) {
+            int[] tempLoc = new int[2];
+            for (int i = 0; i < mSharedElementNames.size(); i++) {
+                View sharedElement = mSharedElements.get(i);
+                String name = mSharedElementNames.get(i);
+                Pair<ImageView.ScaleType, Matrix> originalState = getOldImageState(sharedElement,
+                        name, sharedElementState);
+                if (originalState != null) {
+                    originalImageState.put((ImageView) sharedElement, originalState);
+                }
+                View parent = (View) sharedElement.getParent();
+                parent.getLocationOnScreen(tempLoc);
+                setSharedElementState(sharedElement, name, sharedElementState, tempLoc);
+            }
+        }
+        mListener.setSharedElementStart(mSharedElementNames, mSharedElements, snapshots);
+
+        getDecor().getViewTreeObserver().addOnPreDrawListener(
+                new ViewTreeObserver.OnPreDrawListener() {
+                    @Override
+                    public boolean onPreDraw() {
+                        getDecor().getViewTreeObserver().removeOnPreDrawListener(this);
+                        mListener.setSharedElementEnd(mSharedElementNames, mSharedElements,
+                                snapshots);
+                        return true;
+                    }
+                }
+        );
+        return originalImageState;
+    }
+
+    private static Pair<ImageView.ScaleType, Matrix> getOldImageState(View view, String name,
+            Bundle transitionArgs) {
+        if (!(view instanceof ImageView)) {
+            return null;
+        }
+        Bundle bundle = transitionArgs.getBundle(name);
+        if (bundle == null) {
+            return null;
+        }
+        int scaleTypeInt = bundle.getInt(KEY_SCALE_TYPE, -1);
+        if (scaleTypeInt < 0) {
+            return null;
+        }
+
+        ImageView imageView = (ImageView) view;
+        ImageView.ScaleType originalScaleType = imageView.getScaleType();
+
+        Matrix originalMatrix = null;
+        if (originalScaleType == ImageView.ScaleType.MATRIX) {
+            originalMatrix = new Matrix(imageView.getImageMatrix());
+        }
+
+        return Pair.create(originalScaleType, originalMatrix);
+    }
+
+    protected ArrayList<View> createSnapshots(Bundle state, Collection<String> names) {
+        int numSharedElements = names.size();
+        if (numSharedElements == 0) {
+            return null;
+        }
+        ArrayList<View> snapshots = new ArrayList<View>(numSharedElements);
+        Context context = getWindow().getContext();
+        int[] parentLoc = new int[2];
+        getDecor().getLocationOnScreen(parentLoc);
+        for (String name: names) {
+            Bundle sharedElementBundle = state.getBundle(name);
+            if (sharedElementBundle != null) {
+                Bitmap bitmap = sharedElementBundle.getParcelable(KEY_BITMAP);
+                View snapshot = new View(context);
+                Resources resources = getWindow().getContext().getResources();
+                if (bitmap != null) {
+                    snapshot.setBackground(new BitmapDrawable(resources, bitmap));
+                }
+                snapshot.setViewName(name);
+                setSharedElementState(snapshot, name, state, parentLoc);
+                snapshots.add(snapshot);
+            }
+        }
+        return snapshots;
+    }
+
+    protected static void setOriginalImageViewState(
+            ArrayMap<ImageView, Pair<ImageView.ScaleType, Matrix>> originalState) {
+        for (int i = 0; i < originalState.size(); i++) {
+            ImageView imageView = originalState.keyAt(i);
+            Pair<ImageView.ScaleType, Matrix> state = originalState.valueAt(i);
+            imageView.setScaleType(state.first);
+            imageView.setImageMatrix(state.second);
+        }
+    }
+
+    protected Bundle captureSharedElementState() {
+        Bundle bundle = new Bundle();
+        int[] tempLoc = new int[2];
+        for (int i = 0; i < mSharedElementNames.size(); i++) {
+            View sharedElement = mSharedElements.get(i);
+            String name = mSharedElementNames.get(i);
+            captureSharedElementState(sharedElement, name, bundle, tempLoc);
+        }
+        return bundle;
+    }
+
+    /**
+     * Captures placement information for Views with a shared element name for
+     * Activity Transitions.
+     *
+     * @param view           The View to capture the placement information for.
+     * @param name           The shared element name in the target Activity to apply the placement
+     *                       information for.
+     * @param transitionArgs Bundle to store shared element placement information.
+     * @param tempLoc        A temporary int[2] for capturing the current location of views.
+     */
+    private static void captureSharedElementState(View view, String name, Bundle transitionArgs,
+            int[] tempLoc) {
+        Bundle sharedElementBundle = new Bundle();
+        view.getLocationOnScreen(tempLoc);
+        float scaleX = view.getScaleX();
+        sharedElementBundle.putInt(KEY_SCREEN_X, tempLoc[0]);
+        int width = Math.round(view.getWidth() * scaleX);
+        sharedElementBundle.putInt(KEY_WIDTH, width);
+
+        float scaleY = view.getScaleY();
+        sharedElementBundle.putInt(KEY_SCREEN_Y, tempLoc[1]);
+        int height = Math.round(view.getHeight() * scaleY);
+        sharedElementBundle.putInt(KEY_HEIGHT, height);
+
+        sharedElementBundle.putFloat(KEY_TRANSLATION_Z, view.getTranslationZ());
+
+        if (width > 0 && height > 0) {
+            Bitmap bitmap = Bitmap.createBitmap(width, height, Bitmap.Config.ARGB_8888);
+            Canvas canvas = new Canvas(bitmap);
+            view.draw(canvas);
+            sharedElementBundle.putParcelable(KEY_BITMAP, bitmap);
+        }
+
+        if (view instanceof ImageView) {
+            ImageView imageView = (ImageView) view;
+            int scaleTypeInt = scaleTypeToInt(imageView.getScaleType());
+            sharedElementBundle.putInt(KEY_SCALE_TYPE, scaleTypeInt);
+            if (imageView.getScaleType() == ImageView.ScaleType.MATRIX) {
+                float[] matrix = new float[9];
+                imageView.getImageMatrix().getValues(matrix);
+                sharedElementBundle.putFloatArray(KEY_IMAGE_MATRIX, matrix);
+            }
+        }
+
+        transitionArgs.putBundle(name, sharedElementBundle);
+    }
+
+    private static int scaleTypeToInt(ImageView.ScaleType scaleType) {
+        for (int i = 0; i < SCALE_TYPE_VALUES.length; i++) {
+            if (scaleType == SCALE_TYPE_VALUES[i]) {
+                return i;
+            }
+        }
+        return -1;
+    }
+
     private static class FixedEpicenterCallback extends Transition.EpicenterCallback {
         private Rect mEpicenter;
 
diff --git a/core/java/android/app/EnterTransitionCoordinator.java b/core/java/android/app/EnterTransitionCoordinator.java
index bc97852..a8617b8 100644
--- a/core/java/android/app/EnterTransitionCoordinator.java
+++ b/core/java/android/app/EnterTransitionCoordinator.java
@@ -18,11 +18,7 @@
 import android.animation.Animator;
 import android.animation.AnimatorListenerAdapter;
 import android.animation.ObjectAnimator;
-import android.content.Context;
-import android.content.res.Resources;
-import android.graphics.Bitmap;
 import android.graphics.Matrix;
-import android.graphics.drawable.BitmapDrawable;
 import android.graphics.drawable.Drawable;
 import android.os.Bundle;
 import android.os.Handler;
@@ -38,7 +34,6 @@
 import android.widget.ImageView;
 
 import java.util.ArrayList;
-import java.util.Collection;
 
 /**
  * This ActivityTransitionCoordinator is created by the Activity to manage
@@ -56,6 +51,7 @@
     private Handler mHandler;
     private boolean mIsCanceled;
     private ObjectAnimator mBackgroundAnimator;
+    private boolean mIsExitTransitionComplete;
 
     public EnterTransitionCoordinator(Activity activity, ResultReceiver resultReceiver,
             ArrayList<String> sharedElementNames,
@@ -76,6 +72,8 @@
                 }
             };
             mHandler.sendEmptyMessageDelayed(MSG_CANCEL, MAX_WAIT_MS);
+            Bundle state = captureSharedElementState();
+            mResultReceiver.send(MSG_SHARED_ELEMENT_DESTINATION, state);
         }
     }
 
@@ -98,9 +96,8 @@
                 break;
             case MSG_EXIT_TRANSITION_COMPLETE:
                 if (!mIsCanceled) {
-                    if (!mSharedElementTransitionStarted) {
-                        send(resultCode, resultData);
-                    } else {
+                    mIsExitTransitionComplete = true;
+                    if (mSharedElementTransitionStarted) {
                         onRemoteExitTransitionComplete();
                     }
                 }
@@ -183,6 +180,7 @@
         setViewVisibility(mSharedElements, View.VISIBLE);
         ArrayMap<ImageView, Pair<ImageView.ScaleType, Matrix>> originalImageViewState =
                 setSharedElementState(sharedElementState, sharedElementSnapshots);
+        requestLayoutForSharedElements();
 
         boolean startEnterTransition = allowOverlappingTransitions();
         boolean startSharedElementTransition = true;
@@ -200,6 +198,13 @@
         mResultReceiver = null; // all done sending messages.
     }
 
+    private void requestLayoutForSharedElements() {
+        int numSharedElements = mSharedElements.size();
+        for (int i = 0; i < numSharedElements; i++) {
+            mSharedElements.get(i).requestLayout();
+        }
+    }
+
     private Transition beginTransition(boolean startEnterTransition,
             boolean startSharedElementTransition) {
         Transition sharedElementTransition = null;
@@ -213,6 +218,19 @@
         }
 
         Transition transition = mergeTransitions(sharedElementTransition, viewsTransition);
+        if (startSharedElementTransition) {
+            if (transition == null) {
+                sharedElementTransitionStarted();
+            } else {
+                transition.addListener(new Transition.TransitionListenerAdapter() {
+                    @Override
+                    public void onTransitionStart(Transition transition) {
+                        transition.removeListener(this);
+                        sharedElementTransitionStarted();
+                    }
+                });
+            }
+        }
         if (transition != null) {
             TransitionManager.beginDelayedTransition(getDecor(), transition);
             if (startSharedElementTransition && !mSharedElementNames.isEmpty()) {
@@ -224,6 +242,13 @@
         return transition;
     }
 
+    private void sharedElementTransitionStarted() {
+        mSharedElementTransitionStarted = true;
+        if (mIsExitTransitionComplete) {
+            send(MSG_EXIT_TRANSITION_COMPLETE, null);
+        }
+    }
+
     private void startEnterTransition(Transition transition) {
         setViewVisibility(mTransitioningViews, View.VISIBLE);
         if (!mIsReturning) {
@@ -310,142 +335,4 @@
             startEnterTransition(transition);
         }
     }
-
-    private ArrayList<View> createSnapshots(Bundle state, Collection<String> names) {
-        int numSharedElements = names.size();
-        if (numSharedElements == 0) {
-            return null;
-        }
-        ArrayList<View> snapshots = new ArrayList<View>(numSharedElements);
-        Context context = getWindow().getContext();
-        int[] parentLoc = new int[2];
-        getDecor().getLocationOnScreen(parentLoc);
-        for (String name: names) {
-            Bundle sharedElementBundle = state.getBundle(name);
-            if (sharedElementBundle != null) {
-                Bitmap bitmap = sharedElementBundle.getParcelable(KEY_BITMAP);
-                View snapshot = new View(context);
-                Resources resources = getWindow().getContext().getResources();
-                snapshot.setBackground(new BitmapDrawable(resources, bitmap));
-                snapshot.setViewName(name);
-                setSharedElementState(snapshot, name, state, parentLoc);
-                snapshots.add(snapshot);
-            }
-        }
-        return snapshots;
-    }
-
-    private static void setSharedElementState(View view, String name, Bundle transitionArgs,
-            int[] parentLoc) {
-        Bundle sharedElementBundle = transitionArgs.getBundle(name);
-        if (sharedElementBundle == null) {
-            return;
-        }
-
-        if (view instanceof ImageView) {
-            int scaleTypeInt = sharedElementBundle.getInt(KEY_SCALE_TYPE, -1);
-            if (scaleTypeInt >= 0) {
-                ImageView imageView = (ImageView) view;
-                ImageView.ScaleType scaleType = SCALE_TYPE_VALUES[scaleTypeInt];
-                imageView.setScaleType(scaleType);
-                if (scaleType == ImageView.ScaleType.MATRIX) {
-                    float[] matrixValues = sharedElementBundle.getFloatArray(KEY_IMAGE_MATRIX);
-                    Matrix matrix = new Matrix();
-                    matrix.setValues(matrixValues);
-                    imageView.setImageMatrix(matrix);
-                }
-            }
-        }
-
-        float z = sharedElementBundle.getFloat(KEY_TRANSLATION_Z);
-        view.setTranslationZ(z);
-
-        int x = sharedElementBundle.getInt(KEY_SCREEN_X);
-        int y = sharedElementBundle.getInt(KEY_SCREEN_Y);
-        int width = sharedElementBundle.getInt(KEY_WIDTH);
-        int height = sharedElementBundle.getInt(KEY_HEIGHT);
-
-        int widthSpec = View.MeasureSpec.makeMeasureSpec(width, View.MeasureSpec.EXACTLY);
-        int heightSpec = View.MeasureSpec.makeMeasureSpec(height, View.MeasureSpec.EXACTLY);
-        view.measure(widthSpec, heightSpec);
-
-        int left = x - parentLoc[0];
-        int top = y - parentLoc[1];
-        int right = left + width;
-        int bottom = top + height;
-        view.layout(left, top, right, bottom);
-    }
-
-    private ArrayMap<ImageView, Pair<ImageView.ScaleType, Matrix>> setSharedElementState(
-            Bundle sharedElementState, final ArrayList<View> snapshots) {
-        ArrayMap<ImageView, Pair<ImageView.ScaleType, Matrix>> originalImageState =
-                new ArrayMap<ImageView, Pair<ImageView.ScaleType, Matrix>>();
-        if (sharedElementState != null) {
-            int[] tempLoc = new int[2];
-            for (int i = 0; i < mSharedElementNames.size(); i++) {
-                View sharedElement = mSharedElements.get(i);
-                String name = mSharedElementNames.get(i);
-                Pair<ImageView.ScaleType, Matrix> originalState = getOldImageState(sharedElement,
-                        name, sharedElementState);
-                if (originalState != null) {
-                    originalImageState.put((ImageView) sharedElement, originalState);
-                }
-                View parent = (View) sharedElement.getParent();
-                parent.getLocationOnScreen(tempLoc);
-                setSharedElementState(sharedElement, name, sharedElementState, tempLoc);
-                sharedElement.requestLayout();
-            }
-        }
-        mListener.setSharedElementStart(mSharedElementNames, mSharedElements, snapshots);
-
-        getDecor().getViewTreeObserver().addOnPreDrawListener(
-                new ViewTreeObserver.OnPreDrawListener() {
-                    @Override
-                    public boolean onPreDraw() {
-                        getDecor().getViewTreeObserver().removeOnPreDrawListener(this);
-                        mListener.setSharedElementEnd(mSharedElementNames, mSharedElements,
-                                snapshots);
-                        mSharedElementTransitionStarted = true;
-                        return true;
-                    }
-                }
-        );
-        return originalImageState;
-    }
-
-    private static Pair<ImageView.ScaleType, Matrix> getOldImageState(View view, String name,
-            Bundle transitionArgs) {
-        if (!(view instanceof ImageView)) {
-            return null;
-        }
-        Bundle bundle = transitionArgs.getBundle(name);
-        if (bundle == null) {
-            return null;
-        }
-        int scaleTypeInt = bundle.getInt(KEY_SCALE_TYPE, -1);
-        if (scaleTypeInt < 0) {
-            return null;
-        }
-
-        ImageView imageView = (ImageView) view;
-        ImageView.ScaleType originalScaleType = imageView.getScaleType();
-
-        Matrix originalMatrix = null;
-        if (originalScaleType == ImageView.ScaleType.MATRIX) {
-            originalMatrix = new Matrix(imageView.getImageMatrix());
-        }
-
-        return Pair.create(originalScaleType, originalMatrix);
-    }
-
-    private static void setOriginalImageViewState(
-            ArrayMap<ImageView, Pair<ImageView.ScaleType, Matrix>> originalState) {
-        for (int i = 0; i < originalState.size(); i++) {
-            ImageView imageView = originalState.keyAt(i);
-            Pair<ImageView.ScaleType, Matrix> state = originalState.valueAt(i);
-            imageView.setScaleType(state.first);
-            imageView.setImageMatrix(state.second);
-        }
-    }
-
 }
diff --git a/core/java/android/app/ExitTransitionCoordinator.java b/core/java/android/app/ExitTransitionCoordinator.java
index 93eb53e..a71d649 100644
--- a/core/java/android/app/ExitTransitionCoordinator.java
+++ b/core/java/android/app/ExitTransitionCoordinator.java
@@ -19,8 +19,6 @@
 import android.animation.AnimatorListenerAdapter;
 import android.animation.ObjectAnimator;
 import android.content.Intent;
-import android.graphics.Bitmap;
-import android.graphics.Canvas;
 import android.graphics.drawable.ColorDrawable;
 import android.graphics.drawable.Drawable;
 import android.os.Bundle;
@@ -29,7 +27,7 @@
 import android.transition.Transition;
 import android.transition.TransitionManager;
 import android.view.View;
-import android.widget.ImageView;
+import android.view.ViewTreeObserver;
 
 import java.util.ArrayList;
 
@@ -62,6 +60,10 @@
 
     private boolean mIsHidden;
 
+    private boolean mExitTransitionStarted;
+
+    private Bundle mExitSharedElementBundle;
+
     public ExitTransitionCoordinator(Activity activity, ArrayList<String> names,
             ArrayList<String> accepted, ArrayList<String> mapped, boolean isReturning) {
         super(activity.getWindow(), names, accepted, mapped, getListener(activity, isReturning),
@@ -102,15 +104,32 @@
                 setViewVisibility(mSharedElements, View.VISIBLE);
                 mIsHidden = true;
                 break;
+            case MSG_SHARED_ELEMENT_DESTINATION:
+                mExitSharedElementBundle = resultData;
+                if (mExitTransitionStarted) {
+                    startSharedElementExit();
+                }
+                break;
+        }
+    }
+
+    private void startSharedElementExit() {
+        if (!mSharedElements.isEmpty() && getSharedElementTransition() != null) {
+            Transition transition = getSharedElementExitTransition();
+            TransitionManager.beginDelayedTransition(getDecor(), transition);
+            ArrayList<View> sharedElementSnapshots = createSnapshots(mExitSharedElementBundle,
+                    mSharedElementNames);
+            setSharedElementState(mExitSharedElementBundle, sharedElementSnapshots);
         }
     }
 
     private void hideSharedElements() {
         setViewVisibility(mSharedElements, View.INVISIBLE);
+        finishIfNecessary();
     }
 
     public void startExit() {
-        beginTransition();
+        beginTransitions();
         setViewVisibility(mTransitioningViews, View.INVISIBLE);
     }
 
@@ -140,7 +159,30 @@
                 }
             }
         }, options);
-        startExit();
+        Transition sharedElementTransition = mSharedElements.isEmpty()
+                ? null : getSharedElementTransition();
+        if (sharedElementTransition == null) {
+            sharedElementTransitionComplete();
+        }
+        Transition transition = mergeTransitions(sharedElementTransition, getExitTransition());
+        if (transition == null) {
+            mExitTransitionStarted = true;
+        } else {
+            TransitionManager.beginDelayedTransition(getDecor(), transition);
+            setViewVisibility(mTransitioningViews, View.INVISIBLE);
+            getDecor().getViewTreeObserver().addOnPreDrawListener(new ViewTreeObserver.OnPreDrawListener() {
+                @Override
+                public boolean onPreDraw() {
+                    getDecor().getViewTreeObserver().removeOnPreDrawListener(this);
+                    mExitTransitionStarted = true;
+                    if (mExitSharedElementBundle != null) {
+                        startSharedElementExit();
+                    }
+                    notifyComplete();
+                    return true;
+                }
+            });
+        }
     }
 
     private void fadeOutBackground() {
@@ -162,24 +204,13 @@
         }
     }
 
-    private void beginTransition() {
-        Transition sharedElementTransition = configureTransition(getSharedElementTransition());
-        Transition viewsTransition = configureTransition(getViewsTransition());
-        viewsTransition = addTargets(viewsTransition, mTransitioningViews);
-        if (sharedElementTransition == null || mSharedElements.isEmpty()) {
-            sharedElementTransitionComplete();
-            sharedElementTransition = null;
-        } else {
-            sharedElementTransition.addListener(new Transition.TransitionListenerAdapter() {
-                @Override
-                public void onTransitionEnd(Transition transition) {
-                    sharedElementTransitionComplete();
-                }
-            });
+    private Transition getExitTransition() {
+        Transition viewsTransition = null;
+        if (!mTransitioningViews.isEmpty()) {
+            viewsTransition = configureTransition(getViewsTransition());
         }
-        if (viewsTransition == null || mTransitioningViews.isEmpty()) {
+        if (viewsTransition == null) {
             exitTransitionComplete();
-            viewsTransition = null;
         } else {
             viewsTransition.addListener(new Transition.TransitionListenerAdapter() {
                 @Override
@@ -189,13 +220,46 @@
                         setViewVisibility(mTransitioningViews, View.VISIBLE);
                     }
                 }
+
+                @Override
+                public void onTransitionCancel(Transition transition) {
+                    super.onTransitionCancel(transition);
+                }
             });
         }
+        return viewsTransition;
+    }
+
+    private Transition getSharedElementExitTransition() {
+        Transition sharedElementTransition = null;
+        if (!mSharedElements.isEmpty()) {
+            sharedElementTransition = configureTransition(getSharedElementTransition());
+        }
+        if (sharedElementTransition == null) {
+            sharedElementTransitionComplete();
+        } else {
+            sharedElementTransition.addListener(new Transition.TransitionListenerAdapter() {
+                @Override
+                public void onTransitionEnd(Transition transition) {
+                    sharedElementTransitionComplete();
+                    if (mIsHidden) {
+                        setViewVisibility(mSharedElements, View.VISIBLE);
+                    }
+                }
+            });
+            mSharedElements.get(0).invalidate();
+        }
+        return sharedElementTransition;
+    }
+
+    private void beginTransitions() {
+        Transition sharedElementTransition = getSharedElementExitTransition();
+        Transition viewsTransition = getExitTransition();
 
         Transition transition = mergeTransitions(sharedElementTransition, viewsTransition);
-        TransitionManager.beginDelayedTransition(getDecor(), transition);
-        if (viewsTransition == null && sharedElementTransition != null) {
-            mSharedElements.get(0).requestLayout();
+        mExitTransitionStarted = true;
+        if (transition != null) {
+            TransitionManager.beginDelayedTransition(getDecor(), transition);
         }
     }
 
@@ -205,18 +269,12 @@
     }
 
     protected boolean isReadyToNotify() {
-        return mSharedElementBundle != null && mResultReceiver != null && mIsBackgroundReady;
+        return mSharedElementBundle != null && mResultReceiver != null && mIsBackgroundReady
+                && mExitTransitionStarted;
     }
 
     private void sharedElementTransitionComplete() {
-        Bundle bundle = new Bundle();
-        int[] tempLoc = new int[2];
-        for (int i = 0; i < mSharedElementNames.size(); i++) {
-            View sharedElement = mSharedElements.get(i);
-            String name = mSharedElementNames.get(i);
-            captureSharedElementState(sharedElement, name, bundle, tempLoc);
-        }
-        mSharedElementBundle = bundle;
+        mSharedElementBundle = captureSharedElementState();
         notifyComplete();
     }
 
@@ -230,15 +288,23 @@
                 mExitNotified = true;
                 mResultReceiver.send(MSG_EXIT_TRANSITION_COMPLETE, null);
                 mResultReceiver = null; // done talking
-                if (mIsReturning) {
-                    mActivity.finish();
-                    mActivity.overridePendingTransition(0, 0);
-                }
-                mActivity = null;
+                finishIfNecessary();
             }
         }
     }
 
+    private void finishIfNecessary() {
+        if (mIsReturning && mExitNotified && (mSharedElements.isEmpty()
+                || mSharedElements.get(0).getVisibility() == View.INVISIBLE)) {
+            mActivity.finish();
+            mActivity.overridePendingTransition(0, 0);
+            mActivity = null;
+        }
+        if (!mIsReturning && mExitNotified) {
+            mActivity = null; // don't need it anymore
+        }
+    }
+
     @Override
     protected Transition getViewsTransition() {
         if (mIsReturning) {
@@ -255,58 +321,4 @@
             return getWindow().getSharedElementExitTransition();
         }
     }
-
-    /**
-     * Captures placement information for Views with a shared element name for
-     * Activity Transitions.
-     *
-     * @param view           The View to capture the placement information for.
-     * @param name           The shared element name in the target Activity to apply the placement
-     *                       information for.
-     * @param transitionArgs Bundle to store shared element placement information.
-     * @param tempLoc        A temporary int[2] for capturing the current location of views.
-     */
-    private static void captureSharedElementState(View view, String name, Bundle transitionArgs,
-            int[] tempLoc) {
-        Bundle sharedElementBundle = new Bundle();
-        view.getLocationOnScreen(tempLoc);
-        float scaleX = view.getScaleX();
-        sharedElementBundle.putInt(KEY_SCREEN_X, tempLoc[0]);
-        int width = Math.round(view.getWidth() * scaleX);
-        sharedElementBundle.putInt(KEY_WIDTH, width);
-
-        float scaleY = view.getScaleY();
-        sharedElementBundle.putInt(KEY_SCREEN_Y, tempLoc[1]);
-        int height = Math.round(view.getHeight() * scaleY);
-        sharedElementBundle.putInt(KEY_HEIGHT, height);
-
-        sharedElementBundle.putFloat(KEY_TRANSLATION_Z, view.getTranslationZ());
-
-        Bitmap bitmap = Bitmap.createBitmap(width, height, Bitmap.Config.ARGB_8888);
-        Canvas canvas = new Canvas(bitmap);
-        view.draw(canvas);
-        sharedElementBundle.putParcelable(KEY_BITMAP, bitmap);
-
-        if (view instanceof ImageView) {
-            ImageView imageView = (ImageView) view;
-            int scaleTypeInt = scaleTypeToInt(imageView.getScaleType());
-            sharedElementBundle.putInt(KEY_SCALE_TYPE, scaleTypeInt);
-            if (imageView.getScaleType() == ImageView.ScaleType.MATRIX) {
-                float[] matrix = new float[9];
-                imageView.getImageMatrix().getValues(matrix);
-                sharedElementBundle.putFloatArray(KEY_IMAGE_MATRIX, matrix);
-            }
-        }
-
-        transitionArgs.putBundle(name, sharedElementBundle);
-    }
-
-    private static int scaleTypeToInt(ImageView.ScaleType scaleType) {
-        for (int i = 0; i < SCALE_TYPE_VALUES.length; i++) {
-            if (scaleType == SCALE_TYPE_VALUES[i]) {
-                return i;
-            }
-        }
-        return -1;
-    }
 }