Merge "libbinder: RPC binder - incl. protocol version"
diff --git a/libs/binder/RpcServer.cpp b/libs/binder/RpcServer.cpp
index 200d923..62ea187 100644
--- a/libs/binder/RpcServer.cpp
+++ b/libs/binder/RpcServer.cpp
@@ -110,6 +110,10 @@
     return mMaxThreads;
 }
 
+void RpcServer::setProtocolVersion(uint32_t version) {
+    mProtocolVersion = version;
+}
+
 void RpcServer::setRootObject(const sp<IBinder>& binder) {
     std::lock_guard<std::mutex> _l(mLock);
     mRootObjectWeak = mRootObject = binder;
@@ -245,13 +249,37 @@
     RpcConnectionHeader header;
     status_t status = server->mShutdownTrigger->interruptableReadFully(clientFd.get(), &header,
                                                                        sizeof(header));
-    bool idValid = status == OK;
-    if (!idValid) {
+    if (status != OK) {
         ALOGE("Failed to read ID for client connecting to RPC server: %s",
               statusToString(status).c_str());
         // still need to cleanup before we can return
     }
-    bool incoming = header.options & RPC_CONNECTION_OPTION_INCOMING;
+
+    bool incoming = false;
+    uint32_t protocolVersion = 0;
+    RpcAddress sessionId = RpcAddress::zero();
+    bool requestingNewSession = false;
+
+    if (status == OK) {
+        incoming = header.options & RPC_CONNECTION_OPTION_INCOMING;
+        protocolVersion = std::min(header.version,
+                                   server->mProtocolVersion.value_or(RPC_WIRE_PROTOCOL_VERSION));
+        sessionId = RpcAddress::fromRawEmbedded(&header.sessionId);
+        requestingNewSession = sessionId.isZero();
+
+        if (requestingNewSession) {
+            RpcNewSessionResponse response{
+                    .version = protocolVersion,
+            };
+
+            status = server->mShutdownTrigger->interruptableWriteFully(clientFd.get(), &response,
+                                                                       sizeof(response));
+            if (status != OK) {
+                ALOGE("Failed to send new session response: %s", statusToString(status).c_str());
+                // still need to cleanup before we can return
+            }
+        }
+    }
 
     std::thread thisThread;
     sp<RpcSession> session;
@@ -269,19 +297,16 @@
         };
         server->mConnectingThreads.erase(threadId);
 
-        if (!idValid || server->mShutdownTrigger->isTriggered()) {
+        if (status != OK || server->mShutdownTrigger->isTriggered()) {
             return;
         }
 
-        RpcAddress sessionId = RpcAddress::fromRawEmbedded(&header.sessionId);
-
-        if (sessionId.isZero()) {
+        if (requestingNewSession) {
             if (incoming) {
                 ALOGE("Cannot create a new session with an incoming connection, would leak");
                 return;
             }
 
-            sessionId = RpcAddress::zero();
             size_t tries = 0;
             do {
                 // don't block if there is some entropy issue
@@ -295,6 +320,7 @@
 
             session = RpcSession::make();
             session->setMaxThreads(server->mMaxThreads);
+            if (!session->setProtocolVersion(protocolVersion)) return;
             if (!session->setForServer(server,
                                        sp<RpcServer::EventListener>::fromExisting(
                                                static_cast<RpcServer::EventListener*>(
diff --git a/libs/binder/RpcSession.cpp b/libs/binder/RpcSession.cpp
index 1c37651..90ce4d6 100644
--- a/libs/binder/RpcSession.cpp
+++ b/libs/binder/RpcSession.cpp
@@ -77,6 +77,25 @@
     return mMaxThreads;
 }
 
+bool RpcSession::setProtocolVersion(uint32_t version) {
+    if (version >= RPC_WIRE_PROTOCOL_VERSION_NEXT &&
+        version != RPC_WIRE_PROTOCOL_VERSION_EXPERIMENTAL) {
+        ALOGE("Cannot start RPC session with version %u which is unknown (current protocol version "
+              "is %u).",
+              version, RPC_WIRE_PROTOCOL_VERSION);
+        return false;
+    }
+
+    std::lock_guard<std::mutex> _l(mMutex);
+    mProtocolVersion = version;
+    return true;
+}
+
+std::optional<uint32_t> RpcSession::getProtocolVersion() {
+    std::lock_guard<std::mutex> _l(mMutex);
+    return mProtocolVersion;
+}
+
 bool RpcSession::setupUnixDomainClient(const char* path) {
     return setupSocketClient(UnixSocketAddress(path));
 }
@@ -424,6 +443,18 @@
 
     if (!setupOneSocketConnection(addr, RpcAddress::zero(), false /*incoming*/)) return false;
 
+    {
+        ExclusiveConnection connection;
+        status_t status = ExclusiveConnection::find(sp<RpcSession>::fromExisting(this),
+                                                    ConnectionUse::CLIENT, &connection);
+        if (status != OK) return false;
+
+        uint32_t version;
+        status = state()->readNewSessionResponse(connection.get(),
+                                                 sp<RpcSession>::fromExisting(this), &version);
+        if (!setProtocolVersion(version)) return false;
+    }
+
     // TODO(b/189955605): we should add additional sessions dynamically
     // instead of all at once.
     // TODO(b/186470974): first risk of blocking
@@ -484,7 +515,10 @@
             return false;
         }
 
-        RpcConnectionHeader header{.options = 0};
+        RpcConnectionHeader header{
+                .version = mProtocolVersion.value_or(RPC_WIRE_PROTOCOL_VERSION),
+                .options = 0,
+        };
         memcpy(&header.sessionId, &id.viewRawEmbedded(), sizeof(RpcWireAddress));
 
         if (incoming) header.options |= RPC_CONNECTION_OPTION_INCOMING;
diff --git a/libs/binder/RpcState.cpp b/libs/binder/RpcState.cpp
index 332c75f..f3406bb 100644
--- a/libs/binder/RpcState.cpp
+++ b/libs/binder/RpcState.cpp
@@ -315,6 +315,18 @@
     return OK;
 }
 
+status_t RpcState::readNewSessionResponse(const sp<RpcSession::RpcConnection>& connection,
+                                          const sp<RpcSession>& session, uint32_t* version) {
+    RpcNewSessionResponse response;
+    if (status_t status =
+                rpcRec(connection, session, "new session response", &response, sizeof(response));
+        status != OK) {
+        return status;
+    }
+    *version = response.version;
+    return OK;
+}
+
 status_t RpcState::sendConnectionInit(const sp<RpcSession::RpcConnection>& connection,
                                       const sp<RpcSession>& session) {
     RpcOutgoingConnectionInit init{
diff --git a/libs/binder/RpcState.h b/libs/binder/RpcState.h
index 5ac0b97..1446eec 100644
--- a/libs/binder/RpcState.h
+++ b/libs/binder/RpcState.h
@@ -60,6 +60,8 @@
     RpcState();
     ~RpcState();
 
+    status_t readNewSessionResponse(const sp<RpcSession::RpcConnection>& connection,
+                                    const sp<RpcSession>& session, uint32_t* version);
     status_t sendConnectionInit(const sp<RpcSession::RpcConnection>& connection,
                                 const sp<RpcSession>& session);
     status_t readConnectionInit(const sp<RpcSession::RpcConnection>& connection,
diff --git a/libs/binder/RpcWireFormat.h b/libs/binder/RpcWireFormat.h
index 2a44c7a..0f8efd2 100644
--- a/libs/binder/RpcWireFormat.h
+++ b/libs/binder/RpcWireFormat.h
@@ -37,9 +37,20 @@
  * either as part of a new session or an existing session
  */
 struct RpcConnectionHeader {
+    uint32_t version; // maximum supported by caller
+    uint8_t reserver0[4];
     RpcWireAddress sessionId;
     uint8_t options;
-    uint8_t reserved[7];
+    uint8_t reserved1[7];
+};
+
+/**
+ * In response to an RpcConnectionHeader which corresponds to a new session,
+ * this returns information to the server.
+ */
+struct RpcNewSessionResponse {
+    uint32_t version; // maximum supported by callee <= maximum supported by caller
+    uint8_t reserved[4];
 };
 
 #define RPC_CONNECTION_INIT_OKAY "cci"
diff --git a/libs/binder/include/binder/RpcServer.h b/libs/binder/include/binder/RpcServer.h
index a8094dd..40ff78c 100644
--- a/libs/binder/include/binder/RpcServer.h
+++ b/libs/binder/include/binder/RpcServer.h
@@ -105,6 +105,13 @@
     size_t getMaxThreads();
 
     /**
+     * By default, the latest protocol version which is supported by a client is
+     * used. However, this can be used in order to prevent newer protocol
+     * versions from ever being used. This is expected to be useful for testing.
+     */
+    void setProtocolVersion(uint32_t version);
+
+    /**
      * The root object can be retrieved by any client, without any
      * authentication. TODO(b/183988761)
      *
@@ -164,6 +171,7 @@
 
     bool mAgreedExperimental = false;
     size_t mMaxThreads = 1;
+    std::optional<uint32_t> mProtocolVersion;
     base::unique_fd mServer; // socket we are accepting sessions on
 
     std::mutex mLock; // for below
diff --git a/libs/binder/include/binder/RpcSession.h b/libs/binder/include/binder/RpcSession.h
index 2101df8..1f7c029 100644
--- a/libs/binder/include/binder/RpcSession.h
+++ b/libs/binder/include/binder/RpcSession.h
@@ -37,6 +37,10 @@
 class RpcSocketAddress;
 class RpcState;
 
+constexpr uint32_t RPC_WIRE_PROTOCOL_VERSION_NEXT = 0;
+constexpr uint32_t RPC_WIRE_PROTOCOL_VERSION_EXPERIMENTAL = 0xF0000000;
+constexpr uint32_t RPC_WIRE_PROTOCOL_VERSION = RPC_WIRE_PROTOCOL_VERSION_EXPERIMENTAL;
+
 /**
  * This represents a session (group of connections) between a client
  * and a server. Multiple connections are needed for multiple parallel "binder"
@@ -60,6 +64,13 @@
     size_t getMaxThreads();
 
     /**
+     * By default, the minimum of the supported versions of the client and the
+     * server will be used. Usually, this API should only be used for debugging.
+     */
+    [[nodiscard]] bool setProtocolVersion(uint32_t version);
+    std::optional<uint32_t> getProtocolVersion();
+
+    /**
      * This should be called once per thread, matching 'join' in the remote
      * process.
      */
@@ -291,6 +302,7 @@
     std::mutex mMutex; // for all below
 
     size_t mMaxThreads = 0;
+    std::optional<uint32_t> mProtocolVersion;
 
     std::condition_variable mAvailableConnectionCv; // for mWaitingThreads
     size_t mWaitingThreads = 0;
diff --git a/libs/binder/tests/binderRpcTest.cpp b/libs/binder/tests/binderRpcTest.cpp
index 40ebd9c..d5786bc 100644
--- a/libs/binder/tests/binderRpcTest.cpp
+++ b/libs/binder/tests/binderRpcTest.cpp
@@ -47,6 +47,9 @@
 
 namespace android {
 
+static_assert(RPC_WIRE_PROTOCOL_VERSION + 1 == RPC_WIRE_PROTOCOL_VERSION_NEXT ||
+              RPC_WIRE_PROTOCOL_VERSION == RPC_WIRE_PROTOCOL_VERSION_EXPERIMENTAL);
+
 TEST(BinderRpcParcel, EntireParcelFormatted) {
     Parcel p;
     p.writeInt32(3);
@@ -67,6 +70,19 @@
     ASSERT_EQ(sinkFd, retrieved.get());
 }
 
+TEST(BinderRpc, CannotUseNextWireVersion) {
+    auto session = RpcSession::make();
+    EXPECT_FALSE(session->setProtocolVersion(RPC_WIRE_PROTOCOL_VERSION_NEXT));
+    EXPECT_FALSE(session->setProtocolVersion(RPC_WIRE_PROTOCOL_VERSION_NEXT + 1));
+    EXPECT_FALSE(session->setProtocolVersion(RPC_WIRE_PROTOCOL_VERSION_NEXT + 2));
+    EXPECT_FALSE(session->setProtocolVersion(RPC_WIRE_PROTOCOL_VERSION_NEXT + 15));
+}
+
+TEST(BinderRpc, CanUseExperimentalWireVersion) {
+    auto session = RpcSession::make();
+    EXPECT_TRUE(session->setProtocolVersion(RPC_WIRE_PROTOCOL_VERSION_EXPERIMENTAL));
+}
+
 using android::binder::Status;
 
 #define EXPECT_OK(status)                 \