Implement ActivityManager to LMKD reconnect functionality

In case of LMKD crashing and restarting or for testing purposes when LMKD
is killed and restarted on purpose, ActivityManager needs to reconnect and
reset the oom_adj levels so that LMKD can function. Implement a separate
LmkdConnection class handling LMKD connection and communication. We use
FileDescriptorEventListener to detect LMKD connection loss and retry logic
is implemented to reconnect and reset LMKD parameters.

Bug: 124618999
Test: kill lmkd and verify successful reconnect

Change-Id: I1b63c960aa9580cdc4f8d3697cd02fa9c18a2935
Merged-In: I1b63c960aa9580cdc4f8d3697cd02fa9c18a2935
Signed-off-by: Suren Baghdasaryan <surenb@google.com>
diff --git a/services/core/java/com/android/server/am/LmkdConnection.java b/services/core/java/com/android/server/am/LmkdConnection.java
new file mode 100644
index 0000000..0a2f608
--- /dev/null
+++ b/services/core/java/com/android/server/am/LmkdConnection.java
@@ -0,0 +1,291 @@
+/*
+ * Copyright (C) 2019 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.server.am;
+
+import static android.os.MessageQueue.OnFileDescriptorEventListener.EVENT_ERROR;
+import static android.os.MessageQueue.OnFileDescriptorEventListener.EVENT_INPUT;
+
+import static com.android.server.am.ActivityManagerDebugConfig.TAG_AM;
+import static com.android.server.am.ActivityManagerDebugConfig.TAG_WITH_CLASS_NAME;
+
+import android.net.LocalSocket;
+import android.net.LocalSocketAddress;
+import android.os.MessageQueue;
+import android.util.Slog;
+
+import com.android.internal.annotations.GuardedBy;
+
+import libcore.io.IoUtils;
+
+import java.io.FileDescriptor;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
+import java.nio.ByteBuffer;
+
+/**
+ * Lmkd connection to communicate with lowmemorykiller daemon.
+ */
+public class LmkdConnection {
+    private static final String TAG = TAG_WITH_CLASS_NAME ? "LmkdConnection" : TAG_AM;
+
+    // lmkd reply max size in bytes
+    private static final int LMKD_REPLY_MAX_SIZE = 8;
+
+    // connection listener interface
+    interface LmkdConnectionListener {
+        boolean onConnect(OutputStream ostream);
+        void onDisconnect();
+        /**
+         * Check if received reply was expected (reply to an earlier request)
+         *
+         * @param replyBuf The buffer provided in exchange() to receive the reply.
+         *                 It can be used by exchange() caller to store reply-specific
+         *                 tags for later use in isReplyExpected() to verify if
+         *                 received packet is the expected reply.
+         * @param dataReceived The buffer holding received data
+         * @param receivedLen Size of the data received
+         */
+        boolean isReplyExpected(ByteBuffer replyBuf, ByteBuffer dataReceived,
+                int receivedLen);
+    }
+
+    private final MessageQueue mMsgQueue;
+
+    // lmkd connection listener
+    private final LmkdConnectionListener mListener;
+
+    // mutex to synchronize access to the socket
+    private final Object mLmkdSocketLock = new Object();
+
+    // socket to communicate with lmkd
+    @GuardedBy("mLmkdSocketLock")
+    private LocalSocket mLmkdSocket = null;
+
+    // socket I/O streams
+    @GuardedBy("mLmkdSocketLock")
+    private OutputStream mLmkdOutputStream = null;
+    @GuardedBy("mLmkdSocketLock")
+    private InputStream mLmkdInputStream = null;
+
+    // buffer to store incoming data
+    private final ByteBuffer mInputBuf =
+            ByteBuffer.allocate(LMKD_REPLY_MAX_SIZE);
+
+    // object to protect mReplyBuf and to wait/notify when reply is received
+    private final Object mReplyBufLock = new Object();
+
+    // reply buffer
+    @GuardedBy("mReplyBufLock")
+    private ByteBuffer mReplyBuf = null;
+
+    ////////////////////  END FIELDS  ////////////////////
+
+    LmkdConnection(MessageQueue msgQueue, LmkdConnectionListener listener) {
+        mMsgQueue = msgQueue;
+        mListener = listener;
+    }
+
+    boolean connect() {
+        synchronized (mLmkdSocketLock) {
+            if (mLmkdSocket != null) {
+                return true;
+            }
+            // temporary sockets and I/O streams
+            final LocalSocket socket = openSocket();
+
+            if (socket == null) {
+                Slog.w(TAG, "Failed to connect to lowmemorykiller, retry later");
+                return false;
+            }
+
+            final OutputStream ostream;
+            final InputStream istream;
+            try {
+                ostream = socket.getOutputStream();
+                istream = socket.getInputStream();
+            } catch (IOException ex) {
+                IoUtils.closeQuietly(socket);
+                return false;
+            }
+            // execute onConnect callback
+            if (mListener != null && !mListener.onConnect(ostream)) {
+                Slog.w(TAG, "Failed to communicate with lowmemorykiller, retry later");
+                IoUtils.closeQuietly(socket);
+                return false;
+            }
+            // connection established
+            mLmkdSocket = socket;
+            mLmkdOutputStream = ostream;
+            mLmkdInputStream = istream;
+            mMsgQueue.addOnFileDescriptorEventListener(mLmkdSocket.getFileDescriptor(),
+                    EVENT_INPUT | EVENT_ERROR,
+                    new MessageQueue.OnFileDescriptorEventListener() {
+                        public int onFileDescriptorEvents(FileDescriptor fd, int events) {
+                            return fileDescriptorEventHandler(fd, events);
+                        }
+                    }
+            );
+            mLmkdSocketLock.notifyAll();
+        }
+        return true;
+    }
+
+    private int fileDescriptorEventHandler(FileDescriptor fd, int events) {
+        if (mListener == null) {
+            return 0;
+        }
+        if ((events & EVENT_INPUT) != 0) {
+            processIncomingData();
+        }
+        if ((events & EVENT_ERROR) != 0) {
+            synchronized (mLmkdSocketLock) {
+                // stop listening on this socket
+                mMsgQueue.removeOnFileDescriptorEventListener(
+                        mLmkdSocket.getFileDescriptor());
+                IoUtils.closeQuietly(mLmkdSocket);
+                mLmkdSocket = null;
+            }
+            // wake up reply waiters if any
+            synchronized (mReplyBufLock) {
+                if (mReplyBuf != null) {
+                    mReplyBuf = null;
+                    mReplyBufLock.notifyAll();
+                }
+            }
+            // notify listener
+            mListener.onDisconnect();
+            return 0;
+        }
+        return (EVENT_INPUT | EVENT_ERROR);
+    }
+
+    private void processIncomingData() {
+        int len = read(mInputBuf);
+        if (len > 0) {
+            synchronized (mReplyBufLock) {
+                if (mReplyBuf != null) {
+                    if (mListener.isReplyExpected(mReplyBuf, mInputBuf, len)) {
+                        // copy into reply buffer
+                        mReplyBuf.put(mInputBuf.array(), 0, len);
+                        mReplyBuf.rewind();
+                        // wakeup the waiting thread
+                        mReplyBufLock.notifyAll();
+                    } else {
+                        // received asynchronous or unexpected packet
+                        // treat this as an error
+                        mReplyBuf = null;
+                        mReplyBufLock.notifyAll();
+                        Slog.e(TAG, "Received unexpected packet from lmkd");
+                    }
+                } else {
+                    // received asynchronous communication from lmkd
+                    // we don't support this yet
+                    Slog.w(TAG, "Received an asynchronous packet from lmkd");
+                }
+            }
+        }
+    }
+
+    boolean isConnected() {
+        synchronized (mLmkdSocketLock) {
+            return (mLmkdSocket != null);
+        }
+    }
+
+    boolean waitForConnection(long timeoutMs) {
+        synchronized (mLmkdSocketLock) {
+            if (mLmkdSocket != null) {
+                return true;
+            }
+            try {
+                mLmkdSocketLock.wait(timeoutMs);
+                return (mLmkdSocket != null);
+            } catch (InterruptedException e) {
+                return false;
+            }
+        }
+    }
+
+    private LocalSocket openSocket() {
+        final LocalSocket socket;
+
+        try {
+            socket = new LocalSocket(LocalSocket.SOCKET_SEQPACKET);
+            socket.connect(
+                    new LocalSocketAddress("lmkd",
+                        LocalSocketAddress.Namespace.RESERVED));
+        } catch (IOException ex) {
+            Slog.e(TAG, "Connection failed: " + ex.toString());
+            return null;
+        }
+        return socket;
+    }
+
+    private boolean write(ByteBuffer buf) {
+        synchronized (mLmkdSocketLock) {
+            try {
+                mLmkdOutputStream.write(buf.array(), 0, buf.position());
+            } catch (IOException ex) {
+                return false;
+            }
+            return true;
+        }
+    }
+
+    private int read(ByteBuffer buf) {
+        synchronized (mLmkdSocketLock) {
+            try {
+                return mLmkdInputStream.read(buf.array(), 0, buf.array().length);
+            } catch (IOException ex) {
+            }
+            return -1;
+        }
+    }
+
+    /**
+     * Exchange a request/reply packets with lmkd
+     *
+     * @param req The buffer holding the request data to be sent
+     * @param repl The buffer to receive the reply
+     */
+    public boolean exchange(ByteBuffer req, ByteBuffer repl) {
+        if (repl == null) {
+            return write(req);
+        }
+
+        boolean result = false;
+        // set reply buffer to user-defined one to fill it
+        synchronized (mReplyBufLock) {
+            mReplyBuf = repl;
+
+            if (write(req)) {
+                try {
+                    // wait for the reply
+                    mReplyBufLock.wait();
+                    result = (mReplyBuf != null);
+                } catch (InterruptedException ie) {
+                    result = false;
+                }
+            }
+
+            // reset reply buffer
+            mReplyBuf = null;
+        }
+        return result;
+    }
+}
diff --git a/services/core/java/com/android/server/am/ProcessList.java b/services/core/java/com/android/server/am/ProcessList.java
index f397a32..197829a 100644
--- a/services/core/java/com/android/server/am/ProcessList.java
+++ b/services/core/java/com/android/server/am/ProcessList.java
@@ -57,8 +57,6 @@
 import android.content.pm.IPackageManager;
 import android.content.res.Resources;
 import android.graphics.Point;
-import android.net.LocalSocket;
-import android.net.LocalSocketAddress;
 import android.os.AppZygote;
 import android.os.Binder;
 import android.os.Build;
@@ -103,11 +101,8 @@
 
 import dalvik.system.VMRuntime;
 
-import libcore.io.IoUtils;
-
 import java.io.File;
 import java.io.IOException;
-import java.io.InputStream;
 import java.io.OutputStream;
 import java.io.PrintWriter;
 import java.nio.ByteBuffer;
@@ -118,11 +113,6 @@
 
 /**
  * Activity manager code dealing with processes.
- *
- * Method naming convention:
- * <ul>
- * <li> Methods suffixed with "LS" should be called within the {@link #sLmkdSocketLock} lock.
- * </ul>
  */
 public final class ProcessList {
     static final String TAG = TAG_WITH_CLASS_NAME ? "ProcessList" : TAG_AM;
@@ -269,6 +259,9 @@
     static final byte LMK_PROCPURGE = 3;
     static final byte LMK_GETKILLCNT = 4;
 
+    // lmkd reconnect delay in msecs
+    private static final long LMDK_RECONNECT_DELAY_MS = 1000;
+
     ActivityManagerService mService = null;
 
     // To kill process groups asynchronously
@@ -303,16 +296,9 @@
 
     private boolean mHaveDisplaySize;
 
-    private static Object sLmkdSocketLock = new Object();
+    private static LmkdConnection sLmkdConnection = null;
 
-    @GuardedBy("sLmkdSocketLock")
-    private static LocalSocket sLmkdSocket;
-
-    @GuardedBy("sLmkdSocketLock")
-    private static OutputStream sLmkdOutputStream;
-
-    @GuardedBy("sLmkdSocketLock")
-    private static InputStream sLmkdInputStream;
+    private boolean mOomLevelsSet = false;
 
     /**
      * Temporary to avoid allocations.  Protected by main lock.
@@ -539,6 +525,7 @@
 
     final class KillHandler extends Handler {
         static final int KILL_PROCESS_GROUP_MSG = 4000;
+        static final int LMDK_RECONNECT_MSG = 4001;
 
         public KillHandler(Looper looper) {
             super(looper, null, true);
@@ -552,6 +539,15 @@
                     Process.killProcessGroup(msg.arg1 /* uid */, msg.arg2 /* pid */);
                     Trace.traceEnd(Trace.TRACE_TAG_ACTIVITY_MANAGER);
                     break;
+                case LMDK_RECONNECT_MSG:
+                    if (!sLmkdConnection.connect()) {
+                        Slog.i(TAG, "Failed to connect to lmkd, retry after "
+                                + LMDK_RECONNECT_DELAY_MS + " ms");
+                        // retry after LMDK_RECONNECT_DELAY_MS
+                        sKillHandler.sendMessageDelayed(sKillHandler.obtainMessage(
+                                KillHandler.LMDK_RECONNECT_MSG), LMDK_RECONNECT_DELAY_MS);
+                    }
+                    break;
 
                 default:
                     super.handleMessage(msg);
@@ -579,6 +575,30 @@
                     THREAD_PRIORITY_BACKGROUND, true /* allowIo */);
             sKillThread.start();
             sKillHandler = new KillHandler(sKillThread.getLooper());
+            sLmkdConnection = new LmkdConnection(sKillThread.getLooper().getQueue(),
+                    new LmkdConnection.LmkdConnectionListener() {
+                        @Override
+                        public boolean onConnect(OutputStream ostream) {
+                            Slog.i(TAG, "Connection with lmkd established");
+                            return onLmkdConnect(ostream);
+                        }
+                        @Override
+                        public void onDisconnect() {
+                            Slog.w(TAG, "Lost connection to lmkd");
+                            // start reconnection after delay to let lmkd restart
+                            sKillHandler.sendMessageDelayed(sKillHandler.obtainMessage(
+                                    KillHandler.LMDK_RECONNECT_MSG), LMDK_RECONNECT_DELAY_MS);
+                        }
+                        @Override
+                        public boolean isReplyExpected(ByteBuffer replyBuf,
+                                ByteBuffer dataReceived, int receivedLen) {
+                            // compare the preambule (currently one integer) to check if
+                            // this is the reply packet we are waiting for
+                            return (receivedLen == replyBuf.array().length
+                                    && dataReceived.getInt(0) == replyBuf.getInt(0));
+                        }
+                    }
+            );
         }
     }
 
@@ -684,6 +704,7 @@
 
             writeLmkd(buf, null);
             SystemProperties.set("sys.sysctl.extra_free_kbytes", Integer.toString(reserve));
+            mOomLevelsSet = true;
         }
         // GB: 2048,3072,4096,6144,7168,8192
         // HC: 8192,10240,12288,14336,16384,20480
@@ -1223,93 +1244,50 @@
         buf.putInt(LMK_GETKILLCNT);
         buf.putInt(min_oom_adj);
         buf.putInt(max_oom_adj);
-        if (writeLmkd(buf, repl)) {
-            int i = repl.getInt();
-            if (i != LMK_GETKILLCNT) {
-                Slog.e("ActivityManager", "Failed to get kill count, code mismatch");
-                return null;
-            }
+        // indicate what we are waiting for
+        repl.putInt(LMK_GETKILLCNT);
+        repl.rewind();
+        if (writeLmkd(buf, repl) && repl.getInt() == LMK_GETKILLCNT) {
             return new Integer(repl.getInt());
         }
         return null;
     }
 
-    @GuardedBy("sLmkdSocketLock")
-    private static boolean openLmkdSocketLS() {
+    boolean onLmkdConnect(OutputStream ostream) {
         try {
-            sLmkdSocket = new LocalSocket(LocalSocket.SOCKET_SEQPACKET);
-            sLmkdSocket.connect(
-                new LocalSocketAddress("lmkd",
-                        LocalSocketAddress.Namespace.RESERVED));
-            sLmkdOutputStream = sLmkdSocket.getOutputStream();
-            sLmkdInputStream = sLmkdSocket.getInputStream();
-        } catch (IOException ex) {
-            Slog.w(TAG, "lowmemorykiller daemon socket open failed");
-            sLmkdSocket = null;
-            return false;
-        }
-
-        return true;
-    }
-
-    // Never call directly, use writeLmkd() instead
-    @GuardedBy("sLmkdSocketLock")
-    private static boolean writeLmkdCommandLS(ByteBuffer buf) {
-        try {
-            sLmkdOutputStream.write(buf.array(), 0, buf.position());
-        } catch (IOException ex) {
-            Slog.w(TAG, "Error writing to lowmemorykiller socket");
-            IoUtils.closeQuietly(sLmkdSocket);
-            sLmkdSocket = null;
-            return false;
-        }
-        return true;
-    }
-
-    // Never call directly, use writeLmkd() instead
-    @GuardedBy("sLmkdSocketLock")
-    private static boolean readLmkdReplyLS(ByteBuffer buf) {
-        int len;
-        try {
-            len = sLmkdInputStream.read(buf.array(), 0, buf.array().length);
-            if (len == buf.array().length) {
-                return true;
+            // Purge any previously registered pids
+            ByteBuffer buf = ByteBuffer.allocate(4);
+            buf.putInt(LMK_PROCPURGE);
+            ostream.write(buf.array(), 0, buf.position());
+            if (mOomLevelsSet) {
+                // Reset oom_adj levels
+                buf = ByteBuffer.allocate(4 * (2 * mOomAdj.length + 1));
+                buf.putInt(LMK_TARGET);
+                for (int i = 0; i < mOomAdj.length; i++) {
+                    buf.putInt((mOomMinFree[i] * 1024) / PAGE_SIZE);
+                    buf.putInt(mOomAdj[i]);
+                }
+                ostream.write(buf.array(), 0, buf.position());
             }
         } catch (IOException ex) {
-            Slog.w(TAG, "Error reading from lowmemorykiller socket");
+            return false;
         }
-
-        IoUtils.closeQuietly(sLmkdSocket);
-        sLmkdSocket = null;
-        return false;
+        return true;
     }
 
     private static boolean writeLmkd(ByteBuffer buf, ByteBuffer repl) {
-        synchronized (sLmkdSocketLock) {
-            for (int i = 0; i < 3; i++) {
-                if (sLmkdSocket == null) {
-                    if (openLmkdSocketLS() == false) {
-                        try {
-                            Thread.sleep(1000);
-                        } catch (InterruptedException ie) {
-                        }
-                        continue;
-                    }
+        if (!sLmkdConnection.isConnected()) {
+            // try to connect immediately and then keep retrying
+            sKillHandler.sendMessage(
+                    sKillHandler.obtainMessage(KillHandler.LMDK_RECONNECT_MSG));
 
-                    // Purge any previously registered pids
-                    ByteBuffer purge_buf = ByteBuffer.allocate(4);
-                    purge_buf.putInt(LMK_PROCPURGE);
-                    if (writeLmkdCommandLS(purge_buf) == false) {
-                        // Write failed, skip the rest and retry
-                        continue;
-                    }
-                }
-                if (writeLmkdCommandLS(buf) && (repl == null || readLmkdReplyLS(repl))) {
-                    return true;
-                }
+            // wait for connection retrying 3 times (up to 3 seconds)
+            if (!sLmkdConnection.waitForConnection(3 * LMDK_RECONNECT_DELAY_MS)) {
+                return false;
             }
         }
-        return false;
+
+        return sLmkdConnection.exchange(buf, repl);
     }
 
     static void killProcessGroup(int uid, int pid) {