Rewrote DnsPinger - now is async and concurrant

Change-Id: I93e1374ae857832935202614e34ce18f040fcfc7
diff --git a/core/java/android/net/DnsPinger.java b/core/java/android/net/DnsPinger.java
index f2d84eb..81738f3 100644
--- a/core/java/android/net/DnsPinger.java
+++ b/core/java/android/net/DnsPinger.java
@@ -17,20 +17,27 @@
 package android.net;
 
 import android.content.Context;
-import android.net.ConnectivityManager;
-import android.net.LinkProperties;
-import android.net.NetworkUtils;
+import android.os.Handler;
+import android.os.Looper;
+import android.os.Message;
 import android.os.SystemClock;
 import android.provider.Settings;
 import android.util.Slog;
 
+import com.android.internal.util.Protocol;
+
+import java.io.IOException;
 import java.net.DatagramPacket;
 import java.net.DatagramSocket;
 import java.net.InetAddress;
 import java.net.NetworkInterface;
 import java.net.SocketTimeoutException;
+import java.util.ArrayList;
 import java.util.Collection;
+import java.util.Iterator;
+import java.util.List;
 import java.util.Random;
+import java.util.concurrent.atomic.AtomicInteger;
 
 /**
  * Performs a simple DNS "ping" by sending a "server status" query packet to the
@@ -40,42 +47,174 @@
  * API may not differentiate between a time out and a failure lookup (which we
  * really care about).
  * <p>
- * TODO : More general API. Socket does not bind to specified connection type
- * TODO : Choice of DNS query location - current looks up www.android.com
  *
  * @hide
  */
-public final class DnsPinger {
+public final class DnsPinger extends Handler {
     private static final boolean V = true;
 
-    /** Number of bytes for the query */
-    private static final int DNS_QUERY_BASE_SIZE = 32;
-
-    /** The DNS port */
+    private static final int RECEIVE_POLL_INTERVAL_MS = 30;
     private static final int DNS_PORT = 53;
 
+    /** Short socket timeout so we don't block one any 'receive' call */
+    private static final int SOCKET_TIMEOUT_MS = 1;
+
     /** Used to generate IDs */
-    private static Random sRandom = new Random();
+    private static final Random sRandom = new Random();
+    private static final AtomicInteger sCounter = new AtomicInteger();
 
     private ConnectivityManager mConnectivityManager = null;
-    private Context mContext;
-    private int mConnectionType;
-    private InetAddress mDefaultDns;
-
+    private final Context mContext;
+    private final int mConnectionType;
+    private final Handler mTarget;
+    private final InetAddress mDefaultDns;
     private String TAG;
 
+    private static final int BASE = Protocol.BASE_DNS_PINGER;
+
     /**
-     * @param connectionType The connection type from {@link ConnectivityManager}
+     * Async response packet for dns pings.
+     * arg1 is the ID of the ping, also returned by {@link #pingDnsAsync(InetAddress, int, int)}
+     * arg2 is the delay, or is negative on error.
      */
-    public DnsPinger(String TAG, Context context, int connectionType) {
+    public static final int DNS_PING_RESULT = BASE;
+    /** An error code for a {@link #DNS_PING_RESULT} packet */
+    public static final int TIMEOUT = -1;
+    /** An error code for a {@link #DNS_PING_RESULT} packet */
+    public static final int SOCKET_EXCEPTION = -2;
+
+    /**
+     * Send a new ping via a socket.  arg1 is ID, arg2 is timeout, obj is InetAddress to ping
+     */
+    private static final int ACTION_PING_DNS = BASE + 1;
+    private static final int ACTION_LISTEN_FOR_RESPONSE = BASE + 2;
+    private static final int ACTION_CANCEL_ALL_PINGS = BASE + 3;
+
+    private List<ActivePing> mActivePings = new ArrayList<ActivePing>();
+    private int mEventCounter;
+
+    private class ActivePing {
+        DatagramSocket socket;
+        int internalId;
+        short packetId;
+        int timeout;
+        Integer result;
+        long start = SystemClock.elapsedRealtime();
+    }
+
+    public DnsPinger(Context context, String TAG, Looper looper,
+            Handler target, int connectionType) {
+        super(looper);
+        this.TAG = TAG;
         mContext = context;
+        mTarget = target;
         mConnectionType = connectionType;
         if (!ConnectivityManager.isNetworkTypeValid(connectionType)) {
-            Slog.e(TAG, "Invalid connectionType in constructor: " + connectionType);
+            throw new IllegalArgumentException("Invalid connectionType in constructor: "
+                    + connectionType);
         }
-        this.TAG = TAG;
-
         mDefaultDns = getDefaultDns();
+        mEventCounter = 0;
+    }
+
+    @Override
+    public void handleMessage(Message msg) {
+        switch (msg.what) {
+            case ACTION_PING_DNS:
+                try {
+                    ActivePing newActivePing = new ActivePing();
+                    InetAddress dnsAddress = (InetAddress) msg.obj;
+                    newActivePing.internalId = msg.arg1;
+                    newActivePing.timeout = msg.arg2;
+                    newActivePing.socket = new DatagramSocket();
+                    // Set some socket properties
+                    newActivePing.socket.setSoTimeout(SOCKET_TIMEOUT_MS);
+
+                    // Try to bind but continue ping if bind fails
+                    try {
+                        newActivePing.socket.setNetworkInterface(NetworkInterface.getByName(
+                                getCurrentLinkProperties().getInterfaceName()));
+                    } catch (Exception e) {
+                        Slog.w(TAG,"sendDnsPing::Error binding to socket", e);
+                    }
+
+                    newActivePing.packetId = (short) sRandom.nextInt();
+                    byte[] buf = mDnsQuery.clone();
+                    buf[0] = (byte) (newActivePing.packetId >> 8);
+                    buf[1] = (byte) newActivePing.packetId;
+
+                    // Send the DNS query
+                    DatagramPacket packet = new DatagramPacket(buf,
+                            buf.length, dnsAddress, DNS_PORT);
+                    if (V) {
+                        Slog.v(TAG, "Sending a ping to " + dnsAddress.getHostAddress()
+                                + " with ID " + newActivePing.packetId + ".");
+                    }
+
+                    newActivePing.socket.send(packet);
+                    mActivePings.add(newActivePing);
+                    mEventCounter++;
+                    sendMessageDelayed(obtainMessage(ACTION_LISTEN_FOR_RESPONSE, mEventCounter, 0),
+                            RECEIVE_POLL_INTERVAL_MS);
+                } catch (IOException e) {
+                    sendResponse((short) msg.arg1, SOCKET_EXCEPTION);
+                }
+                break;
+            case ACTION_LISTEN_FOR_RESPONSE:
+                if (msg.arg1 != mEventCounter) {
+                    break;
+                }
+                for (ActivePing curPing : mActivePings) {
+                    try {
+                        /** Each socket will block for {@link #SOCKET_TIMEOUT_MS} in receive() */
+                        byte[] responseBuf = new byte[2];
+                        DatagramPacket replyPacket = new DatagramPacket(responseBuf, 2);
+                        curPing.socket.receive(replyPacket);
+                        // Check that ID field matches (we're throwing out the rest of the packet)
+                        if (responseBuf[0] == (byte) (curPing.packetId >> 8) &&
+                                responseBuf[1] == (byte) curPing.packetId) {
+                            curPing.result =
+                                    (int) (SystemClock.elapsedRealtime() - curPing.start);
+                        } else {
+                            if (V) {
+                                Slog.v(TAG, "response ID didn't match, ignoring packet");
+                            }
+                        }
+                    } catch (SocketTimeoutException e) {
+                        // A timeout here doesn't mean anything - squelsh this exception
+                    } catch (Exception e) {
+                        if (V) {
+                            Slog.v(TAG, "DnsPinger.pingDns got socket exception: ", e);
+                        }
+                        curPing.result = SOCKET_EXCEPTION;
+                    }
+                }
+                Iterator<ActivePing> iter = mActivePings.iterator();
+                while (iter.hasNext()) {
+                   ActivePing curPing = iter.next();
+                   if (curPing.result != null) {
+                       sendResponse(curPing.internalId, curPing.result);
+                       curPing.socket.close();
+                       iter.remove();
+                   } else if (SystemClock.elapsedRealtime() >
+                                  curPing.start + curPing.timeout) {
+                       sendResponse(curPing.internalId, TIMEOUT);
+                       curPing.socket.close();
+                       iter.remove();
+                   }
+                }
+                if (!mActivePings.isEmpty()) {
+                    sendMessageDelayed(obtainMessage(ACTION_LISTEN_FOR_RESPONSE, mEventCounter, 0),
+                            RECEIVE_POLL_INTERVAL_MS);
+                }
+                break;
+            case ACTION_CANCEL_ALL_PINGS:
+                for (ActivePing activePing : mActivePings)
+                    activePing.socket.close();
+                mActivePings.clear();
+                removeMessages(ACTION_PING_DNS);
+                break;
+        }
     }
 
     /**
@@ -99,6 +238,30 @@
         return dnses.iterator().next();
     }
 
+    /**
+     * Send a ping.  The response will come via a {@link #DNS_PING_RESULT} to the handler
+     * specified at creation.
+     * @param dns address of dns server to ping
+     * @param timeout timeout for ping
+     * @return an ID field, which will also be included in the {@link #DNS_PING_RESULT} message.
+     */
+    public int pingDnsAsync(InetAddress dns, int timeout, int delay) {
+        int id = sCounter.incrementAndGet();
+        sendMessageDelayed(obtainMessage(ACTION_PING_DNS, id, timeout, dns), delay);
+        return id;
+    }
+
+    public void cancelPings() {
+        obtainMessage(ACTION_CANCEL_ALL_PINGS).sendToTarget();
+    }
+
+    private void sendResponse(int internalId, int responseVal) {
+        if(V) {
+            Slog.v(TAG, "Responding with id " + internalId + " and val " + responseVal);
+        }
+        mTarget.sendMessage(obtainMessage(DNS_PING_RESULT, internalId, responseVal));
+    }
+
     private LinkProperties getCurrentLinkProperties() {
         if (mConnectivityManager == null) {
             mConnectivityManager = (ConnectivityManager) mContext.getSystemService(
@@ -123,106 +286,18 @@
         }
     }
 
-    /**
-     * @return time to response. Negative value on error.
-     */
-    public long pingDns(InetAddress dnsAddress, int timeout) {
-        DatagramSocket socket = null;
-        try {
-            socket = new DatagramSocket();
-
-            // Set some socket properties
-            socket.setSoTimeout(timeout);
-
-            // Try to bind but continue ping if bind fails
-            try {
-                socket.setNetworkInterface(NetworkInterface.getByName(
-                        getCurrentLinkProperties().getInterfaceName()));
-            } catch (Exception e) {
-                Slog.d(TAG,"pingDns::Error binding to socket", e);
-            }
-
-            byte[] buf = constructQuery();
-
-            // Send the DNS query
-
-            DatagramPacket packet = new DatagramPacket(buf,
-                    buf.length, dnsAddress, DNS_PORT);
-            long start = SystemClock.elapsedRealtime();
-            socket.send(packet);
-
-            // Wait for reply (blocks for the above timeout)
-            DatagramPacket replyPacket = new DatagramPacket(buf, buf.length);
-            socket.receive(replyPacket);
-
-            // If a timeout occurred, an exception would have been thrown. We
-            // got a reply!
-            return SystemClock.elapsedRealtime() - start;
-
-        } catch (SocketTimeoutException e) {
-            // Squelch this exception.
-            return -1;
-        } catch (Exception e) {
-            if (V) {
-                Slog.v(TAG, "DnsPinger.pingDns got socket exception: ", e);
-            }
-            return -2;
-        } finally {
-            if (socket != null) {
-                socket.close();
-            }
-        }
-
-    }
-
-    /**
-     * @return google.com DNS query packet
-     */
-    private static byte[] constructQuery() {
-        byte[] buf = new byte[DNS_QUERY_BASE_SIZE];
-
-        // [0-1] bytes are an ID, generate random ID for this query
-        buf[0] = (byte) sRandom.nextInt(256);
-        buf[1] = (byte) sRandom.nextInt(256);
-
-        // [2-3] bytes are for flags.
-        buf[2] = 0x01; // Recursion desired
-
-        // [4-5] bytes are for number of queries (QCOUNT)
-        buf[5] = 0x01;
-
-        // [6-7] [8-9] [10-11] are all counts of other fields we don't use
-
-        // [12-15] for www
-        writeString(buf, 12, "www");
-
-        // [16-22] for google
-        writeString(buf, 16, "google");
-
-        // [23-26] for com
-        writeString(buf, 23, "com");
-
-        // [27] is a null byte terminator byte for the url
-
-        // [28-29] bytes are for QTYPE, set to 1 = A (host address)
-        buf[29] = 0x01;
-
-        // [30-31] bytes are for QCLASS, set to 1 = IN (internet)
-        buf[31] = 0x01;
-
-        return buf;
-    }
-
-    /**
-     * Writes the string's length and its contents to the buffer
-     */
-    private static void writeString(byte[] buf, int startPos, String string) {
-        int pos = startPos;
-
-        // Write the length first
-        buf[pos++] = (byte) string.length();
-        for (int i = 0; i < string.length(); i++) {
-            buf[pos++] = (byte) string.charAt(i);
-        }
-    }
+    private static final byte[] mDnsQuery = new byte[] {
+        0, 0, // [0-1] is for ID (will set each time)
+        0, 0, // [2-3] are flags.  Set byte[2] = 1 for recursion desired (RD) on.  Currently off. 
+        0, 1, // [4-5] bytes are for number of queries (QCOUNT)
+        0, 0, // [6-7] unused count field for dns response packets
+        0, 0, // [8-9] unused count field for dns response packets
+        0, 0, // [10-11] unused count field for dns response packets
+        3, 'w', 'w', 'w',
+        6, 'g', 'o', 'o', 'g', 'l', 'e',
+        3, 'c', 'o', 'm',
+        0,    // null terminator of address (also called empty TLD)
+        0, 1, // QTYPE, set to 1 = A (host address)
+        0, 1  // QCLASS, set to 1 = IN (internet)
+    };
 }