Restore the default interface MTU when disconnecting from Wi-Fi AP.

Bug: 113350007
Test: atest FrameworksNetTests NetworkStackTests
Test: atest NetworkStackIntegrationTests
Test: manual test

Change-Id: I709a504885033a330b946de402a261d341f78117
diff --git a/src/android/net/ip/IpClient.java b/src/android/net/ip/IpClient.java
index 24f985a..45bdff3 100644
--- a/src/android/net/ip/IpClient.java
+++ b/src/android/net/ip/IpClient.java
@@ -48,6 +48,7 @@
 import android.os.IBinder;
 import android.os.Message;
 import android.os.RemoteException;
+import android.os.ServiceSpecificException;
 import android.os.SystemClock;
 import android.text.TextUtils;
 import android.util.LocalLog;
@@ -850,10 +851,12 @@
         return shouldLog;
     }
 
+    private void logError(String fmt, Throwable e, Object... args) {
+        mLog.e(String.format(fmt, args), e);
+    }
+
     private void logError(String fmt, Object... args) {
-        final String msg = "ERROR " + String.format(fmt, args);
-        Log.e(mTag, msg);
-        mLog.log(msg);
+        logError(fmt, null, args);
     }
 
     // This needs to be called with care to ensure that our LinkProperties
@@ -1274,6 +1277,28 @@
         // TODO : implement this
     }
 
+    private void maybeRestoreInterfaceMtu() {
+        InterfaceParams params = mDependencies.getInterfaceParams(mInterfaceName);
+        if (params == null) {
+            Log.w(mTag, "interface: " + mInterfaceName + " is gone");
+            return;
+        }
+
+        if (params.index != mInterfaceParams.index) {
+            Log.w(mTag, "interface: " + mInterfaceName + " has a different index: " + params.index);
+            return;
+        }
+
+        if (params.defaultMtu != mInterfaceParams.defaultMtu) {
+            try {
+                mNetd.interfaceSetMtu(mInterfaceName, mInterfaceParams.defaultMtu);
+            } catch (RemoteException | ServiceSpecificException e) {
+                logError("Couldn't reset MTU on " + mInterfaceName + " from "
+                        + params.defaultMtu + " to " + mInterfaceParams.defaultMtu, e);
+            }
+        }
+    }
+
     class StoppedState extends State {
         @Override
         public void enter() {
@@ -1351,6 +1376,9 @@
                 // There's no DHCPv4 for which to wait; proceed to stopped.
                 deferMessage(obtainMessage(CMD_JUMP_STOPPING_TO_STOPPED));
             }
+
+            // Restore the interface MTU to initial value if it has changed.
+            maybeRestoreInterfaceMtu();
         }
 
         @Override
diff --git a/src/com/android/server/util/NetworkStackConstants.java b/src/com/android/server/util/NetworkStackConstants.java
index 804765e..3174a9b 100644
--- a/src/com/android/server/util/NetworkStackConstants.java
+++ b/src/com/android/server/util/NetworkStackConstants.java
@@ -54,6 +54,7 @@
     public static final int ETHER_TYPE_IPV4 = 0x0800;
     public static final int ETHER_TYPE_IPV6 = 0x86dd;
     public static final int ETHER_HEADER_LEN = 14;
+    public static final int ETHER_MTU = 1500;
 
     /**
      * ARP constants.
@@ -97,6 +98,7 @@
     public static final int IPV6_PROTOCOL_OFFSET = 6;
     public static final int IPV6_SRC_ADDR_OFFSET = 8;
     public static final int IPV6_DST_ADDR_OFFSET = 24;
+    public static final int IPV6_MIN_MTU = 1280;
 
     /**
      * ICMPv6 constants.
diff --git a/tests/integration/Android.bp b/tests/integration/Android.bp
index ec8257f..ec16467 100644
--- a/tests/integration/Android.bp
+++ b/tests/integration/Android.bp
@@ -23,6 +23,7 @@
         "androidx.annotation_annotation",
         "androidx.test.rules",
         "mockito-target-extended-minus-junit4",
+        "net-tests-utils",
         "NetworkStackBase",
         "testables",
     ],
diff --git a/tests/integration/src/android/net/ip/IpClientIntegrationTest.java b/tests/integration/src/android/net/ip/IpClientIntegrationTest.java
index 16e92ef..400bd51 100644
--- a/tests/integration/src/android/net/ip/IpClientIntegrationTest.java
+++ b/tests/integration/src/android/net/ip/IpClientIntegrationTest.java
@@ -34,6 +34,7 @@
 import static org.junit.Assert.assertTrue;
 import static org.mockito.Mockito.any;
 import static org.mockito.Mockito.doAnswer;
+import static org.mockito.Mockito.doThrow;
 import static org.mockito.Mockito.eq;
 import static org.mockito.Mockito.never;
 import static org.mockito.Mockito.timeout;
@@ -77,6 +78,7 @@
 import com.android.server.NetworkObserverRegistry;
 import com.android.server.NetworkStackService.NetworkStackServiceManager;
 import com.android.server.connectivity.ipmemorystore.IpMemoryStoreService;
+import com.android.testutils.HandlerUtilsKt;
 
 import org.junit.After;
 import org.junit.Before;
@@ -91,6 +93,7 @@
 import java.io.FileOutputStream;
 import java.io.IOException;
 import java.net.Inet4Address;
+import java.net.NetworkInterface;
 import java.nio.ByteBuffer;
 import java.util.Arrays;
 import java.util.Collections;
@@ -112,7 +115,7 @@
 
     @Mock private Context mContext;
     @Mock private ConnectivityManager mCm;
-    @Mock private INetd mNetd;
+    @Mock private INetd mMockNetd;
     @Mock private Resources mResources;
     @Mock private IIpClientCallbacks mCb;
     @Mock private AlarmManager mAlarm;
@@ -122,6 +125,7 @@
     @Mock private IpMemoryStoreService mIpMemoryStoreService;
 
     private String mIfaceName;
+    private INetd mNetd;
     private HandlerThread mPacketReaderThread;
     private TapPacketReader mPacketReader;
     private IpClient mIpc;
@@ -158,7 +162,8 @@
     private static final Inet4Address BROADCAST_ADDR = getBroadcastAddress(
             SERVER_ADDR, PREFIX_LENGTH);
     private static final String HOSTNAME = "testhostname";
-    private static final short MTU = 1500;
+    private static final int TEST_DEFAULT_MTU = 1500;
+    private static final int TEST_MIN_MTU = 1280;
 
     private static class TapPacketReader extends PacketReader {
         private final ParcelFileDescriptor mTapFd;
@@ -214,7 +219,7 @@
 
         @Override
         public INetd getNetd(Context context) {
-            return mNetd;
+            return mMockNetd;
         }
 
         @Override
@@ -287,7 +292,6 @@
             inst.getUiAutomation().dropShellPermissionIdentity();
         }
         mIfaceName = iface.getInterfaceName();
-
         mPacketReaderThread = new HandlerThread(IpClientIntegrationTest.class.getSimpleName());
         mPacketReaderThread.start();
 
@@ -300,12 +304,12 @@
         final Instrumentation inst = InstrumentationRegistry.getInstrumentation();
         final IBinder netdIBinder =
                 (IBinder) inst.getContext().getSystemService(Context.NETD_SERVICE);
-        final INetd netd = INetd.Stub.asInterface(netdIBinder);
+        mNetd = INetd.Stub.asInterface(netdIBinder);
         when(mContext.getSystemService(eq(Context.NETD_SERVICE))).thenReturn(netdIBinder);
-        assertNotNull(netd);
+        assertNotNull(mNetd);
 
         final NetworkObserverRegistry reg = new NetworkObserverRegistry();
-        reg.register(netd);
+        reg.register(mNetd);
         mIpc = new IpClient(mContext, mIfaceName, mCb, reg, mNetworkStackServiceManager,
                 mDependencies);
     }
@@ -345,7 +349,7 @@
     }
 
     private static ByteBuffer buildDhcpOfferPacket(final DhcpPacket packet,
-            final Integer leaseTimeSec) {
+            final Integer leaseTimeSec, final short mtu) {
         return DhcpPacket.buildOfferPacket(DhcpPacket.ENCAP_L2, packet.getTransactionId(),
                 false /* broadcast */, SERVER_ADDR, INADDR_ANY /* relayIp */,
                 CLIENT_ADDR /* yourIp */, packet.getClientMac(), leaseTimeSec,
@@ -353,11 +357,11 @@
                 Collections.singletonList(SERVER_ADDR) /* gateways */,
                 Collections.singletonList(SERVER_ADDR) /* dnsServers */,
                 SERVER_ADDR /* dhcpServerIdentifier */, null /* domainName */, HOSTNAME,
-                false /* metered */, MTU);
+                false /* metered */, mtu);
     }
 
     private static ByteBuffer buildDhcpAckPacket(final DhcpPacket packet,
-            final Integer leaseTimeSec) {
+            final Integer leaseTimeSec, final short mtu) {
         return DhcpPacket.buildAckPacket(DhcpPacket.ENCAP_L2, packet.getTransactionId(),
                 false /* broadcast */, SERVER_ADDR, INADDR_ANY /* relayIp */,
                 CLIENT_ADDR /* yourIp */, CLIENT_ADDR /* requestIp */, packet.getClientMac(),
@@ -365,7 +369,7 @@
                 Collections.singletonList(SERVER_ADDR) /* gateways */,
                 Collections.singletonList(SERVER_ADDR) /* dnsServers */,
                 SERVER_ADDR /* dhcpServerIdentifier */, null /* domainName */, HOSTNAME,
-                false /* metered */, MTU);
+                false /* metered */, mtu);
     }
 
     private static ByteBuffer buildDhcpNakPacket(final DhcpPacket packet) {
@@ -397,7 +401,7 @@
     }
 
     private void assertIpMemoryStoreNetworkAttributes(final Integer leaseTimeSec,
-            final long startTime) {
+            final long startTime, final int mtu) {
         final ArgumentCaptor<NetworkAttributes> networkAttributes =
                 ArgumentCaptor.forClass(NetworkAttributes.class);
 
@@ -416,7 +420,7 @@
             assertTrue(lowerBound < expiry);
         }
         assertEquals(Collections.singletonList(SERVER_ADDR), naValueCaptured.dnsAddresses);
-        assertEquals(new Integer((int) MTU), naValueCaptured.mtu);
+        assertEquals(new Integer(mtu), naValueCaptured.mtu);
     }
 
     private void assertIpMemoryNeverStoreNetworkAttributes() {
@@ -426,20 +430,20 @@
     // Helper method to complete DHCP 2-way or 4-way handshake
     private void performDhcpHandshake(final boolean isSuccessLease,
             final Integer leaseTimeSec, final boolean isDhcpLeaseCacheEnabled,
-            final boolean isDhcpRapidCommitEnabled) throws Exception {
+            final boolean isDhcpRapidCommitEnabled, final int mtu) throws Exception {
         startIpClientProvisioning(isDhcpLeaseCacheEnabled, isDhcpRapidCommitEnabled);
 
         DhcpPacket packet;
         while ((packet = getNextDhcpPacket()) != null) {
             if (packet instanceof DhcpDiscoverPacket) {
                 if (isDhcpRapidCommitEnabled) {
-                    sendResponse(buildDhcpAckPacket(packet, leaseTimeSec));
+                    sendResponse(buildDhcpAckPacket(packet, leaseTimeSec, (short) mtu));
                 } else {
-                    sendResponse(buildDhcpOfferPacket(packet, leaseTimeSec));
+                    sendResponse(buildDhcpOfferPacket(packet, leaseTimeSec, (short) mtu));
                 }
             } else if (packet instanceof DhcpRequestPacket) {
                 final ByteBuffer byteBuffer = isSuccessLease
-                        ? buildDhcpAckPacket(packet, leaseTimeSec)
+                        ? buildDhcpAckPacket(packet, leaseTimeSec, (short) mtu)
                         : buildDhcpNakPacket(packet);
                 sendResponse(byteBuffer);
             } else {
@@ -474,6 +478,28 @@
         return getNextDhcpPacket();
     }
 
+    private void doRestoreInitialMtuTest(final boolean shouldChangeMtu) throws Exception {
+        final long currentTime = System.currentTimeMillis();
+        int mtu = TEST_DEFAULT_MTU;
+
+        if (shouldChangeMtu) mtu = TEST_MIN_MTU;
+        performDhcpHandshake(true /* isSuccessLease */, TEST_LEASE_DURATION_S,
+                true /* isDhcpLeaseCacheEnabled */, false /* isDhcpRapidCommitEnabled */, mtu);
+        assertIpMemoryStoreNetworkAttributes(TEST_LEASE_DURATION_S, currentTime, mtu);
+
+        if (shouldChangeMtu) {
+            // Pretend that ConnectivityService set the MTU.
+            mNetd.interfaceSetMtu(mIfaceName, mtu);
+            assertEquals(NetworkInterface.getByName(mIfaceName).getMTU(), mtu);
+        }
+
+        mIpc.shutdown();
+        HandlerUtilsKt.waitForIdle(mIpc.getHandler(), TEST_TIMEOUT_MS);
+        // Verify that MTU indeed has been restored or not.
+        verify(mMockNetd, times(shouldChangeMtu ? 1 : 0)).interfaceSetMtu(mIfaceName,
+                TEST_DEFAULT_MTU);
+    }
+
     @Test
     public void testDhcpInit() throws Exception {
         startIpClientProvisioning(false /* isDhcpLeaseCacheEnabled */,
@@ -486,14 +512,16 @@
     public void testHandleSuccessDhcpLease() throws Exception {
         final long currentTime = System.currentTimeMillis();
         performDhcpHandshake(true /* isSuccessLease */, TEST_LEASE_DURATION_S,
-                true /* isDhcpLeaseCacheEnabled */, false /* isDhcpRapidCommitEnabled */);
-        assertIpMemoryStoreNetworkAttributes(TEST_LEASE_DURATION_S, currentTime);
+                true /* isDhcpLeaseCacheEnabled */, false /* isDhcpRapidCommitEnabled */,
+                TEST_DEFAULT_MTU);
+        assertIpMemoryStoreNetworkAttributes(TEST_LEASE_DURATION_S, currentTime, TEST_DEFAULT_MTU);
     }
 
     @Test
     public void testHandleFailureDhcpLease() throws Exception {
         performDhcpHandshake(false /* isSuccessLease */, TEST_LEASE_DURATION_S,
-                true /* isDhcpLeaseCacheEnabled */, false /* isDhcpRapidCommitEnabled */);
+                true /* isDhcpLeaseCacheEnabled */, false /* isDhcpRapidCommitEnabled */,
+                TEST_DEFAULT_MTU);
         assertIpMemoryNeverStoreNetworkAttributes();
     }
 
@@ -501,22 +529,25 @@
     public void testHandleInfiniteLease() throws Exception {
         final long currentTime = System.currentTimeMillis();
         performDhcpHandshake(true /* isSuccessLease */, INFINITE_LEASE,
-                true /* isDhcpLeaseCacheEnabled */, false /* isDhcpRapidCommitEnabled */);
-        assertIpMemoryStoreNetworkAttributes(INFINITE_LEASE, currentTime);
+                true /* isDhcpLeaseCacheEnabled */, false /* isDhcpRapidCommitEnabled */,
+                TEST_DEFAULT_MTU);
+        assertIpMemoryStoreNetworkAttributes(INFINITE_LEASE, currentTime, TEST_DEFAULT_MTU);
     }
 
     @Test
     public void testHandleNoLease() throws Exception {
         final long currentTime = System.currentTimeMillis();
         performDhcpHandshake(true /* isSuccessLease */, null /* no lease time */,
-                true /* isDhcpLeaseCacheEnabled */, false /* isDhcpRapidCommitEnabled */);
-        assertIpMemoryStoreNetworkAttributes(null, currentTime);
+                true /* isDhcpLeaseCacheEnabled */, false /* isDhcpRapidCommitEnabled */,
+                TEST_DEFAULT_MTU);
+        assertIpMemoryStoreNetworkAttributes(null, currentTime, TEST_DEFAULT_MTU);
     }
 
     @Test
     public void testHandleDisableInitRebootState() throws Exception {
         performDhcpHandshake(true /* isSuccessLease */, TEST_LEASE_DURATION_S,
-                false /* isDhcpLeaseCacheEnabled */, false /* isDhcpRapidCommitEnabled */);
+                false /* isDhcpLeaseCacheEnabled */, false /* isDhcpRapidCommitEnabled */,
+                TEST_DEFAULT_MTU);
         assertIpMemoryNeverStoreNetworkAttributes();
     }
 
@@ -526,8 +557,9 @@
         // TODO: remove @Ignore after supporting rapid commit option in DHCP server
         final long currentTime = System.currentTimeMillis();
         performDhcpHandshake(true /* isSuccessLease */, TEST_LEASE_DURATION_S,
-                true /* isDhcpLeaseCacheEnabled */, true /* isDhcpRapidCommitEnabled */);
-        assertIpMemoryStoreNetworkAttributes(TEST_LEASE_DURATION_S, currentTime);
+                true /* isDhcpLeaseCacheEnabled */, true /* isDhcpRapidCommitEnabled */,
+                TEST_DEFAULT_MTU);
+        assertIpMemoryStoreNetworkAttributes(TEST_LEASE_DURATION_S, currentTime, TEST_DEFAULT_MTU);
     }
 
     @Test
@@ -536,7 +568,7 @@
                 new NetworkAttributes.Builder()
                     .setAssignedV4Address(CLIENT_ADDR)
                     .setAssignedV4AddressExpiry(Long.MAX_VALUE) // lease is always valid
-                    .setMtu(new Integer(MTU))
+                    .setMtu(new Integer(TEST_DEFAULT_MTU))
                     .setGroupHint(TEST_GROUPHINT)
                     .setDnsAddresses(Collections.singletonList(SERVER_ADDR))
                     .build(), false /* timeout */);
@@ -549,7 +581,7 @@
                  new NetworkAttributes.Builder()
                     .setAssignedV4Address(CLIENT_ADDR)
                     .setAssignedV4AddressExpiry(EXPIRED_LEASE)
-                    .setMtu(new Integer(MTU))
+                    .setMtu(new Integer(TEST_DEFAULT_MTU))
                     .setGroupHint(TEST_GROUPHINT)
                     .setDnsAddresses(Collections.singletonList(SERVER_ADDR))
                     .build(), false /* timeout */);
@@ -568,7 +600,7 @@
                 new NetworkAttributes.Builder()
                     .setAssignedV4Address(CLIENT_ADDR)
                     .setAssignedV4AddressExpiry(System.currentTimeMillis() + 3_600_000)
-                    .setMtu(new Integer(MTU))
+                    .setMtu(new Integer(TEST_DEFAULT_MTU))
                     .setGroupHint(TEST_GROUPHINT)
                     .setDnsAddresses(Collections.singletonList(SERVER_ADDR))
                     .build(), true /* timeout */);
@@ -579,7 +611,7 @@
     public void testDhcpClientStartWithCachedLeaseWithoutIPAddress() throws Exception {
         final DhcpPacket packet = getReplyFromDhcpLease(
                 new NetworkAttributes.Builder()
-                    .setMtu(new Integer(MTU))
+                    .setMtu(new Integer(TEST_DEFAULT_MTU))
                     .setGroupHint(TEST_GROUPHINT)
                     .setDnsAddresses(Collections.singletonList(SERVER_ADDR))
                     .build(), false /* timeout */);
@@ -593,4 +625,23 @@
         final DhcpPacket packet = getNextDhcpPacket();
         assertTrue(DhcpDiscoverPacket.class.isInstance(packet));
     }
+
+    @Test
+    public void testRestoreInitialInterfaceMtu() throws Exception {
+        doRestoreInitialMtuTest(true /* shouldChangeMtu */);
+    }
+
+    @Test
+    public void testRestoreInitialInterfaceMtuWithoutChange() throws Exception {
+        doRestoreInitialMtuTest(false /* shouldChangeMtu */);
+    }
+
+    @Test
+    public void testRestoreInitialInterfaceMtuWithException() throws Exception {
+        doThrow(new RemoteException("NetdNativeService::interfaceSetMtu")).when(mMockNetd)
+                .interfaceSetMtu(mIfaceName, TEST_DEFAULT_MTU);
+
+        doRestoreInitialMtuTest(true /* shouldChangeMtu */);
+        assertEquals(NetworkInterface.getByName(mIfaceName).getMTU(), TEST_MIN_MTU);
+    }
 }