Merge "Create a NetworkStack test version module."
diff --git a/src/android/net/dhcp/DhcpAckPacket.java b/src/android/net/dhcp/DhcpAckPacket.java
index b2eb4e2..052af35 100644
--- a/src/android/net/dhcp/DhcpAckPacket.java
+++ b/src/android/net/dhcp/DhcpAckPacket.java
@@ -22,7 +22,7 @@
 /**
  * This class implements the DHCP-ACK packet.
  */
-class DhcpAckPacket extends DhcpPacket {
+public class DhcpAckPacket extends DhcpPacket {
 
     /**
      * The address of the server which sent this packet.
diff --git a/src/android/net/dhcp/DhcpDeclinePacket.java b/src/android/net/dhcp/DhcpDeclinePacket.java
index 7ecdea7..2f4f0f6 100644
--- a/src/android/net/dhcp/DhcpDeclinePacket.java
+++ b/src/android/net/dhcp/DhcpDeclinePacket.java
@@ -22,7 +22,7 @@
 /**
  * This class implements the DHCP-DECLINE packet.
  */
-class DhcpDeclinePacket extends DhcpPacket {
+public class DhcpDeclinePacket extends DhcpPacket {
     /**
      * Generates a DECLINE packet with the specified parameters.
      */
diff --git a/src/android/net/dhcp/DhcpInformPacket.java b/src/android/net/dhcp/DhcpInformPacket.java
index 7a83466..135b8f6 100644
--- a/src/android/net/dhcp/DhcpInformPacket.java
+++ b/src/android/net/dhcp/DhcpInformPacket.java
@@ -22,7 +22,7 @@
 /**
  * This class implements the (unused) DHCP-INFORM packet.
  */
-class DhcpInformPacket extends DhcpPacket {
+public class DhcpInformPacket extends DhcpPacket {
     /**
      * Generates an INFORM packet with the specified parameters.
      */
diff --git a/src/android/net/dhcp/DhcpNakPacket.java b/src/android/net/dhcp/DhcpNakPacket.java
index 1da0b73..98bd188 100644
--- a/src/android/net/dhcp/DhcpNakPacket.java
+++ b/src/android/net/dhcp/DhcpNakPacket.java
@@ -22,7 +22,7 @@
 /**
  * This class implements the DHCP-NAK packet.
  */
-class DhcpNakPacket extends DhcpPacket {
+public class DhcpNakPacket extends DhcpPacket {
     /**
      * Generates a NAK packet with the specified parameters.
      */
diff --git a/src/android/net/dhcp/DhcpOfferPacket.java b/src/android/net/dhcp/DhcpOfferPacket.java
index 0eba77e..aae08a7 100644
--- a/src/android/net/dhcp/DhcpOfferPacket.java
+++ b/src/android/net/dhcp/DhcpOfferPacket.java
@@ -22,7 +22,7 @@
 /**
  * This class implements the DHCP-OFFER packet.
  */
-class DhcpOfferPacket extends DhcpPacket {
+public class DhcpOfferPacket extends DhcpPacket {
     /**
      * The IP address of the server which sent this packet.
      */
diff --git a/src/android/net/dhcp/DhcpReleasePacket.java b/src/android/net/dhcp/DhcpReleasePacket.java
index 3958303..cef5567 100644
--- a/src/android/net/dhcp/DhcpReleasePacket.java
+++ b/src/android/net/dhcp/DhcpReleasePacket.java
@@ -22,7 +22,7 @@
 /**
  * Implements DHCP-RELEASE
  */
-class DhcpReleasePacket extends DhcpPacket {
+public class DhcpReleasePacket extends DhcpPacket {
 
     final Inet4Address mClientAddr;
 
diff --git a/src/android/net/dhcp/DhcpRequestPacket.java b/src/android/net/dhcp/DhcpRequestPacket.java
index 231d045..0672871 100644
--- a/src/android/net/dhcp/DhcpRequestPacket.java
+++ b/src/android/net/dhcp/DhcpRequestPacket.java
@@ -16,15 +16,13 @@
 
 package android.net.dhcp;
 
-import android.util.Log;
-
 import java.net.Inet4Address;
 import java.nio.ByteBuffer;
 
 /**
  * This class implements the DHCP-REQUEST packet.
  */
-class DhcpRequestPacket extends DhcpPacket {
+public class DhcpRequestPacket extends DhcpPacket {
     /**
      * Generates a REQUEST packet with the specified parameters.
      */
diff --git a/src/android/net/ip/IpNeighborMonitor.java b/src/android/net/ip/IpNeighborMonitor.java
index 6ae9a2b..803f2e6 100644
--- a/src/android/net/ip/IpNeighborMonitor.java
+++ b/src/android/net/ip/IpNeighborMonitor.java
@@ -185,12 +185,6 @@
                 break;
             }
 
-            final int srcPortId = nlMsg.getHeader().nlmsg_pid;
-            if (srcPortId !=  0) {
-                mLog.e("non-kernel source portId: " + Integer.toUnsignedLong(srcPortId));
-                break;
-            }
-
             if (nlMsg instanceof NetlinkErrorMessage) {
                 mLog.e("netlink error: " + nlMsg);
                 continue;
diff --git a/src/com/android/networkstack/util/DnsUtils.java b/src/com/android/networkstack/util/DnsUtils.java
index 2ea5ed8..759807b 100644
--- a/src/com/android/networkstack/util/DnsUtils.java
+++ b/src/com/android/networkstack/util/DnsUtils.java
@@ -47,6 +47,9 @@
 public class DnsUtils {
     // Decide what queries to make depending on what IP addresses are on the system.
     public static final int TYPE_ADDRCONFIG = -1;
+    // A one time host name suffix of private dns probe.
+    // q.v. system/netd/server/dns/DnsTlsTransport.cpp
+    public static final String PRIVATE_DNS_PROBE_HOST_SUFFIX = "-dnsotls-ds.metric.gstatic.com";
     private static final String TAG = DnsUtils.class.getSimpleName();
     private static final boolean DBG = Log.isLoggable(TAG, Log.DEBUG);
 
diff --git a/src/com/android/server/connectivity/NetworkMonitor.java b/src/com/android/server/connectivity/NetworkMonitor.java
index d4b484d..6122d98 100644
--- a/src/com/android/server/connectivity/NetworkMonitor.java
+++ b/src/com/android/server/connectivity/NetworkMonitor.java
@@ -65,6 +65,7 @@
 import static android.net.util.NetworkStackUtils.NAMESPACE_CONNECTIVITY;
 import static android.net.util.NetworkStackUtils.isEmpty;
 
+import static com.android.networkstack.util.DnsUtils.PRIVATE_DNS_PROBE_HOST_SUFFIX;
 import static com.android.networkstack.util.DnsUtils.TYPE_ADDRCONFIG;
 
 import android.annotation.NonNull;
@@ -1083,10 +1084,8 @@
         }
 
         private boolean sendPrivateDnsProbe() {
-            // q.v. system/netd/server/dns/DnsTlsTransport.cpp
-            final String oneTimeHostnameSuffix = "-dnsotls-ds.metric.gstatic.com";
             final String host = UUID.randomUUID().toString().substring(0, 8)
-                    + oneTimeHostnameSuffix;
+                    + PRIVATE_DNS_PROBE_HOST_SUFFIX;
             final Stopwatch watch = new Stopwatch().start();
             boolean success = false;
             long time;
diff --git a/tests/integration/src/android/net/ip/IpClientIntegrationTest.java b/tests/integration/src/android/net/ip/IpClientIntegrationTest.java
index ae6afdb..16e92ef 100644
--- a/tests/integration/src/android/net/ip/IpClientIntegrationTest.java
+++ b/tests/integration/src/android/net/ip/IpClientIntegrationTest.java
@@ -16,23 +16,27 @@
 
 package android.net.ip;
 
+import static android.net.dhcp.DhcpClient.EXPIRED_LEASE;
 import static android.net.dhcp.DhcpPacket.DHCP_BOOTREQUEST;
 import static android.net.dhcp.DhcpPacket.DHCP_CLIENT;
 import static android.net.dhcp.DhcpPacket.DHCP_MAGIC_COOKIE;
-import static android.net.dhcp.DhcpPacket.DHCP_MESSAGE_TYPE;
-import static android.net.dhcp.DhcpPacket.DHCP_MESSAGE_TYPE_DISCOVER;
 import static android.net.dhcp.DhcpPacket.DHCP_SERVER;
 import static android.net.dhcp.DhcpPacket.ENCAP_L2;
-import static android.net.dhcp.DhcpPacket.ETHER_BROADCAST;
-
-import static com.android.server.util.NetworkStackConstants.IPV4_ADDR_ALL;
+import static android.net.dhcp.DhcpPacket.INFINITE_LEASE;
+import static android.net.ipmemorystore.Status.SUCCESS;
+import static android.net.shared.Inet4AddressUtils.getBroadcastAddress;
+import static android.net.shared.Inet4AddressUtils.getPrefixMaskAsInet4Address;
 
 import static junit.framework.Assert.fail;
 
+import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertNotNull;
 import static org.junit.Assert.assertTrue;
 import static org.mockito.Mockito.any;
+import static org.mockito.Mockito.doAnswer;
 import static org.mockito.Mockito.eq;
+import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.timeout;
 import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
@@ -43,20 +47,27 @@
 import android.content.Context;
 import android.content.res.Resources;
 import android.net.ConnectivityManager;
-import android.net.IIpMemoryStore;
 import android.net.INetd;
+import android.net.InetAddresses;
+import android.net.NetworkStackIpMemoryStore;
 import android.net.TestNetworkInterface;
 import android.net.TestNetworkManager;
+import android.net.dhcp.DhcpClient;
 import android.net.dhcp.DhcpDiscoverPacket;
 import android.net.dhcp.DhcpPacket;
 import android.net.dhcp.DhcpPacket.ParseException;
+import android.net.dhcp.DhcpRequestPacket;
+import android.net.ipmemorystore.NetworkAttributes;
+import android.net.ipmemorystore.OnNetworkAttributesRetrievedListener;
+import android.net.ipmemorystore.Status;
 import android.net.shared.ProvisioningConfiguration;
-import android.net.util.InterfaceParams;
+import android.net.util.NetworkStackUtils;
 import android.net.util.PacketReader;
 import android.os.Handler;
 import android.os.HandlerThread;
 import android.os.IBinder;
 import android.os.ParcelFileDescriptor;
+import android.os.RemoteException;
 
 import androidx.annotation.Nullable;
 import androidx.test.InstrumentationRegistry;
@@ -64,18 +75,25 @@
 import androidx.test.runner.AndroidJUnit4;
 
 import com.android.server.NetworkObserverRegistry;
-import com.android.server.NetworkStackService;
+import com.android.server.NetworkStackService.NetworkStackServiceManager;
+import com.android.server.connectivity.ipmemorystore.IpMemoryStoreService;
 
 import org.junit.After;
 import org.junit.Before;
+import org.junit.Ignore;
 import org.junit.Test;
 import org.junit.runner.RunWith;
+import org.mockito.ArgumentCaptor;
 import org.mockito.Mock;
 import org.mockito.MockitoAnnotations;
 
 import java.io.FileDescriptor;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.net.Inet4Address;
 import java.nio.ByteBuffer;
 import java.util.Arrays;
+import java.util.Collections;
 import java.util.concurrent.LinkedBlockingQueue;
 import java.util.concurrent.TimeUnit;
 
@@ -85,35 +103,36 @@
 @RunWith(AndroidJUnit4.class)
 @SmallTest
 public class IpClientIntegrationTest {
+    private static final int DATA_BUFFER_LEN = 4096;
+    private static final int PACKET_TIMEOUT_MS = 5_000;
+    private static final int TEST_TIMEOUT_MS = 400;
+    private static final String TEST_L2KEY = "some l2key";
+    private static final String TEST_GROUPHINT = "some grouphint";
+    private static final int TEST_LEASE_DURATION_S = 3_600; // 1 hour
+
     @Mock private Context mContext;
     @Mock private ConnectivityManager mCm;
     @Mock private INetd mNetd;
     @Mock private Resources mResources;
     @Mock private IIpClientCallbacks mCb;
     @Mock private AlarmManager mAlarm;
-    @Mock private IpClient.Dependencies mDependencies;
     @Mock private ContentResolver mContentResolver;
-    @Mock private NetworkStackService.NetworkStackServiceManager mNetworkStackServiceManager;
-    @Mock private IIpMemoryStore mIpMemoryStore;
-    @Mock private InterfaceParams mInterfaceParams;
+    @Mock private NetworkStackServiceManager mNetworkStackServiceManager;
+    @Mock private NetworkStackIpMemoryStore mIpMemoryStore;
+    @Mock private IpMemoryStoreService mIpMemoryStoreService;
 
     private String mIfaceName;
     private HandlerThread mPacketReaderThread;
     private TapPacketReader mPacketReader;
     private IpClient mIpc;
-
-    private static final int DATA_BUFFER_LEN = 4096;
-    private static final long PACKET_TIMEOUT_MS = 5_000;
+    private Dependencies mDependencies;
 
     // Ethernet header
     private static final int ETH_HEADER_LEN = 14;
-    private static final int ETH_DEST_ADDR_OFFSET = 0;
-    private static final int ETH_MAC_ADDR_LEN = 6;
 
     // IP header
     private static final int IPV4_HEADER_LEN = 20;
-    private static final int IPV4_DEST_ADDR_OFFSET = ETH_HEADER_LEN + 16;
-    private static final int IPV4_ADDR_LEN = 4;
+    private static final int IPV4_SRC_ADDR_OFFSET = ETH_HEADER_LEN + 12;
 
     // UDP header
     private static final int UDP_HEADER_LEN = 8;
@@ -127,10 +146,19 @@
     private static final int DHCP_TRANSACTION_ID_OFFSET = DHCP_HEADER_OFFSET + 4;
     private static final int DHCP_OPTION_MAGIC_COOKIE_OFFSET = DHCP_HEADER_OFFSET + 236;
     private static final int DHCP_OPTION_MESSAGE_TYPE_OFFSET = DHCP_OPTION_MAGIC_COOKIE_OFFSET + 4;
-    private static final int DHCP_OPTION_MESSAGE_TYPE_LEN_OFFSET =
-            DHCP_OPTION_MESSAGE_TYPE_OFFSET + 1;
-    private static final int DHCP_OPTION_MESSAGE_TYPE_VALUE_OFFSET =
-            DHCP_OPTION_MESSAGE_TYPE_OFFSET + 2;
+
+    private static final Inet4Address SERVER_ADDR =
+            (Inet4Address) InetAddresses.parseNumericAddress("192.168.1.100");
+    private static final Inet4Address CLIENT_ADDR =
+            (Inet4Address) InetAddresses.parseNumericAddress("192.168.1.2");
+    private static final Inet4Address INADDR_ANY =
+            (Inet4Address) InetAddresses.parseNumericAddress("0.0.0.0");
+    private static final int PREFIX_LENGTH = 24;
+    private static final Inet4Address NETMASK = getPrefixMaskAsInet4Address(PREFIX_LENGTH);
+    private static final Inet4Address BROADCAST_ADDR = getBroadcastAddress(
+            SERVER_ADDR, PREFIX_LENGTH);
+    private static final String HOSTNAME = "testhostname";
+    private static final short MTU = 1500;
 
     private static class TapPacketReader extends PacketReader {
         private final ParcelFileDescriptor mTapFd;
@@ -172,17 +200,61 @@
         }
     }
 
+    private class Dependencies extends IpClient.Dependencies {
+        private boolean mIsDhcpLeaseCacheEnabled;
+        private boolean mIsDhcpRapidCommitEnabled;
+
+        public void setDhcpLeaseCacheEnabled(final boolean enable) {
+            mIsDhcpLeaseCacheEnabled = enable;
+        }
+
+        public void setDhcpRapidCommitEnabled(final boolean enable) {
+            mIsDhcpRapidCommitEnabled = enable;
+        }
+
+        @Override
+        public INetd getNetd(Context context) {
+            return mNetd;
+        }
+
+        @Override
+        public NetworkStackIpMemoryStore getIpMemoryStore(Context context,
+                NetworkStackServiceManager nssManager) {
+            return mIpMemoryStore;
+        }
+
+        @Override
+        public DhcpClient.Dependencies getDhcpClientDependencies(
+                NetworkStackIpMemoryStore ipMemoryStore) {
+            return new DhcpClient.Dependencies(ipMemoryStore) {
+                @Override
+                public boolean getBooleanDeviceConfig(final String nameSpace,
+                        final String flagName) {
+                    switch (flagName) {
+                        case NetworkStackUtils.DHCP_RAPID_COMMIT_ENABLED:
+                            return mIsDhcpRapidCommitEnabled;
+                        case NetworkStackUtils.DHCP_INIT_REBOOT_ENABLED:
+                            return mIsDhcpLeaseCacheEnabled;
+                        default:
+                            fail("Invalid experiment flag: " + flagName);
+                            return false;
+                    }
+                }
+            };
+        }
+    }
+
     @Before
     public void setUp() throws Exception {
         MockitoAnnotations.initMocks(this);
 
+        mDependencies = new Dependencies();
         when(mContext.getSystemService(eq(Context.ALARM_SERVICE))).thenReturn(mAlarm);
         when(mContext.getSystemService(eq(ConnectivityManager.class))).thenReturn(mCm);
         when(mContext.getResources()).thenReturn(mResources);
-        when(mDependencies.getNetd(any())).thenReturn(mNetd);
         when(mContext.getContentResolver()).thenReturn(mContentResolver);
-        when(mDependencies.getInterfaceParams(any())).thenReturn(mInterfaceParams);
-        when(mNetworkStackServiceManager.getIpMemoryStoreService()).thenReturn(mIpMemoryStore);
+        when(mNetworkStackServiceManager.getIpMemoryStoreService())
+                .thenReturn(mIpMemoryStoreService);
 
         setUpTapInterface();
         setUpIpClient();
@@ -196,6 +268,7 @@
         if (mPacketReaderThread != null) {
             mPacketReaderThread.quitSafely();
         }
+        mIpc.shutdown();
     }
 
     private void setUpTapInterface() {
@@ -233,7 +306,8 @@
 
         final NetworkObserverRegistry reg = new NetworkObserverRegistry();
         reg.register(netd);
-        mIpc = new IpClient(mContext, mIfaceName, mCb, reg, mNetworkStackServiceManager);
+        mIpc = new IpClient(mContext, mIfaceName, mCb, reg, mNetworkStackServiceManager,
+                mDependencies);
     }
 
     private boolean packetContainsExpectedField(final byte[] packet, final int offset,
@@ -270,43 +344,253 @@
         return true;
     }
 
-    private void verifyDhcpDiscoverPacketReceived(final byte[] packet)
-            throws ParseException {
-        assertTrue(packetContainsExpectedField(packet, ETH_DEST_ADDR_OFFSET, ETHER_BROADCAST));
-        assertTrue(packetContainsExpectedField(packet, IPV4_DEST_ADDR_OFFSET,
-                IPV4_ADDR_ALL.getAddress()));
+    private static ByteBuffer buildDhcpOfferPacket(final DhcpPacket packet,
+            final Integer leaseTimeSec) {
+        return DhcpPacket.buildOfferPacket(DhcpPacket.ENCAP_L2, packet.getTransactionId(),
+                false /* broadcast */, SERVER_ADDR, INADDR_ANY /* relayIp */,
+                CLIENT_ADDR /* yourIp */, packet.getClientMac(), leaseTimeSec,
+                NETMASK /* netMask */, BROADCAST_ADDR /* bcAddr */,
+                Collections.singletonList(SERVER_ADDR) /* gateways */,
+                Collections.singletonList(SERVER_ADDR) /* dnsServers */,
+                SERVER_ADDR /* dhcpServerIdentifier */, null /* domainName */, HOSTNAME,
+                false /* metered */, MTU);
+    }
 
-        // check if received dhcp packet includes DHCP Message Type option and expected
-        // type/length/value.
-        assertTrue(packet[DHCP_OPTION_MESSAGE_TYPE_OFFSET] == DHCP_MESSAGE_TYPE);
-        assertTrue(packet[DHCP_OPTION_MESSAGE_TYPE_OFFSET + 1] == 1);
-        assertTrue(packet[DHCP_OPTION_MESSAGE_TYPE_OFFSET + 2] == DHCP_MESSAGE_TYPE_DISCOVER);
-        final DhcpPacket dhcpPacket = DhcpPacket.decodeFullPacket(
-                packet, packet.length, ENCAP_L2);
-        assertTrue(dhcpPacket instanceof DhcpDiscoverPacket);
+    private static ByteBuffer buildDhcpAckPacket(final DhcpPacket packet,
+            final Integer leaseTimeSec) {
+        return DhcpPacket.buildAckPacket(DhcpPacket.ENCAP_L2, packet.getTransactionId(),
+                false /* broadcast */, SERVER_ADDR, INADDR_ANY /* relayIp */,
+                CLIENT_ADDR /* yourIp */, CLIENT_ADDR /* requestIp */, packet.getClientMac(),
+                leaseTimeSec, NETMASK /* netMask */, BROADCAST_ADDR /* bcAddr */,
+                Collections.singletonList(SERVER_ADDR) /* gateways */,
+                Collections.singletonList(SERVER_ADDR) /* dnsServers */,
+                SERVER_ADDR /* dhcpServerIdentifier */, null /* domainName */, HOSTNAME,
+                false /* metered */, MTU);
+    }
+
+    private static ByteBuffer buildDhcpNakPacket(final DhcpPacket packet) {
+        return DhcpPacket.buildNakPacket(DhcpPacket.ENCAP_L2, packet.getTransactionId(),
+            SERVER_ADDR /* serverIp */, INADDR_ANY /* relayIp */, packet.getClientMac(),
+            false /* broadcast */, "duplicated request IP address");
+    }
+
+    private void sendResponse(final ByteBuffer packet) throws IOException {
+        try (FileOutputStream out = new FileOutputStream(mPacketReader.createFd())) {
+            out.write(packet.array());
+        }
+    }
+
+    private void startIpClientProvisioning(final boolean isDhcpLeaseCacheEnabled,
+            final boolean isDhcpRapidCommitEnabled) throws RemoteException {
+        ProvisioningConfiguration config = new ProvisioningConfiguration.Builder()
+                .withoutIpReachabilityMonitor()
+                .withoutIPv6()
+                .build();
+
+        mDependencies.setDhcpLeaseCacheEnabled(isDhcpLeaseCacheEnabled);
+        mDependencies.setDhcpRapidCommitEnabled(isDhcpRapidCommitEnabled);
+        mIpc.setL2KeyAndGroupHint(TEST_L2KEY, TEST_GROUPHINT);
+        mIpc.startProvisioning(config);
+        verify(mCb, times(1)).setNeighborDiscoveryOffload(true);
+        verify(mCb, timeout(TEST_TIMEOUT_MS).times(1)).setFallbackMulticastFilter(false);
+        verify(mCb, never()).onProvisioningFailure(any());
+    }
+
+    private void assertIpMemoryStoreNetworkAttributes(final Integer leaseTimeSec,
+            final long startTime) {
+        final ArgumentCaptor<NetworkAttributes> networkAttributes =
+                ArgumentCaptor.forClass(NetworkAttributes.class);
+
+        verify(mIpMemoryStore, timeout(TEST_TIMEOUT_MS))
+            .storeNetworkAttributes(eq(TEST_L2KEY), networkAttributes.capture(), any());
+        final NetworkAttributes naValueCaptured = networkAttributes.getValue();
+        assertEquals(CLIENT_ADDR, naValueCaptured.assignedV4Address);
+        if (leaseTimeSec == null || leaseTimeSec.intValue() == DhcpPacket.INFINITE_LEASE) {
+            assertEquals(Long.MAX_VALUE, naValueCaptured.assignedV4AddressExpiry.longValue());
+        } else {
+            // check the lease expiry's scope
+            final long upperBound = startTime + 7_200_000; // start timestamp + 2h
+            final long lowerBound = startTime + 3_600_000; // start timestamp + 1h
+            final long expiry = naValueCaptured.assignedV4AddressExpiry;
+            assertTrue(upperBound > expiry);
+            assertTrue(lowerBound < expiry);
+        }
+        assertEquals(Collections.singletonList(SERVER_ADDR), naValueCaptured.dnsAddresses);
+        assertEquals(new Integer((int) MTU), naValueCaptured.mtu);
+    }
+
+    private void assertIpMemoryNeverStoreNetworkAttributes() {
+        verify(mIpMemoryStore, never()).storeNetworkAttributes(any(), any(), any());
+    }
+
+    // Helper method to complete DHCP 2-way or 4-way handshake
+    private void performDhcpHandshake(final boolean isSuccessLease,
+            final Integer leaseTimeSec, final boolean isDhcpLeaseCacheEnabled,
+            final boolean isDhcpRapidCommitEnabled) throws Exception {
+        startIpClientProvisioning(isDhcpLeaseCacheEnabled, isDhcpRapidCommitEnabled);
+
+        DhcpPacket packet;
+        while ((packet = getNextDhcpPacket()) != null) {
+            if (packet instanceof DhcpDiscoverPacket) {
+                if (isDhcpRapidCommitEnabled) {
+                    sendResponse(buildDhcpAckPacket(packet, leaseTimeSec));
+                } else {
+                    sendResponse(buildDhcpOfferPacket(packet, leaseTimeSec));
+                }
+            } else if (packet instanceof DhcpRequestPacket) {
+                final ByteBuffer byteBuffer = isSuccessLease
+                        ? buildDhcpAckPacket(packet, leaseTimeSec)
+                        : buildDhcpNakPacket(packet);
+                sendResponse(byteBuffer);
+            } else {
+                fail("invalid DHCP packet");
+            }
+            // wait for reply to DHCPOFFER packet if disabling rapid commit option
+            if (isDhcpRapidCommitEnabled || !(packet instanceof DhcpDiscoverPacket)) return;
+        }
+        fail("No DHCPREQUEST received on interface");
+    }
+
+    private DhcpPacket getNextDhcpPacket() throws ParseException {
+        byte[] packet;
+        while ((packet = mPacketReader.popPacket(PACKET_TIMEOUT_MS)) != null) {
+            if (!isDhcpPacket(packet)) continue;
+            return DhcpPacket.decodeFullPacket(packet, packet.length, ENCAP_L2);
+        }
+        fail("No expected DHCP packet received on interface within timeout");
+        return null;
+    }
+
+    private DhcpPacket getReplyFromDhcpLease(final NetworkAttributes na, boolean timeout)
+            throws Exception {
+        doAnswer(invocation -> {
+            if (timeout) return null;
+            ((OnNetworkAttributesRetrievedListener) invocation.getArgument(1))
+                    .onNetworkAttributesRetrieved(new Status(SUCCESS), TEST_L2KEY, na);
+            return null;
+        }).when(mIpMemoryStore).retrieveNetworkAttributes(eq(TEST_L2KEY), any());
+        startIpClientProvisioning(true /* isDhcpLeaseCacheEnabled */,
+                false /* isDhcpRapidCommitEnabled */);
+        return getNextDhcpPacket();
     }
 
     @Test
     public void testDhcpInit() throws Exception {
-        ProvisioningConfiguration config = new ProvisioningConfiguration.Builder()
-                .withoutIpReachabilityMonitor()
-                .build();
+        startIpClientProvisioning(false /* isDhcpLeaseCacheEnabled */,
+                false /* isDhcpRapidCommitEnabled */);
+        final DhcpPacket packet = getNextDhcpPacket();
+        assertTrue(DhcpDiscoverPacket.class.isInstance(packet));
+    }
 
-        mIpc.startProvisioning(config);
-        verify(mCb, times(1)).setNeighborDiscoveryOffload(true);
+    @Test
+    public void testHandleSuccessDhcpLease() throws Exception {
+        final long currentTime = System.currentTimeMillis();
+        performDhcpHandshake(true /* isSuccessLease */, TEST_LEASE_DURATION_S,
+                true /* isDhcpLeaseCacheEnabled */, false /* isDhcpRapidCommitEnabled */);
+        assertIpMemoryStoreNetworkAttributes(TEST_LEASE_DURATION_S, currentTime);
+    }
 
-        byte[] packet;
-        while ((packet = mPacketReader.popPacket(PACKET_TIMEOUT_MS)) != null) {
-            try {
-                if (!isDhcpPacket(packet)) continue;
-                verifyDhcpDiscoverPacketReceived(packet);
-                mIpc.shutdown();
-                return;
-            } catch (DhcpPacket.ParseException e) {
-                fail("parse exception: " + e);
-            }
-        }
+    @Test
+    public void testHandleFailureDhcpLease() throws Exception {
+        performDhcpHandshake(false /* isSuccessLease */, TEST_LEASE_DURATION_S,
+                true /* isDhcpLeaseCacheEnabled */, false /* isDhcpRapidCommitEnabled */);
+        assertIpMemoryNeverStoreNetworkAttributes();
+    }
 
-        fail("No DHCPDISCOVER received on interface");
+    @Test
+    public void testHandleInfiniteLease() throws Exception {
+        final long currentTime = System.currentTimeMillis();
+        performDhcpHandshake(true /* isSuccessLease */, INFINITE_LEASE,
+                true /* isDhcpLeaseCacheEnabled */, false /* isDhcpRapidCommitEnabled */);
+        assertIpMemoryStoreNetworkAttributes(INFINITE_LEASE, currentTime);
+    }
+
+    @Test
+    public void testHandleNoLease() throws Exception {
+        final long currentTime = System.currentTimeMillis();
+        performDhcpHandshake(true /* isSuccessLease */, null /* no lease time */,
+                true /* isDhcpLeaseCacheEnabled */, false /* isDhcpRapidCommitEnabled */);
+        assertIpMemoryStoreNetworkAttributes(null, currentTime);
+    }
+
+    @Test
+    public void testHandleDisableInitRebootState() throws Exception {
+        performDhcpHandshake(true /* isSuccessLease */, TEST_LEASE_DURATION_S,
+                false /* isDhcpLeaseCacheEnabled */, false /* isDhcpRapidCommitEnabled */);
+        assertIpMemoryNeverStoreNetworkAttributes();
+    }
+
+    @Ignore
+    @Test
+    public void testHandleRapidCommitOption() throws Exception {
+        // TODO: remove @Ignore after supporting rapid commit option in DHCP server
+        final long currentTime = System.currentTimeMillis();
+        performDhcpHandshake(true /* isSuccessLease */, TEST_LEASE_DURATION_S,
+                true /* isDhcpLeaseCacheEnabled */, true /* isDhcpRapidCommitEnabled */);
+        assertIpMemoryStoreNetworkAttributes(TEST_LEASE_DURATION_S, currentTime);
+    }
+
+    @Test
+    public void testDhcpClientStartWithCachedInfiniteLease() throws Exception {
+        final DhcpPacket packet = getReplyFromDhcpLease(
+                new NetworkAttributes.Builder()
+                    .setAssignedV4Address(CLIENT_ADDR)
+                    .setAssignedV4AddressExpiry(Long.MAX_VALUE) // lease is always valid
+                    .setMtu(new Integer(MTU))
+                    .setGroupHint(TEST_GROUPHINT)
+                    .setDnsAddresses(Collections.singletonList(SERVER_ADDR))
+                    .build(), false /* timeout */);
+        assertTrue(DhcpRequestPacket.class.isInstance(packet));
+    }
+
+    @Test
+    public void testDhcpClientStartWithCachedExpiredLease() throws Exception {
+        final DhcpPacket packet = getReplyFromDhcpLease(
+                 new NetworkAttributes.Builder()
+                    .setAssignedV4Address(CLIENT_ADDR)
+                    .setAssignedV4AddressExpiry(EXPIRED_LEASE)
+                    .setMtu(new Integer(MTU))
+                    .setGroupHint(TEST_GROUPHINT)
+                    .setDnsAddresses(Collections.singletonList(SERVER_ADDR))
+                    .build(), false /* timeout */);
+        assertTrue(DhcpDiscoverPacket.class.isInstance(packet));
+    }
+
+    @Test
+    public void testDhcpClientStartWithNullRetrieveNetworkAttributes() throws Exception {
+        final DhcpPacket packet = getReplyFromDhcpLease(null /* na */, false /* timeout */);
+        assertTrue(DhcpDiscoverPacket.class.isInstance(packet));
+    }
+
+    @Test
+    public void testDhcpClientStartWithTimeoutRetrieveNetworkAttributes() throws Exception {
+        final DhcpPacket packet = getReplyFromDhcpLease(
+                new NetworkAttributes.Builder()
+                    .setAssignedV4Address(CLIENT_ADDR)
+                    .setAssignedV4AddressExpiry(System.currentTimeMillis() + 3_600_000)
+                    .setMtu(new Integer(MTU))
+                    .setGroupHint(TEST_GROUPHINT)
+                    .setDnsAddresses(Collections.singletonList(SERVER_ADDR))
+                    .build(), true /* timeout */);
+        assertTrue(DhcpDiscoverPacket.class.isInstance(packet));
+    }
+
+    @Test
+    public void testDhcpClientStartWithCachedLeaseWithoutIPAddress() throws Exception {
+        final DhcpPacket packet = getReplyFromDhcpLease(
+                new NetworkAttributes.Builder()
+                    .setMtu(new Integer(MTU))
+                    .setGroupHint(TEST_GROUPHINT)
+                    .setDnsAddresses(Collections.singletonList(SERVER_ADDR))
+                    .build(), false /* timeout */);
+        assertTrue(DhcpDiscoverPacket.class.isInstance(packet));
+    }
+
+    @Test
+    public void testDhcpClientRapidCommitEnabled() throws Exception {
+        startIpClientProvisioning(true /* isDhcpLeaseCacheEnabled */,
+                true /* isDhcpRapidCommitEnabled */);
+        final DhcpPacket packet = getNextDhcpPacket();
+        assertTrue(DhcpDiscoverPacket.class.isInstance(packet));
     }
 }
diff --git a/tests/lib/src/com/android/testutils/MiscAsserts.kt b/tests/lib/src/com/android/testutils/MiscAsserts.kt
index 63aedd6..5019dcd 100644
--- a/tests/lib/src/com/android/testutils/MiscAsserts.kt
+++ b/tests/lib/src/com/android/testutils/MiscAsserts.kt
@@ -17,6 +17,7 @@
 package com.android.testutils
 
 import android.util.Log
+import com.android.testutils.ExceptionUtils.ThrowingRunnable
 import java.lang.reflect.Modifier
 import kotlin.system.measureTimeMillis
 import kotlin.test.assertEquals
@@ -36,10 +37,14 @@
 
 // Bridge method to help write this in Java. If you're writing Kotlin, consider using native
 // kotlin.test.assertFailsWith instead, as that method is reified and inlined.
-fun <T : Exception> assertThrows(expected: Class<T>, block: ExceptionUtils.ThrowingRunnable): T {
+fun <T : Exception> assertThrows(expected: Class<T>, block: ThrowingRunnable): T {
     return assertFailsWith(expected.kotlin) { block.run() }
 }
 
+fun <T : Exception> assertThrows(msg: String, expected: Class<T>, block: ThrowingRunnable): T {
+    return assertFailsWith(expected.kotlin, msg) { block.run() }
+}
+
 fun <T> assertEqualBothWays(o1: T, o2: T) {
     assertTrue(o1 == o2)
     assertTrue(o2 == o1)
diff --git a/tests/unit/src/android/net/dhcp/DhcpPacketTest.java b/tests/unit/src/android/net/dhcp/DhcpPacketTest.java
index 4429add..fcaf655 100644
--- a/tests/unit/src/android/net/dhcp/DhcpPacketTest.java
+++ b/tests/unit/src/android/net/dhcp/DhcpPacketTest.java
@@ -45,8 +45,8 @@
 
 import android.annotation.Nullable;
 import android.net.DhcpResults;
+import android.net.InetAddresses;
 import android.net.LinkAddress;
-import android.net.NetworkUtils;
 import android.net.metrics.DhcpErrorEvent;
 
 import androidx.test.filters.SmallTest;
@@ -86,7 +86,7 @@
     private static final byte[] CLIENT_MAC = new byte[] { 0x00, 0x01, 0x02, 0x03, 0x04, 0x05 };
 
     private static final Inet4Address v4Address(String addrString) throws IllegalArgumentException {
-        return (Inet4Address) NetworkUtils.numericToInetAddress(addrString);
+        return (Inet4Address) InetAddresses.parseNumericAddress(addrString);
     }
 
     @Before
diff --git a/tests/unit/src/android/net/testutils/TrackRecordTest.kt b/tests/unit/src/android/net/testutils/TrackRecordTest.kt
index ff28933..f9d3558 100644
--- a/tests/unit/src/android/net/testutils/TrackRecordTest.kt
+++ b/tests/unit/src/android/net/testutils/TrackRecordTest.kt
@@ -181,7 +181,7 @@
         var delay = measureTimeMillis { assertNull(record.poll(SHORT_TIMEOUT, 0)) }
         assertTrue(delay >= SHORT_TIMEOUT, "Delay $delay < $SHORT_TIMEOUT")
         delay = measureTimeMillis { assertNull(record.poll(SHORT_TIMEOUT, 0) { it < 10 }) }
-        assertTrue(delay > SHORT_TIMEOUT)
+        assertTrue(delay >= SHORT_TIMEOUT)
     }
 
     @Test
@@ -209,7 +209,7 @@
         assertTrue(delay >= SHORT_TIMEOUT, "Delay $delay < $SHORT_TIMEOUT")
         // Polling for an element that doesn't match what is already there
         delay = measureTimeMillis { assertNull(record.poll(SHORT_TIMEOUT, 0) { it < 10 }) }
-        assertTrue(delay > SHORT_TIMEOUT)
+        assertTrue(delay >= SHORT_TIMEOUT)
     }
 
     // Just make sure the interpreter actually throws an exception when the spec
diff --git a/tests/unit/src/android/networkstack/util/DnsUtilsTest.kt b/tests/unit/src/android/networkstack/util/DnsUtilsTest.kt
new file mode 100644
index 0000000..815fc60
--- /dev/null
+++ b/tests/unit/src/android/networkstack/util/DnsUtilsTest.kt
@@ -0,0 +1,92 @@
+/*
+ * Copyright (C) 2019 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package android.networkstack.util
+
+import android.net.DnsResolver
+import android.net.DnsResolver.FLAG_EMPTY
+import android.net.DnsResolver.TYPE_A
+import android.net.DnsResolver.TYPE_AAAA
+import android.net.Network
+import androidx.test.filters.SmallTest
+import androidx.test.runner.AndroidJUnit4
+import com.android.networkstack.util.DnsUtils
+import com.android.networkstack.util.DnsUtils.TYPE_ADDRCONFIG
+import com.android.server.connectivity.NetworkMonitor.DnsLogFunc
+import java.net.InetAddress
+import java.net.UnknownHostException
+import kotlin.test.assertFailsWith
+import org.junit.Assert.assertArrayEquals
+import org.junit.Before
+import org.junit.Test
+import org.junit.runner.RunWith
+import org.mockito.Mock
+import org.mockito.MockitoAnnotations
+
+const val DEFAULT_TIMEOUT_MS = 1000
+const val SHORT_TIMEOUT_MS = 200
+
+@RunWith(AndroidJUnit4::class)
+@SmallTest
+class DnsUtilsTest {
+    val fakeNetwork: Network = Network(1234)
+    @Mock
+    lateinit var mockLogger: DnsLogFunc
+    @Mock
+    lateinit var mockResolver: DnsResolver
+    lateinit var fakeDns: FakeDns
+
+    @Before
+    fun setup() {
+        MockitoAnnotations.initMocks(this)
+        fakeDns = FakeDns(mockResolver)
+        fakeDns.startMocking()
+    }
+
+    private fun assertIpAddressArrayEquals(expect: Array<String>, actual: Array<InetAddress>) =
+            assertArrayEquals("Array of IP addresses differs", expect,
+                    actual.map { it.getHostAddress() }.toTypedArray())
+
+    @Test
+    fun testGetAllByNameWithTypeSuccess() {
+        // Test different query types.
+        verifyGetAllByName("www.google.com", arrayOf("2001:db8::1"), TYPE_AAAA)
+        verifyGetAllByName("www.google.com", arrayOf("192.168.0.1"), TYPE_A)
+        verifyGetAllByName("www.android.com", arrayOf("192.168.0.2", "2001:db8::2"),
+                TYPE_ADDRCONFIG)
+    }
+
+    private fun verifyGetAllByName(name: String, expected: Array<String>, type: Int) {
+        fakeDns.setAnswer(name, expected, type)
+        DnsUtils.getAllByName(mockResolver, fakeNetwork, name, type, FLAG_EMPTY, DEFAULT_TIMEOUT_MS,
+                mockLogger).let { assertIpAddressArrayEquals(expected, it) }
+    }
+
+    @Test
+    fun testGetAllByNameWithTypeNoResult() {
+        verifyGetAllByNameFails("www.android.com", TYPE_A)
+        verifyGetAllByNameFails("www.android.com", TYPE_AAAA)
+        verifyGetAllByNameFails("www.android.com", TYPE_ADDRCONFIG)
+    }
+
+    private fun verifyGetAllByNameFails(name: String, type: Int) {
+        assertFailsWith<UnknownHostException> {
+            DnsUtils.getAllByName(mockResolver, fakeNetwork, name, type,
+                    FLAG_EMPTY, SHORT_TIMEOUT_MS, mockLogger)
+        }
+    }
+    // TODO: Add more tests. Verify timeout, logger and error.
+}
\ No newline at end of file
diff --git a/tests/unit/src/android/networkstack/util/FakeDns.kt b/tests/unit/src/android/networkstack/util/FakeDns.kt
new file mode 100644
index 0000000..f0d44d0
--- /dev/null
+++ b/tests/unit/src/android/networkstack/util/FakeDns.kt
@@ -0,0 +1,94 @@
+/*
+ * Copyright (C) 2019 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package android.networkstack.util
+
+import android.net.DnsResolver
+import android.net.InetAddresses
+import android.os.Looper
+import android.os.Handler
+import com.android.internal.annotations.GuardedBy
+import com.android.networkstack.util.DnsUtils.TYPE_ADDRCONFIG
+import java.net.InetAddress
+import java.util.concurrent.Executor
+import org.mockito.invocation.InvocationOnMock
+import org.mockito.Mockito.any
+import org.mockito.Mockito.anyInt
+import org.mockito.Mockito.doAnswer
+
+// TODO: Integrate with NetworkMonitorTest.
+class FakeDns(val mockResolver: DnsResolver) {
+    class DnsEntry(val hostname: String, val type: Int, val addresses: List<InetAddress>) {
+        fun match(host: String, type: Int) = hostname.equals(host) && type == type
+    }
+
+    @GuardedBy("answers")
+    val answers = ArrayList<DnsEntry>()
+
+    fun getAnswer(hostname: String, type: Int): DnsEntry? = synchronized(answers) {
+        return answers.firstOrNull { it.match(hostname, type) }
+    }
+
+    fun setAnswer(hostname: String, answer: Array<String>, type: Int) = synchronized(answers) {
+        val ans = DnsEntry(hostname, type, generateAnswer(answer))
+        // Replace or remove the existing one.
+        when (val index = answers.indexOfFirst { it.match(hostname, type) }) {
+            -1 -> answers.add(ans)
+            else -> answers[index] = ans
+        }
+    }
+
+    private fun generateAnswer(answer: Array<String>) =
+            answer.filterNotNull().map { InetAddresses.parseNumericAddress(it) }
+
+    fun startMocking() {
+        // Mock DnsResolver.query() w/o type
+        doAnswer {
+            mockAnswer(it, 1, -1, 3, 5)
+        }.`when`(mockResolver).query(any() /* network */, any() /* domain */, anyInt() /* flags */,
+                any() /* executor */, any() /* cancellationSignal */, any() /*callback*/)
+        // Mock DnsResolver.query() w/ type
+        doAnswer {
+            mockAnswer(it, 1, 2, 4, 6)
+        }.`when`(mockResolver).query(any() /* network */, any() /* domain */, anyInt() /* nsType */,
+                anyInt() /* flags */, any() /* executor */, any() /* cancellationSignal */,
+        any() /*callback*/)
+    }
+
+    private fun mockAnswer(
+        it: InvocationOnMock,
+        posHos: Int,
+        posType: Int,
+        posExecutor: Int,
+        posCallback: Int
+    ) {
+        val hostname = it.arguments[posHos] as String
+        val executor = it.arguments[posExecutor] as Executor
+        val callback = it.arguments[posCallback] as DnsResolver.Callback<List<InetAddress>>
+        var type = if (posType != -1) it.arguments[posType] as Int else TYPE_ADDRCONFIG
+        val answer = getAnswer(hostname, type)
+
+        if (!answer?.addresses.isNullOrEmpty()) {
+            Handler(Looper.getMainLooper()).post({ executor.execute({
+                    callback.onAnswer(answer?.addresses, 0); }) })
+        }
+    }
+
+    /** Clears all entries. */
+    fun clearAll() = synchronized(answers) {
+        answers.clear()
+    }
+}
\ No newline at end of file
diff --git a/tests/unit/src/com/android/server/connectivity/NetworkMonitorTest.java b/tests/unit/src/com/android/server/connectivity/NetworkMonitorTest.java
index 8f0974d..e8bc8d2 100644
--- a/tests/unit/src/com/android/server/connectivity/NetworkMonitorTest.java
+++ b/tests/unit/src/com/android/server/connectivity/NetworkMonitorTest.java
@@ -17,6 +17,8 @@
 package com.android.server.connectivity;
 
 import static android.net.CaptivePortal.APP_RETURN_DISMISSED;
+import static android.net.DnsResolver.TYPE_A;
+import static android.net.DnsResolver.TYPE_AAAA;
 import static android.net.INetworkMonitor.NETWORK_VALIDATION_PROBE_DNS;
 import static android.net.INetworkMonitor.NETWORK_VALIDATION_PROBE_FALLBACK;
 import static android.net.INetworkMonitor.NETWORK_VALIDATION_PROBE_HTTP;
@@ -34,6 +36,8 @@
 import static android.net.util.NetworkStackUtils.CAPTIVE_PORTAL_OTHER_FALLBACK_URLS;
 import static android.net.util.NetworkStackUtils.CAPTIVE_PORTAL_USE_HTTPS;
 
+import static com.android.networkstack.util.DnsUtils.PRIVATE_DNS_PROBE_HOST_SUFFIX;
+
 import static junit.framework.Assert.assertEquals;
 import static junit.framework.Assert.assertFalse;
 
@@ -63,7 +67,6 @@
 import android.net.ConnectivityManager;
 import android.net.DnsResolver;
 import android.net.INetworkMonitorCallbacks;
-import android.net.InetAddresses;
 import android.net.LinkProperties;
 import android.net.Network;
 import android.net.NetworkCapabilities;
@@ -89,6 +92,7 @@
 import androidx.test.filters.SmallTest;
 import androidx.test.runner.AndroidJUnit4;
 
+import com.android.internal.util.CollectionUtils;
 import com.android.networkstack.R;
 import com.android.networkstack.metrics.DataStallDetectionStats;
 import com.android.networkstack.metrics.DataStallStatsUtils;
@@ -112,6 +116,7 @@
 import java.net.URL;
 import java.net.UnknownHostException;
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Random;
@@ -196,7 +201,23 @@
      * Network#getAllByName and by DnsResolver#query.
      */
     class FakeDns {
-        private final ArrayMap<String, List<InetAddress>> mAnswers = new ArrayMap<>();
+        /** Data class to record the Dns entry. */
+        class DnsEntry {
+            final String mHostname;
+            final int mType;
+            final List<InetAddress> mAddresses;
+            DnsEntry(String host, int type, List<InetAddress> addr) {
+                mHostname = host;
+                mType = type;
+                mAddresses = addr;
+            }
+            // Full match or partial match that target host contains the entry hostname to support
+            // random private dns probe hostname.
+            private boolean matches(String hostname, int type) {
+                return hostname.endsWith(mHostname) && type == mType;
+            }
+        }
+        private final ArrayList<DnsEntry> mAnswers = new ArrayList<DnsEntry>();
         private boolean mNonBypassPrivateDnsWorking = true;
 
         /** Whether DNS queries on mNonBypassPrivateDnsWorking should succeed. */
@@ -209,41 +230,57 @@
             mAnswers.clear();
         }
 
-        /** Returns the answer for a given name on the given mock network. */
-        private synchronized List<InetAddress> getAnswer(Object mock, String hostname) {
+        /** Returns the answer for a given name and type on the given mock network. */
+        private synchronized List<InetAddress> getAnswer(Object mock, String hostname, int type) {
             if (mock == mNetwork && !mNonBypassPrivateDnsWorking) {
                 return null;
             }
-            if (mAnswers.containsKey(hostname)) {
-                return mAnswers.get(hostname);
-            }
-            return mAnswers.get("*");
+
+            DnsEntry answer = CollectionUtils.find(mAnswers, e -> e.matches(hostname, type));
+            if (answer != null) return answer.mAddresses;
+            else return null;
         }
 
-        /** Sets the answer for a given name. */
-        private synchronized void setAnswer(String hostname, String[] answer)
+        /** Sets the answer for a given name and type. */
+        private synchronized void setAnswer(String hostname, String[] answer, int type)
                 throws UnknownHostException {
-            if (answer == null) {
-                mAnswers.remove(hostname);
-            } else {
-                List<InetAddress> answerList = new ArrayList<>();
-                for (String addr : answer) {
-                    answerList.add(InetAddresses.parseNumericAddress(addr));
-                }
-                mAnswers.put(hostname, answerList);
-            }
+            DnsEntry record = new DnsEntry(hostname, type, generateAnswer(answer));
+            // Remove the existing one.
+            mAnswers.removeIf(entry -> entry.matches(hostname, type));
+            // Add or replace a new record.
+            mAnswers.add(record);
+        }
+
+        private List<InetAddress> generateAnswer(String[] answer) {
+            if (answer == null) return new ArrayList<>();
+            return CollectionUtils.map(Arrays.asList(answer),
+                    addr -> InetAddress.parseNumericAddress(addr));
         }
 
         /** Simulates a getAllByName call for the specified name on the specified mock network. */
         private InetAddress[] getAllByName(Object mock, String hostname)
                 throws UnknownHostException {
-            List<InetAddress> answer = getAnswer(mock, hostname);
+            List<InetAddress> answer = queryAllTypes(mock, hostname);
             if (answer == null || answer.size() == 0) {
                 throw new UnknownHostException(hostname);
             }
             return answer.toArray(new InetAddress[0]);
         }
 
+        // Regardless of the type, depends on what the responses contained in the network.
+        private List<InetAddress> queryAllTypes(Object mock, String hostname) {
+            List<InetAddress> answer = new ArrayList<>();
+            addAllIfNotNull(answer, getAnswer(mock, hostname, TYPE_A));
+            addAllIfNotNull(answer, getAnswer(mock, hostname, TYPE_AAAA));
+            return answer;
+        }
+
+        private void addAllIfNotNull(List<InetAddress> list, List<InetAddress> c) {
+            if (c != null) {
+                list.addAll(c);
+            }
+        }
+
         /** Starts mocking DNS queries. */
         private void startMocking() throws UnknownHostException {
             // Queries on mNetwork using getAllByName.
@@ -254,24 +291,28 @@
             // Queries on mCleartextDnsNetwork using DnsResolver#query.
             doAnswer(invocation -> {
                 return mockQuery(invocation, 1 /* posHostname */, 3 /* posExecutor */,
-                        5 /* posCallback */);
+                        5 /* posCallback */, -1 /* posType */);
             }).when(mDnsResolver).query(any(), any(), anyInt(), any(), any(), any());
 
             // Queries on mCleartextDnsNetwork using DnsResolver#query with QueryType.
             doAnswer(invocation -> {
                 return mockQuery(invocation, 1 /* posHostname */, 4 /* posExecutor */,
-                        6 /* posCallback */);
+                        6 /* posCallback */, 2 /* posType */);
             }).when(mDnsResolver).query(any(), any(), anyInt(), anyInt(), any(), any(), any());
         }
 
         // Mocking queries on DnsResolver#query.
         private Answer mockQuery(InvocationOnMock invocation, int posHostname, int posExecutor,
-                int posCallback) {
+                int posCallback, int posType) {
             String hostname = (String) invocation.getArgument(posHostname);
             Executor executor = (Executor) invocation.getArgument(posExecutor);
             DnsResolver.Callback<List<InetAddress>> callback = invocation.getArgument(posCallback);
+            List<InetAddress> answer;
 
-            List<InetAddress> answer = getAnswer(invocation.getMock(), hostname);
+            answer = posType != -1
+                    ? getAnswer(invocation.getMock(), hostname, invocation.getArgument(posType)) :
+                    queryAllTypes(invocation.getMock(), hostname);
+
             if (answer != null && answer.size() > 0) {
                 new Handler(Looper.getMainLooper()).post(() -> {
                     executor.execute(() -> callback.onAnswer(answer, 0));
@@ -341,7 +382,13 @@
 
         mFakeDns = new FakeDns();
         mFakeDns.startMocking();
-        mFakeDns.setAnswer("*", new String[]{"2001:db8::1", "192.0.2.2"});
+        // Set private dns suffix answer. sendPrivateDnsProbe() in NetworkMonitor send probe with
+        // one time hostname. The hostname will be [random generated UUID] + HOST_SUFFIX differently
+        // each time. That means the host answer cannot be pre-set into the answer list. Thus, set
+        // the host suffix and use partial match in FakeDns to match the target host and reply the
+        // intended answer.
+        mFakeDns.setAnswer(PRIVATE_DNS_PROBE_HOST_SUFFIX, new String[]{"192.0.2.2"}, TYPE_A);
+        mFakeDns.setAnswer(PRIVATE_DNS_PROBE_HOST_SUFFIX, new String[]{"2001:db8::1"}, TYPE_AAAA);
 
         when(mContext.registerReceiver(any(BroadcastReceiver.class), any())).then((invocation) -> {
             mRegisteredReceivers.add(invocation.getArgument(0));
@@ -743,20 +790,39 @@
     public void testPrivateDnsSuccess() throws Exception {
         setStatus(mHttpsConnection, 204);
         setStatus(mHttpConnection, 204);
-        mFakeDns.setAnswer("dns.google", new String[]{"2001:db8::53"});
 
+        // Verify dns query only get v6 address.
+        mFakeDns.setAnswer("dns6.google", new String[]{"2001:db8::53"}, TYPE_AAAA);
         WrappedNetworkMonitor wnm = makeNotMeteredNetworkMonitor();
-        wnm.notifyPrivateDnsSettingsChanged(new PrivateDnsConfig("dns.google", new InetAddress[0]));
+        wnm.notifyPrivateDnsSettingsChanged(new PrivateDnsConfig("dns6.google",
+                new InetAddress[0]));
         wnm.notifyNetworkConnected(TEST_LINK_PROPERTIES, NOT_METERED_CAPABILITIES);
         verify(mCallbacks, timeout(HANDLER_TIMEOUT_MS).times(1))
                 .notifyNetworkTested(eq(VALIDATION_RESULT_VALID | NETWORK_VALIDATION_PROBE_PRIVDNS),
                         eq(null));
+
+        // Verify dns query only get v4 address.
+        resetCallbacks();
+        mFakeDns.setAnswer("dns4.google", new String[]{"192.0.2.1"}, TYPE_A);
+        wnm.notifyPrivateDnsSettingsChanged(new PrivateDnsConfig("dns4.google",
+                new InetAddress[0]));
+        verify(mCallbacks, timeout(HANDLER_TIMEOUT_MS).times(1))
+                .notifyNetworkTested(eq(VALIDATION_RESULT_VALID | NETWORK_VALIDATION_PROBE_PRIVDNS),
+                        eq(null));
+
+        // Verify dns query get both v4 and v6 address.
+        resetCallbacks();
+        mFakeDns.setAnswer("dns.google", new String[]{"2001:db8::54"}, TYPE_AAAA);
+        mFakeDns.setAnswer("dns.google", new String[]{"192.0.2.3"}, TYPE_A);
+        wnm.notifyPrivateDnsSettingsChanged(new PrivateDnsConfig("dns.google", new InetAddress[0]));
+        verify(mCallbacks, timeout(HANDLER_TIMEOUT_MS).times(1))
+                .notifyNetworkTested(eq(VALIDATION_RESULT_VALID | NETWORK_VALIDATION_PROBE_PRIVDNS),
+                        eq(null));
     }
 
     @Test
     public void testPrivateDnsResolutionRetryUpdate() throws Exception {
-        // Set a private DNS hostname that doesn't resolve and expect validation to fail.
-        mFakeDns.setAnswer("dns.google", new String[0]);
+        // Set no record in FakeDns and expect validation to fail.
         setStatus(mHttpsConnection, 204);
         setStatus(mHttpConnection, 204);
 
@@ -770,7 +836,7 @@
 
         // Fix DNS and retry, expect validation to succeed.
         resetCallbacks();
-        mFakeDns.setAnswer("dns.google", new String[]{"2001:db8::1"});
+        mFakeDns.setAnswer("dns.google", new String[]{"2001:db8::1"}, TYPE_AAAA);
 
         wnm.forceReevaluation(Process.myUid());
         verify(mCallbacks, timeout(HANDLER_TIMEOUT_MS).atLeastOnce())
@@ -779,7 +845,8 @@
 
         // Change configuration to an invalid DNS name, expect validation to fail.
         resetCallbacks();
-        mFakeDns.setAnswer("dns.bad", new String[0]);
+        mFakeDns.setAnswer("dns.bad", new String[0], TYPE_A);
+        mFakeDns.setAnswer("dns.bad", new String[0], TYPE_AAAA);
         wnm.notifyPrivateDnsSettingsChanged(new PrivateDnsConfig("dns.bad", new InetAddress[0]));
         // Strict mode hostname resolve fail. Expect only notification for evaluation fail. No probe
         // notification.
@@ -900,30 +967,34 @@
     public void testSendDnsProbeWithTimeout() throws Exception {
         WrappedNetworkMonitor wnm = makeNotMeteredNetworkMonitor();
         final int shortTimeoutMs = 200;
-
-        // Clear the wildcard DNS response created in setUp.
-        mFakeDns.setAnswer("*", null);
-
+        // v6 only.
         String[] expected = new String[]{"2001:db8::"};
-        mFakeDns.setAnswer("www.google.com", expected);
+        mFakeDns.setAnswer("www.google.com", expected, TYPE_AAAA);
         InetAddress[] actual = wnm.sendDnsProbeWithTimeout("www.google.com", shortTimeoutMs);
         assertIpAddressArrayEquals(expected, actual);
-
-        expected = new String[]{"2001:db8::", "192.0.2.1"};
-        mFakeDns.setAnswer("www.googleapis.com", expected);
+        // v4 only.
+        expected = new String[]{"192.0.2.1"};
+        mFakeDns.setAnswer("www.android.com", expected, TYPE_A);
+        actual = wnm.sendDnsProbeWithTimeout("www.android.com", shortTimeoutMs);
+        assertIpAddressArrayEquals(expected, actual);
+        // Both v4 & v6.
+        expected = new String[]{"192.0.2.1", "2001:db8::"};
+        mFakeDns.setAnswer("www.googleapis.com", new String[]{"192.0.2.1"}, TYPE_A);
+        mFakeDns.setAnswer("www.googleapis.com", new String[]{"2001:db8::"}, TYPE_AAAA);
         actual = wnm.sendDnsProbeWithTimeout("www.googleapis.com", shortTimeoutMs);
         assertIpAddressArrayEquals(expected, actual);
-
-        mFakeDns.setAnswer("www.google.com", new String[0]);
+        // Clear DNS response.
+        mFakeDns.setAnswer("www.android.com", new String[0], TYPE_A);
         try {
-            wnm.sendDnsProbeWithTimeout("www.google.com", shortTimeoutMs);
+            actual = wnm.sendDnsProbeWithTimeout("www.android.com", shortTimeoutMs);
             fail("No DNS results, expected UnknownHostException");
         } catch (UnknownHostException e) {
         }
 
-        mFakeDns.setAnswer("www.google.com", null);
+        mFakeDns.setAnswer("www.android.com", null, TYPE_A);
+        mFakeDns.setAnswer("www.android.com", null, TYPE_AAAA);
         try {
-            wnm.sendDnsProbeWithTimeout("www.google.com", shortTimeoutMs);
+            wnm.sendDnsProbeWithTimeout("www.android.com", shortTimeoutMs);
             fail("DNS query timed out, expected UnknownHostException");
         } catch (UnknownHostException e) {
         }