MediaSessionService: Keep Media1 and Media2 session in the one place

This helps to prioritize all sessions for handle key events.

Bug: 147279043
Test: Build and run CTS
Change-Id: I3f2171bf05cd14761294620e1ae7454e32ddb1e2
diff --git a/services/core/java/com/android/server/media/MediaSession2Record.java b/services/core/java/com/android/server/media/MediaSession2Record.java
index f82d840..f3241ee 100644
--- a/services/core/java/com/android/server/media/MediaSession2Record.java
+++ b/services/core/java/com/android/server/media/MediaSession2Record.java
@@ -138,6 +138,12 @@
         pw.println(indent + "playbackActive=" + mController.isPlaybackActive());
     }
 
+    @Override
+    public String toString() {
+        // TODO(jaewan): Also add getId().
+        return getPackageName() + " (userId=" + getUserId() + ")";
+    }
+
     private class Controller2Callback extends MediaController2.ControllerCallback {
         @Override
         public void onConnected(MediaController2 controller, Session2CommandGroup allowedCommands) {
@@ -147,7 +153,7 @@
             synchronized (mLock) {
                 mIsConnected = true;
             }
-            mService.pushSession2TokensChanged(MediaSession2Record.this);
+            mService.onSessionActiveStateChanged(MediaSession2Record.this);
         }
 
         @Override
@@ -158,7 +164,16 @@
             synchronized (mLock) {
                 mIsConnected = false;
             }
-            mService.sessionDied(MediaSession2Record.this);
+            mService.onSessionDied(MediaSession2Record.this);
+        }
+
+        @Override
+        public void onPlaybackActiveChanged(MediaController2 controller, boolean playbackActive) {
+            if (DEBUG) {
+                Log.d(TAG, "playback active changed, " + mSessionToken + ", active="
+                        + playbackActive);
+            }
+            mService.onSessionPlaybackStateChanged(MediaSession2Record.this, playbackActive);
         }
     }
 }
diff --git a/services/core/java/com/android/server/media/MediaSessionRecord.java b/services/core/java/com/android/server/media/MediaSessionRecord.java
index c49be9c..df115d0 100644
--- a/services/core/java/com/android/server/media/MediaSessionRecord.java
+++ b/services/core/java/com/android/server/media/MediaSessionRecord.java
@@ -56,6 +56,7 @@
 
 import java.io.PrintWriter;
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.List;
 
 /**
@@ -73,6 +74,24 @@
      */
     private static final int OPTIMISTIC_VOLUME_TIMEOUT = 1000;
 
+    /**
+     * These are states that usually indicate the user took an action and should
+     * bump priority regardless of the old state.
+     */
+    private static final List<Integer> ALWAYS_PRIORITY_STATES = Arrays.asList(
+            PlaybackState.STATE_FAST_FORWARDING,
+            PlaybackState.STATE_REWINDING,
+            PlaybackState.STATE_SKIPPING_TO_PREVIOUS,
+            PlaybackState.STATE_SKIPPING_TO_NEXT);
+    /**
+     * These are states that usually indicate the user took an action if they
+     * were entered from a non-priority state.
+     */
+    private static final List<Integer> TRANSITION_PRIORITY_STATES = Arrays.asList(
+            PlaybackState.STATE_BUFFERING,
+            PlaybackState.STATE_CONNECTING,
+            PlaybackState.STATE_PLAYING);
+
     private final MessageHandler mHandler;
 
     private final int mOwnerPid;
@@ -225,7 +244,6 @@
      * @param opPackageName The op package that made the original volume request.
      * @param pid The pid that made the original volume request.
      * @param uid The uid that made the original volume request.
-     * @param caller caller binder. can be {@code null} if it's from the volume key.
      * @param asSystemService {@code true} if the event sent to the session as if it was come from
      *          the system service instead of the app process. This helps sessions to distinguish
      *          between the key injection by the app and key events from the hardware devices.
@@ -362,7 +380,7 @@
 
     @Override
     public void binderDied() {
-        mService.sessionDied(this);
+        mService.onSessionDied(this);
     }
 
     /**
@@ -724,7 +742,7 @@
         public void destroySession() throws RemoteException {
             final long token = Binder.clearCallingIdentity();
             try {
-                mService.sessionDied(MediaSessionRecord.this);
+                mService.onSessionDied(MediaSessionRecord.this);
             } finally {
                 Binder.restoreCallingIdentity(token);
             }
@@ -746,7 +764,7 @@
             mIsActive = active;
             final long token = Binder.clearCallingIdentity();
             try {
-                mService.updateSession(MediaSessionRecord.this);
+                mService.onSessionActiveStateChanged(MediaSessionRecord.this);
             } finally {
                 Binder.restoreCallingIdentity(token);
             }
@@ -813,12 +831,16 @@
                     ? PlaybackState.STATE_NONE : mPlaybackState.getState();
             int newState = state == null
                     ? PlaybackState.STATE_NONE : state.getState();
+            boolean shouldUpdatePriority = ALWAYS_PRIORITY_STATES.contains(newState)
+                    || (!TRANSITION_PRIORITY_STATES.contains(oldState)
+                    && TRANSITION_PRIORITY_STATES.contains(newState));
             synchronized (mLock) {
                 mPlaybackState = state;
             }
             final long token = Binder.clearCallingIdentity();
             try {
-                mService.onSessionPlaystateChanged(MediaSessionRecord.this, oldState, newState);
+                mService.onSessionPlaybackStateChanged(
+                        MediaSessionRecord.this, shouldUpdatePriority);
             } finally {
                 Binder.restoreCallingIdentity(token);
             }
diff --git a/services/core/java/com/android/server/media/MediaSessionService.java b/services/core/java/com/android/server/media/MediaSessionService.java
index 29bfc3b..f71fb58 100644
--- a/services/core/java/com/android/server/media/MediaSessionService.java
+++ b/services/core/java/com/android/server/media/MediaSessionService.java
@@ -119,12 +119,6 @@
     @GuardedBy("mLock")
     private final ArrayList<SessionsListenerRecord> mSessionsListeners =
             new ArrayList<SessionsListenerRecord>();
-    // Map user id as index to list of Session2Tokens
-    // TODO: Keep session2 info in MediaSessionStack for prioritizing both session1 and session2 in
-    //       one place.
-    @GuardedBy("mLock")
-    private final SparseArray<List<MediaSession2Record>> mSession2RecordsPerUser =
-            new SparseArray<>();
     @GuardedBy("mLock")
     private final List<Session2TokensListenerRecord> mSession2TokensListenerRecords =
             new ArrayList<>();
@@ -190,7 +184,7 @@
         return mGlobalPrioritySession != null && mGlobalPrioritySession.isActive();
     }
 
-    void updateSession(MediaSessionRecord record) {
+    void onSessionActiveStateChanged(MediaSessionRecordImpl record) {
         synchronized (mLock) {
             FullUserRecord user = getFullUserRecordLocked(record.getUserId());
             if (user == null) {
@@ -207,12 +201,14 @@
                     Log.w(TAG, "Unknown session updated. Ignoring.");
                     return;
                 }
-                user.mPriorityStack.onSessionStateChange(record);
+                user.mPriorityStack.onSessionActiveStateChanged(record);
             }
-            mHandler.postSessionsChanged(record.getUserId());
+
+            mHandler.postSessionsChanged(record);
         }
     }
 
+    // Currently only media1 can become global priority session.
     void setGlobalPrioritySession(MediaSessionRecord record) {
         synchronized (mLock) {
             FullUserRecord user = getFullUserRecordLocked(record.getUserId());
@@ -258,21 +254,13 @@
     List<Session2Token> getSession2TokensLocked(int userId) {
         List<Session2Token> list = new ArrayList<>();
         if (userId == USER_ALL) {
-            for (int i = 0; i < mSession2RecordsPerUser.size(); i++) {
-                List<MediaSession2Record> records = mSession2RecordsPerUser.valueAt(i);
-                for (MediaSession2Record record: records) {
-                    if (record.isActive()) {
-                        list.add(record.getSession2Token());
-                    }
-                }
+            int size = mUserRecords.size();
+            for (int i = 0; i < size; i++) {
+                list.addAll(mUserRecords.valueAt(i).mPriorityStack.getSession2Tokens(userId));
             }
         } else {
-            List<MediaSession2Record> records = mSession2RecordsPerUser.get(userId);
-            for (MediaSession2Record record: records) {
-                if (record.isActive()) {
-                    list.add(record.getSession2Token());
-                }
-            }
+            FullUserRecord user = getFullUserRecordLocked(userId);
+            list.addAll(user.mPriorityStack.getSession2Tokens(userId));
         }
         return list;
     }
@@ -299,14 +287,15 @@
         }
     }
 
-    void onSessionPlaystateChanged(MediaSessionRecord record, int oldState, int newState) {
+    void onSessionPlaybackStateChanged(MediaSessionRecordImpl record,
+            boolean shouldUpdatePriority) {
         synchronized (mLock) {
             FullUserRecord user = getFullUserRecordLocked(record.getUserId());
             if (user == null || !user.mPriorityStack.contains(record)) {
                 Log.d(TAG, "Unknown session changed playback state. Ignoring.");
                 return;
             }
-            user.mPriorityStack.onPlaystateChanged(record, oldState, newState);
+            user.mPriorityStack.onPlaybackStateChanged(record, shouldUpdatePriority);
         }
     }
 
@@ -349,13 +338,6 @@
                     user.destroySessionsForUserLocked(userId);
                 }
             }
-            List<MediaSession2Record> list = mSession2RecordsPerUser.get(userId);
-            if (list != null) {
-                for (MediaSession2Record session : list) {
-                    session.close();
-                }
-                mSession2RecordsPerUser.remove(userId);
-            }
             updateUser();
         }
     }
@@ -374,26 +356,12 @@
         }
     }
 
-    void sessionDied(MediaSessionRecord session) {
+    void onSessionDied(MediaSessionRecordImpl session) {
         synchronized (mLock) {
             destroySessionLocked(session);
         }
     }
 
-    void pushSession2TokensChanged(MediaSession2Record sessionRecord) {
-        synchronized (mLock) {
-            pushSession2TokensChangedLocked(sessionRecord.getUserId());
-        }
-    }
-
-    void sessionDied(MediaSession2Record sessionRecord) {
-        synchronized (mLock) {
-            int userId = sessionRecord.getUserId();
-            mSession2RecordsPerUser.get(userId).remove(sessionRecord);
-            pushSession2TokensChangedLocked(userId);
-        }
-    }
-
     private void updateUser() {
         synchronized (mLock) {
             UserManager manager = (UserManager) mContext.getSystemService(Context.USER_SERVICE);
@@ -409,9 +377,6 @@
                             mUserRecords.put(userInfo.id, new FullUserRecord(userInfo.id));
                         }
                     }
-                    if (mSession2RecordsPerUser.get(userInfo.id) == null) {
-                        mSession2RecordsPerUser.put(userInfo.id, new ArrayList<>());
-                    }
                 }
             }
             // Ensure that the current full user exists.
@@ -421,9 +386,6 @@
                 Log.w(TAG, "Cannot find FullUserInfo for the current user " + currentFullUserId);
                 mCurrentFullUserRecord = new FullUserRecord(currentFullUserId);
                 mUserRecords.put(currentFullUserId, mCurrentFullUserRecord);
-                if (mSession2RecordsPerUser.get(currentFullUserId) == null) {
-                    mSession2RecordsPerUser.put(currentFullUserId, new ArrayList<>());
-                }
             }
             mFullUserIds.put(currentFullUserId, currentFullUserId);
         }
@@ -460,7 +422,7 @@
      * 5. We need to unlink to death from the cb binder
      * 6. We need to tell the session to do any final cleanup (onDestroy)
      */
-    private void destroySessionLocked(MediaSessionRecord session) {
+    private void destroySessionLocked(MediaSessionRecordImpl session) {
         if (DEBUG) {
             Log.d(TAG, "Destroying " + session);
         }
@@ -477,7 +439,7 @@
         }
 
         session.close();
-        mHandler.postSessionsChanged(session.getUserId());
+        mHandler.postSessionsChanged(session);
     }
 
     private void enforcePackageName(String packageName, int uid) {
@@ -582,7 +544,7 @@
             }
 
             user.mPriorityStack.addSession(session);
-            mHandler.postSessionsChanged(userId);
+            mHandler.postSessionsChanged(session);
 
             if (DEBUG) {
                 Log.d(TAG, "Created session for " + callerPackageName + " with tag " + tag);
@@ -609,16 +571,16 @@
         return -1;
     }
 
-    private void pushSessionsChanged(int userId) {
+    private void pushSession1Changed(int userId) {
         synchronized (mLock) {
             FullUserRecord user = getFullUserRecordLocked(userId);
             if (user == null) {
-                Log.w(TAG, "pushSessionsChanged failed. No user with id=" + userId);
+                Log.w(TAG, "pushSession1ChangedOnHandler failed. No user with id=" + userId);
                 return;
             }
             List<MediaSessionRecord> records = getActiveSessionsLocked(userId);
             int size = records.size();
-            ArrayList<MediaSession.Token> tokens = new ArrayList<MediaSession.Token>();
+            ArrayList<MediaSession.Token> tokens = new ArrayList<>();
             for (int i = 0; i < size; i++) {
                 tokens.add(records.get(i).getSessionToken());
             }
@@ -638,6 +600,27 @@
         }
     }
 
+    void pushSession2Changed(int userId) {
+        synchronized (mLock) {
+            List<Session2Token> allSession2Tokens = getSession2TokensLocked(USER_ALL);
+            List<Session2Token> session2Tokens = getSession2TokensLocked(userId);
+
+            for (int i = mSession2TokensListenerRecords.size() - 1; i >= 0; i--) {
+                Session2TokensListenerRecord listenerRecord = mSession2TokensListenerRecords.get(i);
+                try {
+                    if (listenerRecord.userId == USER_ALL) {
+                        listenerRecord.listener.onSession2TokensChanged(allSession2Tokens);
+                    } else if (listenerRecord.userId == userId) {
+                        listenerRecord.listener.onSession2TokensChanged(session2Tokens);
+                    }
+                } catch (RemoteException e) {
+                    Log.w(TAG, "Failed to notify Session2Token change. Removing listener.", e);
+                    mSession2TokensListenerRecords.remove(i);
+                }
+            }
+        }
+    }
+
     private void pushRemoteVolumeUpdateLocked(int userId) {
         FullUserRecord user = getFullUserRecordLocked(userId);
         if (user == null) {
@@ -647,8 +630,13 @@
 
         synchronized (mLock) {
             int size = mRemoteVolumeControllers.beginBroadcast();
-            MediaSessionRecord record = user.mPriorityStack.getDefaultRemoteSession(userId);
-            MediaSession.Token token = record == null ? null : record.getSessionToken();
+            MediaSessionRecordImpl record = user.mPriorityStack.getDefaultRemoteSession(userId);
+            if (record instanceof MediaSession2Record) {
+                // TODO(jaewan): Implement
+                return;
+            }
+            MediaSession.Token token = record == null
+                    ? null : ((MediaSessionRecord) record).getSessionToken();
 
             for (int i = size - 1; i >= 0; i--) {
                 try {
@@ -662,34 +650,15 @@
         }
     }
 
-    void pushSession2TokensChangedLocked(int userId) {
-        List<Session2Token> allSession2Tokens = getSession2TokensLocked(USER_ALL);
-        List<Session2Token> session2Tokens = getSession2TokensLocked(userId);
-
-        for (int i = mSession2TokensListenerRecords.size() - 1; i >= 0; i--) {
-            Session2TokensListenerRecord listenerRecord = mSession2TokensListenerRecords.get(i);
-            try {
-                if (listenerRecord.userId == USER_ALL) {
-                    listenerRecord.listener.onSession2TokensChanged(allSession2Tokens);
-                } else if (listenerRecord.userId == userId) {
-                    listenerRecord.listener.onSession2TokensChanged(session2Tokens);
-                }
-            } catch (RemoteException e) {
-                Log.w(TAG, "Failed to notify Session2Token change. Removing listener.", e);
-                mSession2TokensListenerRecords.remove(i);
-            }
-        }
-    }
-
     /**
      * Called when the media button receiver for the {@code record} is changed.
      *
      * @param record the media session whose media button receiver is updated.
      */
-    public void onMediaButtonReceiverChanged(MediaSessionRecord record) {
+    public void onMediaButtonReceiverChanged(MediaSessionRecordImpl record) {
         synchronized (mLock) {
             FullUserRecord user = getFullUserRecordLocked(record.getUserId());
-            MediaSessionRecord mediaButtonSession =
+            MediaSessionRecordImpl mediaButtonSession =
                     user.mPriorityStack.getMediaButtonSession();
             if (record == mediaButtonSession) {
                 user.rememberMediaButtonReceiverLocked(mediaButtonSession);
@@ -877,39 +846,34 @@
             pw.println(indent + "Restored MediaButtonReceiverComponentType: "
                     + mRestoredMediaButtonReceiverComponentType);
             mPriorityStack.dump(pw, indent);
-            pw.println(indent + "Session2Tokens:");
-            for (int i = 0; i < mSession2RecordsPerUser.size(); i++) {
-                List<MediaSession2Record> list = mSession2RecordsPerUser.valueAt(i);
-                if (list == null || list.size() == 0) {
-                    continue;
-                }
-                for (MediaSession2Record record : list) {
-                    record.dump(pw, indent);
-                }
-            }
         }
 
         @Override
-        public void onMediaButtonSessionChanged(MediaSessionRecord oldMediaButtonSession,
-                MediaSessionRecord newMediaButtonSession) {
+        public void onMediaButtonSessionChanged(MediaSessionRecordImpl oldMediaButtonSession,
+                MediaSessionRecordImpl newMediaButtonSession) {
             if (DEBUG_KEY_EVENT) {
                 Log.d(TAG, "Media button session is changed to " + newMediaButtonSession);
             }
             synchronized (mLock) {
                 if (oldMediaButtonSession != null) {
-                    mHandler.postSessionsChanged(oldMediaButtonSession.getUserId());
+                    mHandler.postSessionsChanged(oldMediaButtonSession);
                 }
                 if (newMediaButtonSession != null) {
                     rememberMediaButtonReceiverLocked(newMediaButtonSession);
-                    mHandler.postSessionsChanged(newMediaButtonSession.getUserId());
+                    mHandler.postSessionsChanged(newMediaButtonSession);
                 }
                 pushAddressedPlayerChangedLocked();
             }
         }
 
         // Remember media button receiver and keep it in the persistent storage.
-        public void rememberMediaButtonReceiverLocked(MediaSessionRecord record) {
-            PendingIntent receiver = record.getMediaButtonReceiver();
+        public void rememberMediaButtonReceiverLocked(MediaSessionRecordImpl record) {
+            if (record instanceof MediaSession2Record) {
+                // TODO(jaewan): Implement
+                return;
+            }
+            MediaSessionRecord sessionRecord = (MediaSessionRecord) record;
+            PendingIntent receiver = sessionRecord.getMediaButtonReceiver();
             mLastMediaButtonReceiver = receiver;
             mRestoredMediaButtonReceiver = null;
             mRestoredMediaButtonReceiverComponentType = COMPONENT_TYPE_INVALID;
@@ -934,10 +898,15 @@
         private void pushAddressedPlayerChangedLocked(
                 IOnMediaKeyEventSessionChangedListener callback) {
             try {
-                MediaSessionRecord mediaButtonSession = getMediaButtonSessionLocked();
+                MediaSessionRecordImpl mediaButtonSession = getMediaButtonSessionLocked();
                 if (mediaButtonSession != null) {
-                    callback.onMediaKeyEventSessionChanged(mediaButtonSession.getPackageName(),
-                            mediaButtonSession.getSessionToken());
+                    if (mediaButtonSession instanceof MediaSessionRecord) {
+                        MediaSessionRecord session1 = (MediaSessionRecord) mediaButtonSession;
+                        callback.onMediaKeyEventSessionChanged(session1.getPackageName(),
+                                session1.getSessionToken());
+                    } else {
+                        // TODO(jaewan): Implement
+                    }
                 } else if (mCurrentFullUserRecord.mLastMediaButtonReceiver != null) {
                     callback.onMediaKeyEventSessionChanged(
                             mCurrentFullUserRecord.mLastMediaButtonReceiver
@@ -960,7 +929,7 @@
             }
         }
 
-        private MediaSessionRecord getMediaButtonSessionLocked() {
+        private MediaSessionRecordImpl getMediaButtonSessionLocked() {
             return isGlobalPriorityActiveLocked()
                     ? mGlobalPrioritySession : mPriorityStack.getMediaButtonSession();
         }
@@ -1143,7 +1112,10 @@
                 }
                 MediaSession2Record record = new MediaSession2Record(
                         sessionToken, MediaSessionService.this, mHandler.getLooper());
-                mSession2RecordsPerUser.get(record.getUserId()).add(record);
+                synchronized (mLock) {
+                    FullUserRecord user = getFullUserRecordLocked(record.getUserId());
+                    user.mPriorityStack.addSession(record);
+                }
                 // Do not immediately notify changes -- do so when framework can dispatch command
             } finally {
                 Binder.restoreCallingIdentity(token);
@@ -1185,7 +1157,8 @@
                         null /* optional packageName */);
                 List<Session2Token> result;
                 synchronized (mLock) {
-                    result = getSession2TokensLocked(resolvedUserId);
+                    FullUserRecord user = getFullUserRecordLocked(userId);
+                    result = user.mPriorityStack.getSession2Tokens(resolvedUserId);
                 }
                 return new ParceledListSlice(result);
             } finally {
@@ -2023,7 +1996,7 @@
 
         private void dispatchAdjustVolumeLocked(String packageName, String opPackageName, int pid,
                 int uid, boolean asSystemService, int suggestedStream, int direction, int flags) {
-            MediaSessionRecord session = isGlobalPriorityActiveLocked() ? mGlobalPrioritySession
+            MediaSessionRecordImpl session = isGlobalPriorityActiveLocked() ? mGlobalPrioritySession
                     : mCurrentFullUserRecord.mPriorityStack.getDefaultVolumeSession();
 
             boolean preferSuggestedStream = false;
@@ -2114,7 +2087,13 @@
 
         private void dispatchMediaKeyEventLocked(String packageName, int pid, int uid,
                 boolean asSystemService, KeyEvent keyEvent, boolean needWakeLock) {
-            MediaSessionRecord session = mCurrentFullUserRecord.getMediaButtonSessionLocked();
+            if (mCurrentFullUserRecord.getMediaButtonSessionLocked()
+                    instanceof MediaSession2Record) {
+                // TODO(jaewan): Implement
+                return;
+            }
+            MediaSessionRecord session =
+                    (MediaSessionRecord) mCurrentFullUserRecord.getMediaButtonSessionLocked();
             if (session != null) {
                 if (DEBUG_KEY_EVENT) {
                     Log.d(TAG, "Sending " + keyEvent + " to " + session);
@@ -2394,15 +2373,19 @@
     }
 
     final class MessageHandler extends Handler {
-        private static final int MSG_SESSIONS_CHANGED = 1;
-        private static final int MSG_VOLUME_INITIAL_DOWN = 2;
+        private static final int MSG_SESSIONS_1_CHANGED = 1;
+        private static final int MSG_SESSIONS_2_CHANGED = 2;
+        private static final int MSG_VOLUME_INITIAL_DOWN = 3;
         private final SparseArray<Integer> mIntegerCache = new SparseArray<>();
 
         @Override
         public void handleMessage(Message msg) {
             switch (msg.what) {
-                case MSG_SESSIONS_CHANGED:
-                    pushSessionsChanged((int) msg.obj);
+                case MSG_SESSIONS_1_CHANGED:
+                    pushSession1Changed((int) msg.obj);
+                    break;
+                case MSG_SESSIONS_2_CHANGED:
+                    pushSession2Changed((int) msg.obj);
                     break;
                 case MSG_VOLUME_INITIAL_DOWN:
                     synchronized (mLock) {
@@ -2417,15 +2400,18 @@
             }
         }
 
-        public void postSessionsChanged(int userId) {
+        public void postSessionsChanged(MediaSessionRecordImpl record) {
             // Use object instead of the arguments when posting message to remove pending requests.
-            Integer userIdInteger = mIntegerCache.get(userId);
+            Integer userIdInteger = mIntegerCache.get(record.getUserId());
             if (userIdInteger == null) {
-                userIdInteger = Integer.valueOf(userId);
-                mIntegerCache.put(userId, userIdInteger);
+                userIdInteger = Integer.valueOf(record.getUserId());
+                mIntegerCache.put(record.getUserId(), userIdInteger);
             }
-            removeMessages(MSG_SESSIONS_CHANGED, userIdInteger);
-            obtainMessage(MSG_SESSIONS_CHANGED, userIdInteger).sendToTarget();
+
+            int msg = (record instanceof MediaSessionRecord)
+                    ? MSG_SESSIONS_1_CHANGED : MSG_SESSIONS_2_CHANGED;
+            removeMessages(msg, userIdInteger);
+            obtainMessage(msg, userIdInteger).sendToTarget();
         }
     }
 
diff --git a/services/core/java/com/android/server/media/MediaSessionStack.java b/services/core/java/com/android/server/media/MediaSessionStack.java
index 74613c9..7bb7cf4 100644
--- a/services/core/java/com/android/server/media/MediaSessionStack.java
+++ b/services/core/java/com/android/server/media/MediaSessionStack.java
@@ -16,8 +16,8 @@
 
 package com.android.server.media;
 
+import android.media.Session2Token;
 import android.media.session.MediaSession;
-import android.media.session.PlaybackState;
 import android.os.Debug;
 import android.os.UserHandle;
 import android.util.IntArray;
@@ -45,51 +45,30 @@
         /**
          * Called when the media button session is changed.
          */
-        void onMediaButtonSessionChanged(MediaSessionRecord oldMediaButtonSession,
-                MediaSessionRecord newMediaButtonSession);
+        void onMediaButtonSessionChanged(MediaSessionRecordImpl oldMediaButtonSession,
+                MediaSessionRecordImpl newMediaButtonSession);
     }
 
     /**
-     * These are states that usually indicate the user took an action and should
-     * bump priority regardless of the old state.
+     * Sorted list of the media sessions
      */
-    private static final int[] ALWAYS_PRIORITY_STATES = {
-            PlaybackState.STATE_FAST_FORWARDING,
-            PlaybackState.STATE_REWINDING,
-            PlaybackState.STATE_SKIPPING_TO_PREVIOUS,
-            PlaybackState.STATE_SKIPPING_TO_NEXT };
-    /**
-     * These are states that usually indicate the user took an action if they
-     * were entered from a non-priority state.
-     */
-    private static final int[] TRANSITION_PRIORITY_STATES = {
-            PlaybackState.STATE_BUFFERING,
-            PlaybackState.STATE_CONNECTING,
-            PlaybackState.STATE_PLAYING };
-
-    /**
-     * Sorted list of the media sessions.
-     * The session of which PlaybackState is changed to ALWAYS_PRIORITY_STATES or
-     * TRANSITION_PRIORITY_STATES comes first.
-     * @see #shouldUpdatePriority
-     */
-    private final List<MediaSessionRecord> mSessions = new ArrayList<MediaSessionRecord>();
+    private final List<MediaSessionRecordImpl> mSessions = new ArrayList<>();
 
     private final AudioPlayerStateMonitor mAudioPlayerStateMonitor;
     private final OnMediaButtonSessionChangedListener mOnMediaButtonSessionChangedListener;
 
     /**
      * The media button session which receives media key events.
-     * It could be null if the previous media buttion session is released.
+     * It could be null if the previous media button session is released.
      */
-    private MediaSessionRecord mMediaButtonSession;
+    private MediaSessionRecordImpl mMediaButtonSession;
 
-    private MediaSessionRecord mCachedVolumeDefault;
+    private MediaSessionRecordImpl mCachedVolumeDefault;
 
     /**
      * Cache the result of the {@link #getActiveSessions} per user.
      */
-    private final SparseArray<ArrayList<MediaSessionRecord>> mCachedActiveLists =
+    private final SparseArray<List<MediaSessionRecord>> mCachedActiveLists =
             new SparseArray<>();
 
     MediaSessionStack(AudioPlayerStateMonitor monitor, OnMediaButtonSessionChangedListener listener) {
@@ -102,7 +81,7 @@
      *
      * @param record The record to add.
      */
-    public void addSession(MediaSessionRecord record) {
+    public void addSession(MediaSessionRecordImpl record) {
         mSessions.add(record);
         clearCache(record.getUserId());
 
@@ -117,7 +96,7 @@
      *
      * @param record The record to remove.
      */
-    public void removeSession(MediaSessionRecord record) {
+    public void removeSession(MediaSessionRecordImpl record) {
         mSessions.remove(record);
         if (mMediaButtonSession == record) {
             // When the media button session is removed, nullify the media button session and do not
@@ -131,7 +110,7 @@
     /**
      * Return if the record exists in the priority tracker.
      */
-    public boolean contains(MediaSessionRecord record) {
+    public boolean contains(MediaSessionRecordImpl record) {
         return mSessions.contains(record);
     }
 
@@ -142,9 +121,12 @@
      * @return the MediaSessionRecord. Can be {@code null} if the session is gone meanwhile.
      */
     public MediaSessionRecord getMediaSessionRecord(MediaSession.Token sessionToken) {
-        for (MediaSessionRecord record : mSessions) {
-            if (Objects.equals(record.getSessionToken(), sessionToken)) {
-                return record;
+        for (MediaSessionRecordImpl record : mSessions) {
+            if (record instanceof MediaSessionRecord) {
+                MediaSessionRecord session1 = (MediaSessionRecord) record;
+                if (Objects.equals(session1.getSessionToken(), sessionToken)) {
+                    return session1;
+                }
             }
         }
         return null;
@@ -154,15 +136,15 @@
      * Notify the priority tracker that a session's playback state changed.
      *
      * @param record The record that changed.
-     * @param oldState Its old playback state.
-     * @param newState Its new playback state.
+     * @param shouldUpdatePriority {@code true} if the record needs to prioritized
      */
-    public void onPlaystateChanged(MediaSessionRecord record, int oldState, int newState) {
-        if (shouldUpdatePriority(oldState, newState)) {
+    public void onPlaybackStateChanged(
+            MediaSessionRecordImpl record, boolean shouldUpdatePriority) {
+        if (shouldUpdatePriority) {
             mSessions.remove(record);
             mSessions.add(0, record);
             clearCache(record.getUserId());
-        } else if (!MediaSession.isActiveState(newState)) {
+        } else if (record.checkPlaybackActiveState(false)) {
             // Just clear the volume cache when a state goes inactive
             mCachedVolumeDefault = null;
         }
@@ -172,7 +154,7 @@
         // In that case, we pick the media session whose PlaybackState matches
         // the audio playback configuration.
         if (mMediaButtonSession != null && mMediaButtonSession.getUid() == record.getUid()) {
-            MediaSessionRecord newMediaButtonSession =
+            MediaSessionRecordImpl newMediaButtonSession =
                     findMediaButtonSession(mMediaButtonSession.getUid());
             if (newMediaButtonSession != mMediaButtonSession) {
                 updateMediaButtonSession(newMediaButtonSession);
@@ -185,7 +167,7 @@
      *
      * @param record The record that changed.
      */
-    public void onSessionStateChange(MediaSessionRecord record) {
+    public void onSessionActiveStateChanged(MediaSessionRecordImpl record) {
         // For now just clear the cache. Eventually we'll selectively clear
         // depending on what changed.
         clearCache(record.getUserId());
@@ -203,7 +185,7 @@
         }
         IntArray audioPlaybackUids = mAudioPlayerStateMonitor.getSortedAudioPlaybackClientUids();
         for (int i = 0; i < audioPlaybackUids.size(); i++) {
-            MediaSessionRecord mediaButtonSession =
+            MediaSessionRecordImpl mediaButtonSession =
                     findMediaButtonSession(audioPlaybackUids.get(i));
             if (mediaButtonSession != null) {
                 // Found the media button session.
@@ -225,9 +207,9 @@
      * @return The media button session. Returns {@code null} if the app doesn't have a media
      *   session.
      */
-    private MediaSessionRecord findMediaButtonSession(int uid) {
-        MediaSessionRecord mediaButtonSession = null;
-        for (MediaSessionRecord session : mSessions) {
+    private MediaSessionRecordImpl findMediaButtonSession(int uid) {
+        MediaSessionRecordImpl mediaButtonSession = null;
+        for (MediaSessionRecordImpl session : mSessions) {
             if (uid == session.getUid()) {
                 if (session.checkPlaybackActiveState(
                         mAudioPlayerStateMonitor.isPlaybackActive(session.getUid()))) {
@@ -253,8 +235,8 @@
      *    for all users in this {@link MediaSessionStack}.
      * @return All the active sessions in priority order.
      */
-    public ArrayList<MediaSessionRecord> getActiveSessions(int userId) {
-        ArrayList<MediaSessionRecord> cachedActiveList = mCachedActiveLists.get(userId);
+    public List<MediaSessionRecord> getActiveSessions(int userId) {
+        List<MediaSessionRecord> cachedActiveList = mCachedActiveLists.get(userId);
         if (cachedActiveList == null) {
             cachedActiveList = getPriorityList(true, userId);
             mCachedActiveLists.put(userId, cachedActiveList);
@@ -263,26 +245,46 @@
     }
 
     /**
+     * Gets the session2 tokens.
+     *
+     * @param userId The user to check. It can be {@link UserHandle#USER_ALL} to get all session2
+     *    tokens for all users in this {@link MediaSessionStack}.
+     * @return All session2 tokens.
+     */
+    public List<Session2Token> getSession2Tokens(int userId) {
+        ArrayList<Session2Token> session2Records = new ArrayList<>();
+        for (MediaSessionRecordImpl record : mSessions) {
+            if ((userId == UserHandle.USER_ALL || record.getUserId() == userId)
+                    && record.isActive()
+                    && record instanceof MediaSession2Record) {
+                MediaSession2Record session2 = (MediaSession2Record) record;
+                session2Records.add(session2.getSession2Token());
+            }
+        }
+        return session2Records;
+    }
+
+    /**
      * Get the media button session which receives the media button events.
      *
      * @return The media button session or null.
      */
-    public MediaSessionRecord getMediaButtonSession() {
+    public MediaSessionRecordImpl getMediaButtonSession() {
         return mMediaButtonSession;
     }
 
-    private void updateMediaButtonSession(MediaSessionRecord newMediaButtonSession) {
-        MediaSessionRecord oldMediaButtonSession = mMediaButtonSession;
+    private void updateMediaButtonSession(MediaSessionRecordImpl newMediaButtonSession) {
+        MediaSessionRecordImpl oldMediaButtonSession = mMediaButtonSession;
         mMediaButtonSession = newMediaButtonSession;
         mOnMediaButtonSessionChangedListener.onMediaButtonSessionChanged(
                 oldMediaButtonSession, newMediaButtonSession);
     }
 
-    public MediaSessionRecord getDefaultVolumeSession() {
+    public MediaSessionRecordImpl getDefaultVolumeSession() {
         if (mCachedVolumeDefault != null) {
             return mCachedVolumeDefault;
         }
-        ArrayList<MediaSessionRecord> records = getPriorityList(true, UserHandle.USER_ALL);
+        List<MediaSessionRecord> records = getPriorityList(true, UserHandle.USER_ALL);
         int size = records.size();
         for (int i = 0; i < size; i++) {
             MediaSessionRecord record = records.get(i);
@@ -294,8 +296,8 @@
         return null;
     }
 
-    public MediaSessionRecord getDefaultRemoteSession(int userId) {
-        ArrayList<MediaSessionRecord> records = getPriorityList(true, userId);
+    public MediaSessionRecordImpl getDefaultRemoteSession(int userId) {
+        List<MediaSessionRecord> records = getPriorityList(true, userId);
 
         int size = records.size();
         for (int i = 0; i < size; i++) {
@@ -308,16 +310,11 @@
     }
 
     public void dump(PrintWriter pw, String prefix) {
-        ArrayList<MediaSessionRecord> sortedSessions = getPriorityList(false,
-                UserHandle.USER_ALL);
-        int count = sortedSessions.size();
         pw.println(prefix + "Media button session is " + mMediaButtonSession);
-        pw.println(prefix + "Sessions Stack - have " + count + " sessions:");
+        pw.println(prefix + "Sessions Stack - have " + mSessions.size() + " sessions:");
         String indent = prefix + "  ";
-        for (int i = 0; i < count; i++) {
-            MediaSessionRecord record = sortedSessions.get(i);
+        for (MediaSessionRecordImpl record : mSessions) {
             record.dump(pw, indent);
-            pw.println();
         }
     }
 
@@ -335,17 +332,19 @@
      *            will return sessions for all users.
      * @return The priority sorted list of sessions.
      */
-    public ArrayList<MediaSessionRecord> getPriorityList(boolean activeOnly, int userId) {
-        ArrayList<MediaSessionRecord> result = new ArrayList<MediaSessionRecord>();
+    public List<MediaSessionRecord> getPriorityList(boolean activeOnly, int userId) {
+        List<MediaSessionRecord> result = new ArrayList<MediaSessionRecord>();
         int lastPlaybackActiveIndex = 0;
         int lastActiveIndex = 0;
 
-        int size = mSessions.size();
-        for (int i = 0; i < size; i++) {
-            final MediaSessionRecord session = mSessions.get(i);
+        for (MediaSessionRecordImpl record : mSessions) {
+            if (!(record instanceof MediaSessionRecord)) {
+                continue;
+            }
+            final MediaSessionRecord session = (MediaSessionRecord) record;
 
-            if (userId != UserHandle.USER_ALL && userId != session.getUserId()) {
-                // Filter out sessions for the wrong user
+            if ((userId != UserHandle.USER_ALL && userId != session.getUserId())) {
+                // Filter out sessions for the wrong user or session2.
                 continue;
             }
 
@@ -369,26 +368,6 @@
         return result;
     }
 
-    private boolean shouldUpdatePriority(int oldState, int newState) {
-        if (containsState(newState, ALWAYS_PRIORITY_STATES)) {
-            return true;
-        }
-        if (!containsState(oldState, TRANSITION_PRIORITY_STATES)
-                && containsState(newState, TRANSITION_PRIORITY_STATES)) {
-            return true;
-        }
-        return false;
-    }
-
-    private boolean containsState(int state, int[] states) {
-        for (int i = 0; i < states.length; i++) {
-            if (states[i] == state) {
-                return true;
-            }
-        }
-        return false;
-    }
-
     private void clearCache(int userId) {
         mCachedVolumeDefault = null;
         mCachedActiveLists.remove(userId);