diff --git a/core/java/android/view/IPinnedStackListener.aidl b/core/java/android/view/IPinnedStackListener.aidl
index 596d55a..071c259 100644
--- a/core/java/android/view/IPinnedStackListener.aidl
+++ b/core/java/android/view/IPinnedStackListener.aidl
@@ -60,20 +60,12 @@
     void onActionsChanged(in ParceledListSlice actions);
 
     /**
-     * Called by the window manager to notify the listener to save the reentry fraction and size,
-     * typically when an Activity leaves PiP (picture-in-picture) mode to fullscreen.
-     * {@param componentName} represents the application component of PiP window
-     * while {@param bounds} is the current PiP bounds used to calculate the
-     * reentry snap fraction and size.
-     */
-    void onSaveReentryBounds(in ComponentName componentName, in Rect bounds);
-
-    /**
-     * Called by the window manager to notify the listener to reset saved reentry fraction and size,
-     * typically when an Activity enters PiP (picture-in-picture) mode from fullscreen.
+     * Called by the window manager to notify the listener that Activity (was or is in pinned mode)
+     * is hidden (either stopped or removed). This is generally used as a signal to reset saved
+     * reentry fraction and size.
      * {@param componentName} represents the application component of PiP window.
      */
-    void onResetReentryBounds(in ComponentName componentName);
+    void onActivityHidden(in ComponentName componentName);
 
     /**
      * Called when the window manager has detected change on DisplayInfo,  or
diff --git a/packages/SystemUI/shared/src/com/android/systemui/shared/system/PinnedStackListenerForwarder.java b/packages/SystemUI/shared/src/com/android/systemui/shared/system/PinnedStackListenerForwarder.java
index 360244c..34a0268 100644
--- a/packages/SystemUI/shared/src/com/android/systemui/shared/system/PinnedStackListenerForwarder.java
+++ b/packages/SystemUI/shared/src/com/android/systemui/shared/system/PinnedStackListenerForwarder.java
@@ -74,16 +74,9 @@
     }
 
     @Override
-    public void onSaveReentryBounds(ComponentName componentName, Rect bounds) {
+    public void onActivityHidden(ComponentName componentName) {
         for (PinnedStackListener listener : mListeners) {
-            listener.onSaveReentryBounds(componentName, bounds);
-        }
-    }
-
-    @Override
-    public void onResetReentryBounds(ComponentName componentName) {
-        for (PinnedStackListener listener : mListeners) {
-            listener.onResetReentryBounds(componentName);
+            listener.onActivityHidden(componentName);
         }
     }
 
@@ -121,9 +114,7 @@
 
         public void onActionsChanged(ParceledListSlice actions) {}
 
-        public void onSaveReentryBounds(ComponentName componentName, Rect bounds) {}
-
-        public void onResetReentryBounds(ComponentName componentName) {}
+        public void onActivityHidden(ComponentName componentName) {}
 
         public void onDisplayInfoChanged(DisplayInfo displayInfo) {}
 
diff --git a/packages/SystemUI/src/com/android/systemui/pip/PipAnimationController.java b/packages/SystemUI/src/com/android/systemui/pip/PipAnimationController.java
index 232c23d..1533592 100644
--- a/packages/SystemUI/src/com/android/systemui/pip/PipAnimationController.java
+++ b/packages/SystemUI/src/com/android/systemui/pip/PipAnimationController.java
@@ -49,10 +49,10 @@
     @Retention(RetentionPolicy.SOURCE)
     public @interface AnimationType {}
 
-    static final int TRANSITION_DIRECTION_NONE = 0;
-    static final int TRANSITION_DIRECTION_SAME = 1;
-    static final int TRANSITION_DIRECTION_TO_PIP = 2;
-    static final int TRANSITION_DIRECTION_TO_FULLSCREEN = 3;
+    public static final int TRANSITION_DIRECTION_NONE = 0;
+    public static final int TRANSITION_DIRECTION_SAME = 1;
+    public static final int TRANSITION_DIRECTION_TO_PIP = 2;
+    public static final int TRANSITION_DIRECTION_TO_FULLSCREEN = 3;
 
     @IntDef(prefix = { "TRANSITION_DIRECTION_" }, value = {
             TRANSITION_DIRECTION_NONE,
@@ -61,7 +61,7 @@
             TRANSITION_DIRECTION_TO_FULLSCREEN
     })
     @Retention(RetentionPolicy.SOURCE)
-    @interface TransitionDirection {}
+    public @interface TransitionDirection {}
 
     private final Interpolator mFastOutSlowInInterpolator;
     private final PipSurfaceTransactionHelper mSurfaceTransactionHelper;
diff --git a/packages/SystemUI/src/com/android/systemui/pip/PipTaskOrganizer.java b/packages/SystemUI/src/com/android/systemui/pip/PipTaskOrganizer.java
index af8b184..da171b2 100644
--- a/packages/SystemUI/src/com/android/systemui/pip/PipTaskOrganizer.java
+++ b/packages/SystemUI/src/com/android/systemui/pip/PipTaskOrganizer.java
@@ -30,6 +30,7 @@
 import android.annotation.Nullable;
 import android.app.ActivityManager;
 import android.app.PictureInPictureParams;
+import android.content.ComponentName;
 import android.content.Context;
 import android.content.pm.ActivityInfo;
 import android.graphics.Rect;
@@ -94,7 +95,8 @@
             mMainHandler.post(() -> {
                 for (int i = mPipTransitionCallbacks.size() - 1; i >= 0; i--) {
                     final PipTransitionCallback callback = mPipTransitionCallbacks.get(i);
-                    callback.onPipTransitionStarted();
+                    callback.onPipTransitionStarted(mTaskInfo.baseActivity,
+                            animator.getTransitionDirection());
                 }
             });
         }
@@ -105,7 +107,8 @@
             mMainHandler.post(() -> {
                 for (int i = mPipTransitionCallbacks.size() - 1; i >= 0; i--) {
                     final PipTransitionCallback callback = mPipTransitionCallbacks.get(i);
-                    callback.onPipTransitionFinished();
+                    callback.onPipTransitionFinished(mTaskInfo.baseActivity,
+                            animator.getTransitionDirection());
                 }
             });
             finishResize(tx, animator.getDestinationBounds(), animator.getTransitionDirection());
@@ -116,7 +119,8 @@
             mMainHandler.post(() -> {
                 for (int i = mPipTransitionCallbacks.size() - 1; i >= 0; i--) {
                     final PipTransitionCallback callback = mPipTransitionCallbacks.get(i);
-                    callback.onPipTransitionCanceled();
+                    callback.onPipTransitionCanceled(mTaskInfo.baseActivity,
+                            animator.getTransitionDirection());
                 }
             });
         }
@@ -201,6 +205,10 @@
         return mUpdateHandler;
     }
 
+    public Rect getLastReportedBounds() {
+        return new Rect(mLastReportedBounds);
+    }
+
     /**
      * Registers {@link PipTransitionCallback} to receive transition callbacks.
      */
@@ -532,16 +540,16 @@
         /**
          * Callback when the pip transition is started.
          */
-        void onPipTransitionStarted();
+        void onPipTransitionStarted(ComponentName activity, int direction);
 
         /**
          * Callback when the pip transition is finished.
          */
-        void onPipTransitionFinished();
+        void onPipTransitionFinished(ComponentName activity, int direction);
 
         /**
          * Callback when the pip transition is cancelled.
          */
-        void onPipTransitionCanceled();
+        void onPipTransitionCanceled(ComponentName activity, int direction);
     }
 }
diff --git a/packages/SystemUI/src/com/android/systemui/pip/phone/PipManager.java b/packages/SystemUI/src/com/android/systemui/pip/phone/PipManager.java
index 9722c08..ea19532 100644
--- a/packages/SystemUI/src/com/android/systemui/pip/phone/PipManager.java
+++ b/packages/SystemUI/src/com/android/systemui/pip/phone/PipManager.java
@@ -20,6 +20,8 @@
 import static android.app.WindowConfiguration.WINDOWING_MODE_PINNED;
 import static android.window.WindowOrganizer.TaskOrganizer;
 
+import static com.android.systemui.pip.PipAnimationController.TRANSITION_DIRECTION_TO_FULLSCREEN;
+
 import android.app.ActivityManager;
 import android.app.ActivityTaskManager;
 import android.app.IActivityManager;
@@ -170,24 +172,7 @@
         }
 
         @Override
-        public void onSaveReentryBounds(ComponentName componentName, Rect bounds) {
-            mHandler.post(() -> {
-                // On phones, the expansion animation that happens on pip tap before restoring
-                // to fullscreen makes it so that the bounds received here are the expanded
-                // bounds. We want to restore to the unexpanded bounds when re-entering pip,
-                // so we save the bounds before expansion (normal) instead of the current
-                // bounds.
-                mReentryBounds.set(mTouchHandler.getNormalBounds());
-                // Apply the snap fraction of the current bounds to the normal bounds.
-                float snapFraction = mPipBoundsHandler.getSnapFraction(bounds);
-                mPipBoundsHandler.applySnapFraction(mReentryBounds, snapFraction);
-                // Save reentry bounds (normal non-expand bounds with current position applied).
-                mPipBoundsHandler.onSaveReentryBounds(componentName, mReentryBounds);
-            });
-        }
-
-        @Override
-        public void onResetReentryBounds(ComponentName componentName) {
+        public void onActivityHidden(ComponentName componentName) {
             mHandler.post(() -> mPipBoundsHandler.onResetReentryBounds(componentName));
         }
 
@@ -325,7 +310,21 @@
     }
 
     @Override
-    public void onPipTransitionStarted() {
+    public void onPipTransitionStarted(ComponentName activity, int direction) {
+        if (direction == TRANSITION_DIRECTION_TO_FULLSCREEN) {
+            // On phones, the expansion animation that happens on pip tap before restoring
+            // to fullscreen makes it so that the bounds received here are the expanded
+            // bounds. We want to restore to the unexpanded bounds when re-entering pip,
+            // so we save the bounds before expansion (normal) instead of the current
+            // bounds.
+            mReentryBounds.set(mTouchHandler.getNormalBounds());
+            // Apply the snap fraction of the current bounds to the normal bounds.
+            final Rect bounds = mPipTaskOrganizer.getLastReportedBounds();
+            float snapFraction = mPipBoundsHandler.getSnapFraction(bounds);
+            mPipBoundsHandler.applySnapFraction(mReentryBounds, snapFraction);
+            // Save reentry bounds (normal non-expand bounds with current position applied).
+            mPipBoundsHandler.onSaveReentryBounds(activity, mReentryBounds);
+        }
         // Disable touches while the animation is running
         mTouchHandler.setTouchEnabled(false);
         if (mPinnedStackAnimationRecentsListener != null) {
@@ -338,12 +337,12 @@
     }
 
     @Override
-    public void onPipTransitionFinished() {
+    public void onPipTransitionFinished(ComponentName activity, int direction) {
         onPipTransitionFinishedOrCanceled();
     }
 
     @Override
-    public void onPipTransitionCanceled() {
+    public void onPipTransitionCanceled(ComponentName activity, int direction) {
         onPipTransitionFinishedOrCanceled();
     }
 
diff --git a/packages/SystemUI/src/com/android/systemui/pip/tv/PipManager.java b/packages/SystemUI/src/com/android/systemui/pip/tv/PipManager.java
index 18dde9d..c6e6da1 100644
--- a/packages/SystemUI/src/com/android/systemui/pip/tv/PipManager.java
+++ b/packages/SystemUI/src/com/android/systemui/pip/tv/PipManager.java
@@ -699,15 +699,15 @@
     };
 
     @Override
-    public void onPipTransitionStarted() { }
+    public void onPipTransitionStarted(ComponentName activity, int direction) { }
 
     @Override
-    public void onPipTransitionFinished() {
+    public void onPipTransitionFinished(ComponentName activity, int direction) {
         onPipTransitionFinishedOrCanceled();
     }
 
     @Override
-    public void onPipTransitionCanceled() {
+    public void onPipTransitionCanceled(ComponentName activity, int direction) {
         onPipTransitionFinishedOrCanceled();
     }
 
diff --git a/services/core/java/com/android/server/wm/ActivityRecord.java b/services/core/java/com/android/server/wm/ActivityRecord.java
index 78d6e27..62ec936 100644
--- a/services/core/java/com/android/server/wm/ActivityRecord.java
+++ b/services/core/java/com/android/server/wm/ActivityRecord.java
@@ -3160,7 +3160,7 @@
         }
 
         // Reset the last saved PiP snap fraction on removal.
-        mDisplayContent.mPinnedStackControllerLocked.resetReentryBounds(mActivityComponent);
+        mDisplayContent.mPinnedStackControllerLocked.onActivityHidden(mActivityComponent);
         mWmService.mEmbeddedWindowController.onActivityRemoved(this);
         mRemovingFromDisplay = false;
     }
@@ -4426,7 +4426,7 @@
         ProtoLog.v(WM_DEBUG_ADD_REMOVE, "notifyAppStopped: %s", this);
         mAppStopped = true;
         // Reset the last saved PiP snap fraction on app stop.
-        mDisplayContent.mPinnedStackControllerLocked.resetReentryBounds(mActivityComponent);
+        mDisplayContent.mPinnedStackControllerLocked.onActivityHidden(mActivityComponent);
         destroySurfaces();
         // Remove any starting window that was added for this app if they are still around.
         removeStartingWindow();
@@ -6651,17 +6651,6 @@
         }
     }
 
-    void savePinnedStackBounds() {
-        // Leaving PiP to fullscreen, save the snap fraction based on the pre-animation bounds
-        // for the next re-entry into PiP (assuming the activity is not hidden or destroyed)
-        final ActivityStack pinnedStack = mDisplayContent.getRootPinnedTask();
-        if (pinnedStack == null) return;
-        final Rect stackBounds = mTmpRect;
-        pinnedStack.getBounds(stackBounds);
-        mDisplayContent.mPinnedStackControllerLocked.saveReentryBounds(
-                mActivityComponent, stackBounds);
-    }
-
     /** Returns true if the configuration is compatible with this activity. */
     boolean isConfigurationCompatible(Configuration config) {
         final int orientation = getRequestedOrientation();
diff --git a/services/core/java/com/android/server/wm/ActivityStack.java b/services/core/java/com/android/server/wm/ActivityStack.java
index 64c14cf..9815d6d 100644
--- a/services/core/java/com/android/server/wm/ActivityStack.java
+++ b/services/core/java/com/android/server/wm/ActivityStack.java
@@ -3391,12 +3391,6 @@
                     "Can't exit pinned mode if it's not pinned already.");
         }
 
-        // give pinned stack a chance to save current bounds, this should happen before reparent.
-        final ActivityRecord top = topRunningNonOverlayTaskActivity();
-        if (top != null && top.isVisible()) {
-            top.savePinnedStackBounds();
-        }
-
         mWmService.inSurfaceTransaction(() -> {
             final Task task = getBottomMostTask();
             setWindowingMode(WINDOWING_MODE_UNDEFINED);
diff --git a/services/core/java/com/android/server/wm/PinnedStackController.java b/services/core/java/com/android/server/wm/PinnedStackController.java
index e14b8ae..66dbfd5 100644
--- a/services/core/java/com/android/server/wm/PinnedStackController.java
+++ b/services/core/java/com/android/server/wm/PinnedStackController.java
@@ -163,24 +163,13 @@
     }
 
     /**
-     * Saves the current snap fraction for re-entry of the current activity into PiP.
+     * Activity is hidden (either stopped or removed), resets the last saved snap fraction
+     * so that the default bounds will be returned for the next session.
      */
-    void saveReentryBounds(final ComponentName componentName, final Rect stackBounds) {
+    void onActivityHidden(ComponentName componentName) {
         if (mPinnedStackListener == null) return;
         try {
-            mPinnedStackListener.onSaveReentryBounds(componentName, stackBounds);
-        } catch (RemoteException e) {
-            Slog.e(TAG_WM, "Error delivering save reentry fraction event.", e);
-        }
-    }
-
-    /**
-     * Resets the last saved snap fraction so that the default bounds will be returned.
-     */
-    void resetReentryBounds(ComponentName componentName) {
-        if (mPinnedStackListener == null) return;
-        try {
-            mPinnedStackListener.onResetReentryBounds(componentName);
+            mPinnedStackListener.onActivityHidden(componentName);
         } catch (RemoteException e) {
             Slog.e(TAG_WM, "Error delivering reset reentry fraction event.", e);
         }
