Merge "Change BlockingSocketReader to use MessageQueue fd handling" am: 48c7d27f64
am: b9b0886335

Change-Id: Id543589394c230b657ac150ab8fd4b9e388283eb
diff --git a/services/net/java/android/net/ip/ConnectivityPacketTracker.java b/services/net/java/android/net/ip/ConnectivityPacketTracker.java
index 884a8a7..0230f36 100644
--- a/services/net/java/android/net/ip/ConnectivityPacketTracker.java
+++ b/services/net/java/android/net/ip/ConnectivityPacketTracker.java
@@ -61,11 +61,11 @@
     private static final String MARK_STOP = "--- STOP ---";
 
     private final String mTag;
-    private final Handler mHandler;
     private final LocalLog mLog;
     private final BlockingSocketReader mPacketListener;
+    private boolean mRunning;
 
-    public ConnectivityPacketTracker(NetworkInterface netif, LocalLog log) {
+    public ConnectivityPacketTracker(Handler h, NetworkInterface netif, LocalLog log) {
         final String ifname;
         final int ifindex;
         final byte[] hwaddr;
@@ -81,44 +81,40 @@
         }
 
         mTag = TAG + "." + ifname;
-        mHandler = new Handler();
         mLog = log;
-        mPacketListener = new PacketListener(ifindex, hwaddr, mtu);
+        mPacketListener = new PacketListener(h, ifindex, hwaddr, mtu);
     }
 
     public void start() {
-        mLog.log(MARK_START);
+        mRunning = true;
         mPacketListener.start();
     }
 
     public void stop() {
         mPacketListener.stop();
-        mLog.log(MARK_STOP);
+        mRunning = false;
     }
 
     private final class PacketListener extends BlockingSocketReader {
         private final int mIfIndex;
         private final byte mHwAddr[];
 
-        PacketListener(int ifindex, byte[] hwaddr, int mtu) {
-            super(mtu);
+        PacketListener(Handler h, int ifindex, byte[] hwaddr, int mtu) {
+            super(h, mtu);
             mIfIndex = ifindex;
             mHwAddr = hwaddr;
         }
 
         @Override
-        protected FileDescriptor createSocket() {
+        protected FileDescriptor createFd() {
             FileDescriptor s = null;
             try {
-                // TODO: Evaluate switching to SOCK_DGRAM and changing the
-                // BlockingSocketReader's read() to recvfrom(), so that this
-                // might work on non-ethernet-like links (via SLL).
                 s = Os.socket(AF_PACKET, SOCK_RAW, 0);
                 NetworkUtils.attachControlPacketFilter(s, ARPHRD_ETHER);
                 Os.bind(s, new PacketSocketAddress((short) ETH_P_ALL, mIfIndex));
             } catch (ErrnoException | IOException e) {
                 logError("Failed to create packet tracking socket: ", e);
-                closeSocket(s);
+                closeFd(s);
                 return null;
             }
             return s;
@@ -136,13 +132,27 @@
         }
 
         @Override
+        protected void onStart() {
+            mLog.log(MARK_START);
+        }
+
+        @Override
+        protected void onStop() {
+            if (mRunning) {
+                mLog.log(MARK_STOP);
+            } else {
+                mLog.log(MARK_STOP + " (packet listener stopped unexpectedly)");
+            }
+        }
+
+        @Override
         protected void logError(String msg, Exception e) {
             Log.e(mTag, msg, e);
             addLogEntry(msg + e);
         }
 
         private void addLogEntry(String entry) {
-            mHandler.post(() -> mLog.log(entry));
+            mLog.log(entry);
         }
     }
 }
diff --git a/services/net/java/android/net/ip/IpManager.java b/services/net/java/android/net/ip/IpManager.java
index b1eb085..bc07b81 100644
--- a/services/net/java/android/net/ip/IpManager.java
+++ b/services/net/java/android/net/ip/IpManager.java
@@ -1515,7 +1515,8 @@
 
         private ConnectivityPacketTracker createPacketTracker() {
             try {
-                return new ConnectivityPacketTracker(mNetworkInterface, mConnectivityPacketLog);
+                return new ConnectivityPacketTracker(
+                        getHandler(), mNetworkInterface, mConnectivityPacketLog);
             } catch (IllegalArgumentException e) {
                 return null;
             }
diff --git a/services/net/java/android/net/util/BlockingSocketReader.java b/services/net/java/android/net/util/BlockingSocketReader.java
index 12fa1e5..99bf469 100644
--- a/services/net/java/android/net/util/BlockingSocketReader.java
+++ b/services/net/java/android/net/util/BlockingSocketReader.java
@@ -16,81 +16,106 @@
 
 package android.net.util;
 
+import static android.os.MessageQueue.OnFileDescriptorEventListener.EVENT_INPUT;
+import static android.os.MessageQueue.OnFileDescriptorEventListener.EVENT_ERROR;
+
 import android.annotation.Nullable;
+import android.os.Handler;
+import android.os.Looper;
+import android.os.MessageQueue;
+import android.os.MessageQueue.OnFileDescriptorEventListener;
 import android.system.ErrnoException;
 import android.system.Os;
 import android.system.OsConstants;
 
-import libcore.io.IoBridge;
+import libcore.io.IoUtils;
 
 import java.io.FileDescriptor;
-import java.io.InterruptedIOException;
 import java.io.IOException;
 
 
 /**
- * A thread that reads from a socket and passes the received packets to a
- * subclass's handlePacket() method.  The packet receive buffer is recycled
- * on every read call, so subclasses should make any copies they would like
- * inside their handlePacket() implementation.
+ * This class encapsulates the mechanics of registering a file descriptor
+ * with a thread's Looper and handling read events (and errors).
  *
- * All public methods may be called from any thread.
+ * Subclasses MUST implement createFd() and SHOULD override handlePacket().
+
+ * Subclasses can expect a call life-cycle like the following:
+ *
+ *     [1] start() calls createFd() and (if all goes well) onStart()
+ *
+ *     [2] yield, waiting for read event or error notification:
+ *
+ *             [a] readPacket() && handlePacket()
+ *
+ *             [b] if (no error):
+ *                     goto 2
+ *                 else:
+ *                     goto 3
+ *
+ *     [3] stop() calls onStop() if not previously stopped
+ *
+ * The packet receive buffer is recycled on every read call, so subclasses
+ * should make any copies they would like inside their handlePacket()
+ * implementation.
+ *
+ * All public methods MUST only be called from the same thread with which
+ * the Handler constructor argument is associated.
+ *
+ * TODO: rename this class to something more correctly descriptive (something
+ * like [or less horrible than] FdReadEventsHandler?).
  *
  * @hide
  */
 public abstract class BlockingSocketReader {
+    private static final int FD_EVENTS = EVENT_INPUT | EVENT_ERROR;
+    private static final int UNREGISTER_THIS_FD = 0;
+
     public static final int DEFAULT_RECV_BUF_SIZE = 2 * 1024;
 
+    private final Handler mHandler;
+    private final MessageQueue mQueue;
     private final byte[] mPacket;
-    private final Thread mThread;
-    private volatile FileDescriptor mSocket;
-    private volatile boolean mRunning;
-    private volatile long mPacketsReceived;
+    private FileDescriptor mFd;
+    private long mPacketsReceived;
 
-    // Make it slightly easier for subclasses to properly close a socket
-    // without having to know this incantation.
-    public static final void closeSocket(@Nullable FileDescriptor fd) {
-        try {
-            IoBridge.closeAndSignalBlockedThreads(fd);
-        } catch (IOException ignored) {}
+    protected static void closeFd(FileDescriptor fd) {
+        IoUtils.closeQuietly(fd);
     }
 
-    protected BlockingSocketReader() {
-        this(DEFAULT_RECV_BUF_SIZE);
+    protected BlockingSocketReader(Handler h) {
+        this(h, DEFAULT_RECV_BUF_SIZE);
     }
 
-    protected BlockingSocketReader(int recvbufsize) {
-        if (recvbufsize < DEFAULT_RECV_BUF_SIZE) {
-            recvbufsize = DEFAULT_RECV_BUF_SIZE;
+    protected BlockingSocketReader(Handler h, int recvbufsize) {
+        mHandler = h;
+        mQueue = mHandler.getLooper().getQueue();
+        mPacket = new byte[Math.max(recvbufsize, DEFAULT_RECV_BUF_SIZE)];
+    }
+
+    public final void start() {
+        if (onCorrectThread()) {
+            createAndRegisterFd();
+        } else {
+            mHandler.post(() -> {
+                logError("start() called from off-thread", null);
+                createAndRegisterFd();
+            });
         }
-        mPacket = new byte[recvbufsize];
-        mThread = new Thread(() -> { mainLoop(); });
-    }
-
-    public final boolean start() {
-        if (mSocket != null) return false;
-
-        try {
-            mSocket = createSocket();
-        } catch (Exception e) {
-            logError("Failed to create socket: ", e);
-            return false;
-        }
-
-        if (mSocket == null) return false;
-
-        mRunning = true;
-        mThread.start();
-        return true;
     }
 
     public final void stop() {
-        mRunning = false;
-        closeSocket(mSocket);
-        mSocket = null;
+        if (onCorrectThread()) {
+            unregisterAndDestroyFd();
+        } else {
+            mHandler.post(() -> {
+                logError("stop() called from off-thread", null);
+                unregisterAndDestroyFd();
+            });
+        }
     }
 
-    public final boolean isRunning() { return mRunning; }
+    public final int recvBufSize() { return mPacket.length; }
 
     public final long numPacketsReceived() { return mPacketsReceived; }
 
@@ -98,11 +123,21 @@
      * Subclasses MUST create the listening socket here, including setting
      * all desired socket options, interface or address/port binding, etc.
      */
-    protected abstract FileDescriptor createSocket();
+    protected abstract FileDescriptor createFd();
+
+    /**
+     * Subclasses MAY override this to change the default read() implementation
+     * in favour of, say, recvfrom().
+     *
+     * Implementations MUST return the bytes read or throw an Exception.
+     */
+    protected int readPacket(FileDescriptor fd, byte[] packetBuffer) throws Exception {
+        return Os.read(fd, packetBuffer, 0, packetBuffer.length);
+    }
 
     /**
      * Called by the main loop for every packet.  Any desired copies of
-     * |recvbuf| should be made in here, and the underlying byte array is
+     * |recvbuf| should be made in here, as the underlying byte array is
      * reused across all reads.
      */
     protected void handlePacket(byte[] recvbuf, int length) {}
@@ -113,43 +148,102 @@
     protected void logError(String msg, Exception e) {}
 
     /**
-     * Called by the main loop just prior to exiting.
+     * Called by start(), if successful, just prior to returning.
      */
-    protected void onExit() {}
+    protected void onStart() {}
 
-    private final void mainLoop() {
+    /**
+     * Called by stop() just prior to returning.
+     */
+    protected void onStop() {}
+
+    private void createAndRegisterFd() {
+        if (mFd != null) return;
+
+        try {
+            mFd = createFd();
+            if (mFd != null) {
+                // Force the socket to be non-blocking.
+                IoUtils.setBlocking(mFd, false);
+            }
+        } catch (Exception e) {
+            logError("Failed to create socket: ", e);
+            closeFd(mFd);
+            mFd = null;
+            return;
+        }
+
+        if (mFd == null) return;
+
+        mQueue.addOnFileDescriptorEventListener(
+                mFd,
+                FD_EVENTS,
+                new OnFileDescriptorEventListener() {
+                    @Override
+                    public int onFileDescriptorEvents(FileDescriptor fd, int events) {
+                        // Always call handleInput() so read/recvfrom are given
+                        // a proper chance to encounter a meaningful errno and
+                        // perhaps log a useful error message.
+                        if (!isRunning() || !handleInput()) {
+                            unregisterAndDestroyFd();
+                            return UNREGISTER_THIS_FD;
+                        }
+                        return FD_EVENTS;
+                    }
+                });
+        onStart();
+    }
+
+    private boolean isRunning() { return (mFd != null) && mFd.valid(); }
+
+    // Keep trying to read until we get EAGAIN/EWOULDBLOCK or some fatal error.
+    private boolean handleInput() {
         while (isRunning()) {
             final int bytesRead;
 
             try {
-                // Blocking read.
-                // TODO: See if this can be converted to recvfrom.
-                bytesRead = Os.read(mSocket, mPacket, 0, mPacket.length);
+                bytesRead = readPacket(mFd, mPacket);
                 if (bytesRead < 1) {
                     if (isRunning()) logError("Socket closed, exiting", null);
                     break;
                 }
                 mPacketsReceived++;
             } catch (ErrnoException e) {
-                if (e.errno != OsConstants.EINTR) {
-                    if (isRunning()) logError("read error: ", e);
+                if (e.errno == OsConstants.EAGAIN) {
+                    // We've read everything there is to read this time around.
+                    return true;
+                } else if (e.errno == OsConstants.EINTR) {
+                    continue;
+                } else {
+                    if (isRunning()) logError("readPacket error: ", e);
                     break;
                 }
-                continue;
-            } catch (IOException ioe) {
-                if (isRunning()) logError("read error: ", ioe);
-                continue;
+            } catch (Exception e) {
+                if (isRunning()) logError("readPacket error: ", e);
+                break;
             }
 
             try {
                 handlePacket(mPacket, bytesRead);
             } catch (Exception e) {
-                logError("Unexpected exception: ", e);
+                logError("handlePacket error: ", e);
                 break;
             }
         }
 
-        stop();
-        onExit();
+        return false;
+    }
+
+    private void unregisterAndDestroyFd() {
+        if (mFd == null) return;
+
+        mQueue.removeOnFileDescriptorEventListener(mFd);
+        closeFd(mFd);
+        mFd = null;
+        onStop();
+    }
+
+    private boolean onCorrectThread() {
+        return (mHandler.getLooper() == Looper.myLooper());
     }
 }
diff --git a/tests/net/java/android/net/util/BlockingSocketReaderTest.java b/tests/net/java/android/net/util/BlockingSocketReaderTest.java
index e03350f..1aad453 100644
--- a/tests/net/java/android/net/util/BlockingSocketReaderTest.java
+++ b/tests/net/java/android/net/util/BlockingSocketReaderTest.java
@@ -16,8 +16,11 @@
 
 package android.net.util;
 
+import static android.net.util.BlockingSocketReader.DEFAULT_RECV_BUF_SIZE;
 import static android.system.OsConstants.*;
 
+import android.os.Handler;
+import android.os.HandlerThread;
 import android.system.ErrnoException;
 import android.system.Os;
 import android.system.StructTimeval;
@@ -27,6 +30,7 @@
 import java.io.FileDescriptor;
 import java.io.FileInputStream;
 import java.io.IOException;
+import java.io.UncheckedIOException;
 import java.net.DatagramPacket;
 import java.net.DatagramSocket;
 import java.net.Inet6Address;
@@ -53,61 +57,83 @@
     protected FileDescriptor mLocalSocket;
     protected InetSocketAddress mLocalSockName;
     protected byte[] mLastRecvBuf;
-    protected boolean mExited;
+    protected boolean mStopped;
+    protected HandlerThread mHandlerThread;
     protected BlockingSocketReader mReceiver;
 
+    class UdpLoopbackReader extends BlockingSocketReader {
+        public UdpLoopbackReader(Handler h) {
+            super(h);
+        }
+
+        @Override
+        protected FileDescriptor createFd() {
+            FileDescriptor s = null;
+            try {
+                s = Os.socket(AF_INET6, SOCK_DGRAM, IPPROTO_UDP);
+                Os.bind(s, LOOPBACK6, 0);
+                mLocalSockName = (InetSocketAddress) Os.getsockname(s);
+                Os.setsockoptTimeval(s, SOL_SOCKET, SO_SNDTIMEO, TIMEO);
+            } catch (ErrnoException|SocketException e) {
+                closeFd(s);
+                fail();
+                return null;
+            }
+
+            mLocalSocket = s;
+            return s;
+        }
+
+        @Override
+        protected void handlePacket(byte[] recvbuf, int length) {
+            mLastRecvBuf = Arrays.copyOf(recvbuf, length);
+            mLatch.countDown();
+        }
+
+        @Override
+        protected void onStart() {
+            mStopped = false;
+            mLatch.countDown();
+        }
+
+        @Override
+        protected void onStop() {
+            mStopped = true;
+            mLatch.countDown();
+        }
+    };
+
     @Override
     public void setUp() {
         resetLatch();
         mLocalSocket = null;
         mLocalSockName = null;
         mLastRecvBuf = null;
-        mExited = false;
+        mStopped = false;
 
-        mReceiver = new BlockingSocketReader() {
-            @Override
-            protected FileDescriptor createSocket() {
-                FileDescriptor s = null;
-                try {
-                    s = Os.socket(AF_INET6, SOCK_DGRAM, IPPROTO_UDP);
-                    Os.bind(s, LOOPBACK6, 0);
-                    mLocalSockName = (InetSocketAddress) Os.getsockname(s);
-                    Os.setsockoptTimeval(s, SOL_SOCKET, SO_SNDTIMEO, TIMEO);
-                } catch (ErrnoException|SocketException e) {
-                    closeSocket(s);
-                    fail();
-                    return null;
-                }
-
-                mLocalSocket = s;
-                return s;
-            }
-
-            @Override
-            protected void handlePacket(byte[] recvbuf, int length) {
-                mLastRecvBuf = Arrays.copyOf(recvbuf, length);
-                mLatch.countDown();
-            }
-
-            @Override
-            protected void onExit() {
-                mExited = true;
-                mLatch.countDown();
-            }
-        };
+        mHandlerThread = new HandlerThread(BlockingSocketReaderTest.class.getSimpleName());
+        mHandlerThread.start();
     }
 
     @Override
-    public void tearDown() {
-        if (mReceiver != null) mReceiver.stop();
+    public void tearDown() throws Exception {
+        if (mReceiver != null) {
+            mHandlerThread.getThreadHandler().post(() -> { mReceiver.stop(); });
+            waitForActivity();
+        }
         mReceiver = null;
+        mHandlerThread.quit();
+        mHandlerThread = null;
     }
 
     void resetLatch() { mLatch = new CountDownLatch(1); }
 
     void waitForActivity() throws Exception {
-        assertTrue(mLatch.await(500, TimeUnit.MILLISECONDS));
-        resetLatch();
+        try {
+            mLatch.await(1000, TimeUnit.MILLISECONDS);
+        } finally {
+            resetLatch();
+        }
     }
 
     void sendPacket(byte[] contents) throws Exception {
@@ -118,31 +144,54 @@
     }
 
     public void testBasicWorking() throws Exception {
-        assertTrue(mReceiver.start());
+        final Handler h = mHandlerThread.getThreadHandler();
+        mReceiver = new UdpLoopbackReader(h);
+
+        h.post(() -> { mReceiver.start(); });
+        waitForActivity();
         assertTrue(mLocalSockName != null);
         assertEquals(LOOPBACK6, mLocalSockName.getAddress());
         assertTrue(0 < mLocalSockName.getPort());
         assertTrue(mLocalSocket != null);
-        assertFalse(mExited);
+        assertFalse(mStopped);
 
         final byte[] one = "one 1".getBytes("UTF-8");
         sendPacket(one);
         waitForActivity();
         assertEquals(1, mReceiver.numPacketsReceived());
         assertTrue(Arrays.equals(one, mLastRecvBuf));
-        assertFalse(mExited);
+        assertFalse(mStopped);
 
         final byte[] two = "two 2".getBytes("UTF-8");
         sendPacket(two);
         waitForActivity();
         assertEquals(2, mReceiver.numPacketsReceived());
         assertTrue(Arrays.equals(two, mLastRecvBuf));
-        assertFalse(mExited);
+        assertFalse(mStopped);
 
         mReceiver.stop();
         waitForActivity();
         assertEquals(2, mReceiver.numPacketsReceived());
         assertTrue(Arrays.equals(two, mLastRecvBuf));
-        assertTrue(mExited);
+        assertTrue(mStopped);
+        mReceiver = null;
+    }
+
+    class NullBlockingSocketReader extends BlockingSocketReader {
+        public NullBlockingSocketReader(Handler h, int recvbufsize) {
+            super(h, recvbufsize);
+        }
+
+        @Override
+        public FileDescriptor createFd() { return null; }
+    }
+
+    public void testMinimalRecvBufSize() throws Exception {
+        final Handler h = mHandlerThread.getThreadHandler();
+
+        for (int i : new int[]{-1, 0, 1, DEFAULT_RECV_BUF_SIZE-1}) {
+            final BlockingSocketReader b = new NullBlockingSocketReader(h, i);
+            assertEquals(DEFAULT_RECV_BUF_SIZE, b.recvBufSize());
+        }
     }
 }