Rename WorkSource methods on Binder and IPCThreadState.

This change only renames methods, there is no behavior changes except
using the new restore methods instead of clear.

Test: unit tests
Change-Id: I35ae966461657e2e2a67e916d752b9ee53381c83
diff --git a/core/java/android/os/Binder.java b/core/java/android/os/Binder.java
index da4b823..c7184c0 100644
--- a/core/java/android/os/Binder.java
+++ b/core/java/android/os/Binder.java
@@ -382,7 +382,9 @@
     /**
      * Sets the work source for this thread.
      *
-     * <p>All the following binder calls on this thread will use the provided work source.
+     * <p>All the following binder calls on this thread will use the provided work source. If this
+     * is called during an on-going binder transaction, all the following binder calls will use the
+     * work source until the end of the transaction.
      *
      * <p>The concept of worksource is similar to {@link WorkSource}. However, for performance
      * reasons, we only support one UID. This UID represents the original user responsible for the
@@ -390,20 +392,20 @@
      *
      * <p>A typical use case would be
      * <pre>
-     * Binder.setThreadWorkSource(uid);
+     * long token = Binder.setCallingWorkSourceUid(uid);
      * try {
      *   // Call an API.
      * } finally {
-     *   Binder.clearThreadWorkSource();
+     *   Binder.restoreCallingWorkSource(token);
      * }
      * </pre>
      *
      * @param workSource The original UID responsible for the binder call.
-     * @return The previously set work source.
+     * @return token to restore original work source.
      * @hide
      **/
     @CriticalNative
-    public static final native int setThreadWorkSource(int workSource);
+    public static final native long setCallingWorkSourceUid(int workSource);
 
     /**
      * Returns the work source set by the caller.
@@ -416,16 +418,34 @@
      * @hide
      */
     @CriticalNative
-    public static final native int getThreadWorkSource();
+    public static final native int getCallingWorkSourceUid();
 
     /**
      * Clears the work source on this thread.
      *
-     * @return The previously set work source.
+     * @return token to restore original work source.
      * @hide
      **/
     @CriticalNative
-    public static final native int clearThreadWorkSource();
+    public static final native long clearCallingWorkSource();
+
+    /**
+     * Restores the work source on this thread using a token returned by
+     * {@link #setCallingWorkSourceUid(int) or {@link clearCallingWorkSource()}.
+     *
+     * <p>A typical use case would be
+     * <pre>
+     * long token = Binder.setCallingWorkSourceUid(uid);
+     * try {
+     *   // Call an API.
+     * } finally {
+     *   Binder.restoreCallingWorkSource(token);
+     * }
+     * </pre>
+     * @hide
+     **/
+    @CriticalNative
+    public static final native void restoreCallingWorkSource(long token);
 
     /**
      * Flush any Binder commands pending in the current thread to the kernel
@@ -586,7 +606,7 @@
      *
      * <li>By default, this listener will propagate the worksource if the outgoing call happens on
      * the same thread as the incoming binder call.
-     * <li>Custom attribution can be done by calling {@link ThreadLocalWorkSourceUid#set(int)}.
+     * <li>Custom attribution can be done by calling {@link ThreadLocalWorkSource#setUid(int)}.
      * @hide
      */
     public static class PropagateWorkSourceTransactListener implements ProxyTransactListener {
@@ -595,12 +615,11 @@
            // Note that {@link Binder#getCallingUid()} is already set to the UID of the current
            // process when this method is called.
            //
-           // We use ThreadLocalWorkSourceUid instead. It also allows feature owners to set
-           // {@link ThreadLocalWorkSourceUid#set(int) manually to attribute resources to a UID.
-            int uid = ThreadLocalWorkSourceUid.get();
-            if (uid >= 0) {
-                int originalUid = Binder.setThreadWorkSource(uid);
-                return Integer.valueOf(originalUid);
+           // We use ThreadLocalWorkSource instead. It also allows feature owners to set
+           // {@link ThreadLocalWorkSource#set(int) manually to attribute resources to a UID.
+            int uid = ThreadLocalWorkSource.getUid();
+            if (uid != ThreadLocalWorkSource.UID_NONE) {
+                return Binder.setCallingWorkSourceUid(uid);
             }
             return null;
         }
@@ -608,8 +627,8 @@
         @Override
         public void onTransactEnded(Object session) {
             if (session != null) {
-                int uid = (int) session;
-                Binder.setThreadWorkSource(uid);
+                long token = (long) session;
+                Binder.restoreCallingWorkSource(token);
             }
         }
     }
@@ -897,11 +916,11 @@
         // Log any exceptions as warnings, don't silently suppress them.
         // If the call was FLAG_ONEWAY then these exceptions disappear into the ether.
         final boolean tracingEnabled = Binder.isTracingEnabled();
+        final long origWorkSource = ThreadLocalWorkSource.setUid(Binder.getCallingUid());
         try {
             if (tracingEnabled) {
                 Trace.traceBegin(Trace.TRACE_TAG_ALWAYS, getClass().getName() + ":" + code);
             }
-            ThreadLocalWorkSourceUid.set(Binder.getCallingUid());
             res = onTransact(code, data, reply, flags);
         } catch (RemoteException|RuntimeException e) {
             if (observer != null) {
@@ -922,7 +941,7 @@
             }
             res = true;
         } finally {
-            ThreadLocalWorkSourceUid.clear();
+            ThreadLocalWorkSource.restore(origWorkSource);
             if (tracingEnabled) {
                 Trace.traceEnd(Trace.TRACE_TAG_ALWAYS);
             }
diff --git a/core/java/android/os/Handler.java b/core/java/android/os/Handler.java
index f3a9a50..e8704af 100644
--- a/core/java/android/os/Handler.java
+++ b/core/java/android/os/Handler.java
@@ -739,7 +739,7 @@
 
     private boolean enqueueMessage(MessageQueue queue, Message msg, long uptimeMillis) {
         msg.target = this;
-        msg.workSourceUid = ThreadLocalWorkSourceUid.get();
+        msg.workSourceUid = ThreadLocalWorkSource.getUid();
 
         if (mAsynchronous) {
             msg.setAsynchronous(true);
diff --git a/core/java/android/os/Looper.java b/core/java/android/os/Looper.java
index 5b8abab..a8d1215 100644
--- a/core/java/android/os/Looper.java
+++ b/core/java/android/os/Looper.java
@@ -204,8 +204,8 @@
             if (observer != null) {
                 token = observer.messageDispatchStarting();
             }
+            long origWorkSource = ThreadLocalWorkSource.setUid(msg.workSourceUid);
             try {
-                ThreadLocalWorkSourceUid.set(msg.workSourceUid);
                 msg.target.dispatchMessage(msg);
                 if (observer != null) {
                     observer.messageDispatched(token, msg);
@@ -217,7 +217,7 @@
                 }
                 throw exception;
             } finally {
-                ThreadLocalWorkSourceUid.clear();
+                ThreadLocalWorkSource.restore(origWorkSource);
                 if (traceTag != 0) {
                     Trace.traceEnd(traceTag);
                 }
diff --git a/core/java/android/os/ThreadLocalWorkSourceUid.java b/core/java/android/os/ThreadLocalWorkSource.java
similarity index 64%
rename from core/java/android/os/ThreadLocalWorkSourceUid.java
rename to core/java/android/os/ThreadLocalWorkSource.java
index df1d275..53dd460 100644
--- a/core/java/android/os/ThreadLocalWorkSourceUid.java
+++ b/core/java/android/os/ThreadLocalWorkSource.java
@@ -19,26 +19,41 @@
 /**
  * @hide Only for use within system server.
  */
-public final class ThreadLocalWorkSourceUid {
+public final class ThreadLocalWorkSource {
     public static final int UID_NONE = Message.UID_NONE;
     private static final ThreadLocal<Integer> sWorkSourceUid =
             ThreadLocal.withInitial(() -> UID_NONE);
 
     /** Returns the original work source uid. */
-    public static int get() {
+    public static int getUid() {
         return sWorkSourceUid.get();
     }
 
     /** Sets the original work source uid. */
-    public static void set(int uid) {
+    public static long setUid(int uid) {
+        final long token = getToken();
         sWorkSourceUid.set(uid);
+        return token;
+    }
+
+    /** Restores the state using the provided token. */
+    public static void restore(long token) {
+        sWorkSourceUid.set(parseUidFromToken(token));
     }
 
     /** Clears the stored work source uid. */
-    public static void clear() {
-        sWorkSourceUid.set(UID_NONE);
+    public static long clear() {
+        return setUid(UID_NONE);
     }
 
-    private ThreadLocalWorkSourceUid() {
+    private static int parseUidFromToken(long token) {
+        return (int) token;
+    }
+
+    private static long getToken() {
+        return sWorkSourceUid.get();
+    }
+
+    private ThreadLocalWorkSource() {
     }
 }
diff --git a/core/java/com/android/internal/os/BinderCallsStats.java b/core/java/com/android/internal/os/BinderCallsStats.java
index 70fc72f..34e8ed4 100644
--- a/core/java/com/android/internal/os/BinderCallsStats.java
+++ b/core/java/com/android/internal/os/BinderCallsStats.java
@@ -421,7 +421,7 @@
     }
 
     protected int getWorkSourceUid() {
-        return Binder.getThreadWorkSource();
+        return Binder.getCallingWorkSourceUid();
     }
 
     protected long getElapsedRealtimeMicro() {
diff --git a/core/jni/android_util_Binder.cpp b/core/jni/android_util_Binder.cpp
index fd042b3..4f8bbc1 100644
--- a/core/jni/android_util_Binder.cpp
+++ b/core/jni/android_util_Binder.cpp
@@ -904,19 +904,24 @@
     return IPCThreadState::self()->getStrictModePolicy();
 }
 
-static jint android_os_Binder_setThreadWorkSource(jint workSource)
+static jlong android_os_Binder_setCallingWorkSourceUid(jint workSource)
 {
-    return IPCThreadState::self()->setWorkSource(workSource);
+    return IPCThreadState::self()->setCallingWorkSourceUid(workSource);
 }
 
-static jint android_os_Binder_getThreadWorkSource()
+static jlong android_os_Binder_getCallingWorkSourceUid()
 {
-    return IPCThreadState::self()->getWorkSource();
+    return IPCThreadState::self()->getCallingWorkSourceUid();
 }
 
-static jint android_os_Binder_clearThreadWorkSource()
+static jlong android_os_Binder_clearCallingWorkSource()
 {
-    return IPCThreadState::self()->clearWorkSource();
+    return IPCThreadState::self()->clearCallingWorkSource();
+}
+
+static void android_os_Binder_restoreCallingWorkSource(long token)
+{
+    IPCThreadState::self()->restoreCallingWorkSource(token);
 }
 
 static void android_os_Binder_flushPendingCommands(JNIEnv* env, jobject clazz)
@@ -962,11 +967,12 @@
     // @CriticalNative
     { "getThreadStrictModePolicy", "()I", (void*)android_os_Binder_getThreadStrictModePolicy },
     // @CriticalNative
-    { "setThreadWorkSource", "(I)I", (void*)android_os_Binder_setThreadWorkSource },
+    { "setCallingWorkSourceUid", "(I)J", (void*)android_os_Binder_setCallingWorkSourceUid },
     // @CriticalNative
-    { "getThreadWorkSource", "()I", (void*)android_os_Binder_getThreadWorkSource },
+    { "getCallingWorkSourceUid", "()I", (void*)android_os_Binder_getCallingWorkSourceUid },
     // @CriticalNative
-    { "clearThreadWorkSource", "()I", (void*)android_os_Binder_clearThreadWorkSource },
+    { "clearCallingWorkSource", "()J", (void*)android_os_Binder_clearCallingWorkSource },
+    { "restoreCallingWorkSource", "(J)V", (void*)android_os_Binder_restoreCallingWorkSource },
     { "flushPendingCommands", "()V", (void*)android_os_Binder_flushPendingCommands },
     { "getNativeBBinderHolder", "()J", (void*)android_os_Binder_getNativeBBinderHolder },
     { "getNativeFinalizer", "()J", (void*)android_os_Binder_getNativeFinalizer },
diff --git a/core/tests/coretests/src/android/os/BinderTest.java b/core/tests/coretests/src/android/os/BinderTest.java
index 1beb598..534c5cd 100644
--- a/core/tests/coretests/src/android/os/BinderTest.java
+++ b/core/tests/coretests/src/android/os/BinderTest.java
@@ -21,17 +21,26 @@
 import junit.framework.TestCase;
 
 public class BinderTest extends TestCase {
+    private static final int UID = 100;
 
     @SmallTest
     public void testSetWorkSource() throws Exception {
-        Binder.setThreadWorkSource(100);
-        assertEquals(100, Binder.getThreadWorkSource());
+        Binder.setCallingWorkSourceUid(UID);
+        assertEquals(UID, Binder.getCallingWorkSourceUid());
     }
 
     @SmallTest
     public void testClearWorkSource() throws Exception {
-        Binder.setThreadWorkSource(100);
-        Binder.clearThreadWorkSource();
-        assertEquals(-1, Binder.getThreadWorkSource());
+        Binder.setCallingWorkSourceUid(UID);
+        Binder.clearCallingWorkSource();
+        assertEquals(-1, Binder.getCallingWorkSourceUid());
+    }
+
+    @SmallTest
+    public void testRestoreWorkSource() throws Exception {
+        Binder.setCallingWorkSourceUid(UID);
+        long token = Binder.clearCallingWorkSource();
+        Binder.restoreCallingWorkSource(token);
+        assertEquals(UID, Binder.getCallingWorkSourceUid());
     }
 }