Restricted profiles use Owner's VPN

Restricted profiles cannot start their own VPN and will use the Owner's
VPN if one is running.

Change-Id: I1fc153742047f9149acb414c4c9d35305e97d8d0
diff --git a/services/java/com/android/server/connectivity/Vpn.java b/services/java/com/android/server/connectivity/Vpn.java
index efa1c8e..3a2391f 100644
--- a/services/java/com/android/server/connectivity/Vpn.java
+++ b/services/java/com/android/server/connectivity/Vpn.java
@@ -32,6 +32,7 @@
 import android.content.pm.IPackageManager;
 import android.content.pm.PackageManager;
 import android.content.pm.ResolveInfo;
+import android.content.pm.UserInfo;
 import android.graphics.Bitmap;
 import android.graphics.Canvas;
 import android.graphics.drawable.Drawable;
@@ -57,12 +58,14 @@
 import android.os.SystemClock;
 import android.os.SystemService;
 import android.os.UserHandle;
-import android.os.RemoteException;
+import android.os.UserManager;
 import android.security.Credentials;
 import android.security.KeyStore;
 import android.util.Log;
+import android.util.SparseBooleanArray;
 import android.widget.Toast;
 
+import com.android.internal.annotations.GuardedBy;
 import com.android.internal.R;
 import com.android.internal.net.LegacyVpnInfo;
 import com.android.internal.net.VpnConfig;
@@ -103,6 +106,13 @@
     private volatile boolean mEnableNotif = true;
     private volatile boolean mEnableTeardown = true;
     private final IConnectivityManager mConnService;
+    private VpnConfig mConfig;
+
+    /* list of users using this VPN. */
+    @GuardedBy("this")
+    private SparseBooleanArray mVpnUsers = null;
+    private BroadcastReceiver mUserIntentReceiver = null;
+
     private final int mUserId;
 
     public Vpn(Context context, VpnCallback callback, INetworkManagementService netService,
@@ -119,6 +129,30 @@
         } catch (RemoteException e) {
             Log.wtf(TAG, "Problem registering observer", e);
         }
+        if (userId == UserHandle.USER_OWNER) {
+            // Owner's VPN also needs to handle restricted users
+            mUserIntentReceiver = new BroadcastReceiver() {
+                @Override
+                public void onReceive(Context context, Intent intent) {
+                    final String action = intent.getAction();
+                    final int userId = intent.getIntExtra(Intent.EXTRA_USER_HANDLE,
+                            UserHandle.USER_NULL);
+                    if (userId == UserHandle.USER_NULL) return;
+
+                    if (Intent.ACTION_USER_ADDED.equals(action)) {
+                        onUserAdded(userId);
+                    } else if (Intent.ACTION_USER_REMOVED.equals(action)) {
+                        onUserRemoved(userId);
+                    }
+                }
+            };
+
+            IntentFilter intentFilter = new IntentFilter();
+            intentFilter.addAction(Intent.ACTION_USER_ADDED);
+            intentFilter.addAction(Intent.ACTION_USER_REMOVED);
+            mContext.registerReceiverAsUser(
+                    mUserIntentReceiver, UserHandle.ALL, intentFilter, null, null);
+        }
     }
 
     /**
@@ -207,15 +241,20 @@
             final long token = Binder.clearCallingIdentity();
             try {
                 mCallback.restore();
-                mCallback.clearUserForwarding(mInterface, mUserId);
+                final int size = mVpnUsers.size();
+                for (int i = 0; i < size; i++) {
+                    int user = mVpnUsers.keyAt(i);
+                    mCallback.clearUserForwarding(mInterface, user);
+                    hideNotification(user);
+                }
 
                 mCallback.clearMarkedForwarding(mInterface);
-                hideNotification();
             } finally {
                 Binder.restoreCallingIdentity(token);
             }
             jniReset(mInterface);
             mInterface = null;
+            mVpnUsers = null;
         }
 
         // Revoke the connection or stop LegacyVpnRunner.
@@ -235,6 +274,7 @@
 
         Log.i(TAG, "Switched from " + mPackage + " to " + newPackage);
         mPackage = newPackage;
+        mConfig = null;
         updateState(DetailedState.IDLE, "prepare");
         return true;
     }
@@ -253,14 +293,14 @@
         if (Binder.getCallingUid() != appUid) {
             throw new SecurityException("Unauthorized Caller");
         }
-        //protect the socket from routing rules
+        // protect the socket from routing rules
         final long token = Binder.clearCallingIdentity();
         try {
             mCallback.protect(socket);
         } finally {
             Binder.restoreCallingIdentity(token);
         }
-        //bind the socket to the interface
+        // bind the socket to the interface
         jniProtect(socket.getFd(), interfaze);
 
     }
@@ -275,6 +315,7 @@
      */
     public synchronized ParcelFileDescriptor establish(VpnConfig config) {
         // Check if the caller is already prepared.
+        UserManager mgr = UserManager.get(mContext);
         PackageManager pm = mContext.getPackageManager();
         ApplicationInfo app = null;
         try {
@@ -285,12 +326,17 @@
         } catch (Exception e) {
             return null;
         }
-
         // Check if the service is properly declared.
         Intent intent = new Intent(VpnConfig.SERVICE_INTERFACE);
         intent.setClassName(mPackage, config.user);
         long token = Binder.clearCallingIdentity();
         try {
+            // Restricted users are not allowed to create VPNs, they are tied to Owner
+            UserInfo user = mgr.getUserInfo(mUserId);
+            if (user.isRestricted()) {
+                throw new SecurityException("Restricted users cannot establish VPNs");
+            }
+
             ResolveInfo info = AppGlobals.getPackageManager().resolveService(intent,
                                                                         null, 0, mUserId);
             if (info == null) {
@@ -305,31 +351,13 @@
             Binder.restoreCallingIdentity(token);
         }
 
-        // Load the label.
-        String label = app.loadLabel(pm).toString();
-
-        // Load the icon and convert it into a bitmap.
-        Drawable icon = app.loadIcon(pm);
-        Bitmap bitmap = null;
-        if (icon.getIntrinsicWidth() > 0 && icon.getIntrinsicHeight() > 0) {
-            int width = mContext.getResources().getDimensionPixelSize(
-                    android.R.dimen.notification_large_icon_width);
-            int height = mContext.getResources().getDimensionPixelSize(
-                    android.R.dimen.notification_large_icon_height);
-            icon.setBounds(0, 0, width, height);
-            bitmap = Bitmap.createBitmap(width, height, Bitmap.Config.ARGB_8888);
-            Canvas c = new Canvas(bitmap);
-            icon.draw(c);
-            c.setBitmap(null);
-        }
-
         // Configure the interface. Abort if any of these steps fails.
         ParcelFileDescriptor tun = ParcelFileDescriptor.adoptFd(jniCreate(config.mtu));
         try {
             updateState(DetailedState.CONNECTING, "establish");
             String interfaze = jniGetName(tun.getFd());
 
-            //TEMP use the old jni calls until there is support for netd address setting
+            // TEMP use the old jni calls until there is support for netd address setting
             StringBuilder builder = new StringBuilder();
             for (LinkAddress address : config.addresses) {
                 builder.append(" " + address);
@@ -354,14 +382,16 @@
             // Fill more values.
             config.user = mPackage;
             config.interfaze = mInterface;
+            config.startTime = SystemClock.elapsedRealtime();
+            mConfig = config;
             // Set up forwarding and DNS rules.
+            mVpnUsers = new SparseBooleanArray();
             token = Binder.clearCallingIdentity();
             try {
                 mCallback.setMarkedForwarding(mInterface);
                 mCallback.setRoutes(interfaze, config.routes);
                 mCallback.override(mInterface, config.dnsServers, config.searchDomains);
-                mCallback.addUserForwarding(mInterface, mUserId);
-                showNotification(config, label, bitmap);
+                addVpnUserLocked(mUserId);
 
             } finally {
                 Binder.restoreCallingIdentity(token);
@@ -370,22 +400,125 @@
         } catch (RuntimeException e) {
             updateState(DetailedState.FAILED, "establish");
             IoUtils.closeQuietly(tun);
-            //make sure marked forwarding is cleared if it was set
+            // make sure marked forwarding is cleared if it was set
             try {
                 mCallback.clearMarkedForwarding(mInterface);
             } catch (Exception ingored) {
-                //ignored
+                // ignored
             }
             throw e;
         }
         Log.i(TAG, "Established by " + config.user + " on " + mInterface);
 
 
+        // If we are owner assign all Restricted Users to this VPN
+        if (mUserId == UserHandle.USER_OWNER) {
+            token = Binder.clearCallingIdentity();
+            try {
+                for (UserInfo user : mgr.getUsers()) {
+                    if (user.isRestricted()) {
+                        try {
+                            addVpnUserLocked(user.id);
+                        } catch (Exception e) {
+                            Log.wtf(TAG, "Failed to add user " + user.id + " to owner's VPN");
+                        }
+                    }
+                }
+            } finally {
+                Binder.restoreCallingIdentity(token);
+            }
+        }
         // TODO: ensure that contract class eventually marks as connected
         updateState(DetailedState.AUTHENTICATING, "establish");
         return tun;
     }
 
+    private boolean isRunningLocked() {
+        return mVpnUsers != null;
+    }
+
+    private void addVpnUserLocked(int user) {
+        enforceControlPermission();
+
+        if (!isRunningLocked()) {
+            throw new IllegalStateException("VPN is not active");
+        }
+        // add the user
+        mCallback.addUserForwarding(mInterface, user);
+        mVpnUsers.put(user, true);
+
+        // show the notification
+        if (!mPackage.equals(VpnConfig.LEGACY_VPN)) {
+            // Load everything for the user's notification
+            PackageManager pm = mContext.getPackageManager();
+            ApplicationInfo app = null;
+            try {
+                app = AppGlobals.getPackageManager().getApplicationInfo(mPackage, 0, mUserId);
+            } catch (RemoteException e) {
+                throw new IllegalStateException("Invalid application");
+            }
+            String label = app.loadLabel(pm).toString();
+            // Load the icon and convert it into a bitmap.
+            Drawable icon = app.loadIcon(pm);
+            Bitmap bitmap = null;
+            if (icon.getIntrinsicWidth() > 0 && icon.getIntrinsicHeight() > 0) {
+                int width = mContext.getResources().getDimensionPixelSize(
+                        android.R.dimen.notification_large_icon_width);
+                int height = mContext.getResources().getDimensionPixelSize(
+                        android.R.dimen.notification_large_icon_height);
+                icon.setBounds(0, 0, width, height);
+                bitmap = Bitmap.createBitmap(width, height, Bitmap.Config.ARGB_8888);
+                Canvas c = new Canvas(bitmap);
+                icon.draw(c);
+                c.setBitmap(null);
+            }
+            showNotification(label, bitmap, user);
+        } else {
+            showNotification(null, null, user);
+        }
+    }
+
+    private void removeVpnUserLocked(int user) {
+            enforceControlPermission();
+
+            if (!isRunningLocked()) {
+                throw new IllegalStateException("VPN is not active");
+            }
+            mCallback.clearUserForwarding(mInterface, user);
+            mVpnUsers.delete(user);
+            hideNotification(user);
+    }
+
+    private void onUserAdded(int userId) {
+        // If the user is restricted tie them to the owner's VPN
+        synchronized(Vpn.this) {
+            UserManager mgr = UserManager.get(mContext);
+            UserInfo user = mgr.getUserInfo(userId);
+            if (user.isRestricted()) {
+                try {
+                    addVpnUserLocked(userId);
+                } catch (Exception e) {
+                    Log.wtf(TAG, "Failed to add restricted user to owner", e);
+                }
+            }
+        }
+    }
+
+    private void onUserRemoved(int userId) {
+        // clean up if restricted
+        synchronized(Vpn.this) {
+            UserManager mgr = UserManager.get(mContext);
+            UserInfo user = mgr.getUserInfo(userId);
+            if (user.isRestricted()) {
+                try {
+                    removeVpnUserLocked(userId);
+                } catch (Exception e) {
+                    Log.wtf(TAG, "Failed to remove restricted user to owner", e);
+                }
+            }
+        }
+    }
+
     @Deprecated
     public synchronized void interfaceStatusChanged(String iface, boolean up) {
         try {
@@ -411,11 +544,16 @@
                 if (interfaze.equals(mInterface) && jniCheck(interfaze) == 0) {
                     final long token = Binder.clearCallingIdentity();
                     try {
-                        mCallback.clearUserForwarding(mInterface, mUserId);
+                        final int size = mVpnUsers.size();
+                        for (int i = 0; i < size; i++) {
+                            int user = mVpnUsers.keyAt(i);
+                            mCallback.clearUserForwarding(mInterface, user);
+                            hideNotification(user);
+                        }
+                        mVpnUsers = null;
                         mCallback.clearMarkedForwarding(mInterface);
 
                         mCallback.restore();
-                        hideNotification();
                     } finally {
                         Binder.restoreCallingIdentity(token);
                     }
@@ -470,9 +608,9 @@
         }
     }
 
-    private void showNotification(VpnConfig config, String label, Bitmap icon) {
+    private void showNotification(String label, Bitmap icon, int user) {
         if (!mEnableNotif) return;
-        mStatusIntent = VpnConfig.getIntentForStatusPanel(mContext, config);
+        mStatusIntent = VpnConfig.getIntentForStatusPanel(mContext, mConfig);
 
         NotificationManager nm = (NotificationManager)
                 mContext.getSystemService(Context.NOTIFICATION_SERVICE);
@@ -480,9 +618,8 @@
         if (nm != null) {
             String title = (label == null) ? mContext.getString(R.string.vpn_title) :
                     mContext.getString(R.string.vpn_title_long, label);
-            String text = (config.session == null) ? mContext.getString(R.string.vpn_text) :
-                    mContext.getString(R.string.vpn_text_long, config.session);
-            config.startTime = SystemClock.elapsedRealtime();
+            String text = (mConfig.session == null) ? mContext.getString(R.string.vpn_text) :
+                    mContext.getString(R.string.vpn_text_long, mConfig.session);
 
             Notification notification = new Notification.Builder(mContext)
                     .setSmallIcon(R.drawable.vpn_connected)
@@ -493,11 +630,11 @@
                     .setDefaults(0)
                     .setOngoing(true)
                     .build();
-            nm.notifyAsUser(null, R.drawable.vpn_connected, notification,new UserHandle(mUserId));
+            nm.notifyAsUser(null, R.drawable.vpn_connected, notification, new UserHandle(user));
         }
     }
 
-    private void hideNotification() {
+    private void hideNotification(int user) {
         if (!mEnableNotif) return;
         mStatusIntent = null;
 
@@ -505,7 +642,7 @@
                 mContext.getSystemService(Context.NOTIFICATION_SERVICE);
 
         if (nm != null) {
-            nm.cancelAsUser(null, R.drawable.vpn_connected, new UserHandle(mUserId));
+            nm.cancelAsUser(null, R.drawable.vpn_connected, new UserHandle(user));
         }
     }
 
@@ -671,7 +808,7 @@
         if (mLegacyVpnRunner == null) return null;
 
         final LegacyVpnInfo info = new LegacyVpnInfo();
-        info.key = mLegacyVpnRunner.mConfig.user;
+        info.key = mConfig.user;
         info.state = LegacyVpnInfo.stateFromNetworkInfo(mNetworkInfo);
         if (mNetworkInfo.isConnected()) {
             info.intent = mStatusIntent;
@@ -681,7 +818,7 @@
 
     public VpnConfig getLegacyVpnConfig() {
         if (mLegacyVpnRunner != null) {
-            return mLegacyVpnRunner.mConfig;
+            return mConfig;
         } else {
             return null;
         }
@@ -697,7 +834,6 @@
     private class LegacyVpnRunner extends Thread {
         private static final String TAG = "LegacyVpnRunner";
 
-        private final VpnConfig mConfig;
         private final String[] mDaemons;
         private final String[][] mArguments;
         private final LocalSocket[] mSockets;
@@ -962,22 +1098,40 @@
 
                     // Now INetworkManagementEventObserver is watching our back.
                     mInterface = mConfig.interfaze;
+                    mVpnUsers = new SparseBooleanArray();
 
                     token = Binder.clearCallingIdentity();
                     try {
                         mCallback.override(mInterface, mConfig.dnsServers, mConfig.searchDomains);
-                        mCallback.addUserForwarding(mInterface, mUserId);
-                        showNotification(mConfig, null, null);
+                        addVpnUserLocked(mUserId);
                     } finally {
                         Binder.restoreCallingIdentity(token);
                     }
 
+                    // Assign all restircted users to this VPN
+                    // (Legacy VPNs are Owner only)
+                    UserManager mgr = UserManager.get(mContext);
+                    token = Binder.clearCallingIdentity();
+                    try {
+                        for (UserInfo user : mgr.getUsers()) {
+                            if (user.isRestricted()) {
+                                try {
+                                    addVpnUserLocked(user.id);
+                                } catch (Exception e) {
+                                    Log.wtf(TAG, "Failed to add user " + user.id
+                                            + " to owner's VPN");
+                                }
+                            }
+                        }
+                    } finally {
+                        Binder.restoreCallingIdentity(token);
+                    }
                     Log.i(TAG, "Connected!");
                     updateState(DetailedState.CONNECTED, "execute");
                 }
             } catch (Exception e) {
                 Log.i(TAG, "Aborting", e);
-                //make sure the routing is cleared
+                // make sure the routing is cleared
                 try {
                     mCallback.clearMarkedForwarding(mConfig.interfaze);
                 } catch (Exception ignored) {