Switch LocalSocket to android::base::{Send,Receive}FileDescriptorVector.
The previous implementation allocated an array of size
CMSG_SPACE(count) to store CMSG_LEN(count * sizeof(int)) elements, which
leads to bad things happening for values of count greater than 1 on
32-bit, and 2 on 64-bit.
Test: atest android.net.LocalSocketTest
Test: atest android.net.cts.LocalSocketTest
Change-Id: I0a9502c3358d8fa92d2d20e344c6270d6baedc07
diff --git a/core/jni/android_net_LocalSocketImpl.cpp b/core/jni/android_net_LocalSocketImpl.cpp
index a1f2377..1163b86 100644
--- a/core/jni/android_net_LocalSocketImpl.cpp
+++ b/core/jni/android_net_LocalSocketImpl.cpp
@@ -33,14 +33,16 @@
#include <unistd.h>
#include <sys/ioctl.h>
+#include <android-base/cmsg.h>
+#include <android-base/macros.h>
#include <cutils/sockets.h>
#include <netinet/tcp.h>
#include <nativehelper/ScopedUtfChars.h>
-namespace android {
+using android::base::ReceiveFileDescriptorVector;
+using android::base::SendFileDescriptorVector;
-template <typename T>
-void UNUSED(T t) {}
+namespace android {
static jfieldID field_inboundFileDescriptors;
static jfieldID field_outboundFileDescriptors;
@@ -118,67 +120,6 @@
}
/**
- * Processes ancillary data, handling only
- * SCM_RIGHTS. Creates appropriate objects and sets appropriate
- * fields in the LocalSocketImpl object. Returns 0 on success
- * or -1 if an exception was thrown.
- */
-static int socket_process_cmsg(JNIEnv *env, jobject thisJ, struct msghdr * pMsg)
-{
- struct cmsghdr *cmsgptr;
-
- for (cmsgptr = CMSG_FIRSTHDR(pMsg);
- cmsgptr != NULL; cmsgptr = CMSG_NXTHDR(pMsg, cmsgptr)) {
-
- if (cmsgptr->cmsg_level != SOL_SOCKET) {
- continue;
- }
-
- if (cmsgptr->cmsg_type == SCM_RIGHTS) {
- int *pDescriptors = (int *)CMSG_DATA(cmsgptr);
- jobjectArray fdArray;
- int count
- = ((cmsgptr->cmsg_len - CMSG_LEN(0)) / sizeof(int));
-
- if (count < 0) {
- jniThrowException(env, "java/io/IOException",
- "invalid cmsg length");
- return -1;
- }
-
- fdArray = env->NewObjectArray(count, class_FileDescriptor, NULL);
-
- if (fdArray == NULL) {
- return -1;
- }
-
- for (int i = 0; i < count; i++) {
- jobject fdObject
- = jniCreateFileDescriptor(env, pDescriptors[i]);
-
- if (env->ExceptionCheck()) {
- return -1;
- }
-
- env->SetObjectArrayElement(fdArray, i, fdObject);
-
- if (env->ExceptionCheck()) {
- return -1;
- }
- }
-
- env->SetObjectField(thisJ, field_inboundFileDescriptors, fdArray);
-
- if (env->ExceptionCheck()) {
- return -1;
- }
- }
- }
-
- return 0;
-}
-
-/**
* Reads data from a socket into buf, processing any ancillary data
* and adding it to thisJ.
*
@@ -189,47 +130,48 @@
void *buffer, size_t len)
{
ssize_t ret;
- struct msghdr msg;
- struct iovec iv;
- unsigned char *buf = (unsigned char *)buffer;
- // Enough buffer for a pile of fd's. We throw an exception if
- // this buffer is too small.
- struct cmsghdr cmsgbuf[2*sizeof(cmsghdr) + 0x100];
+ std::vector<android::base::unique_fd> received_fds;
- memset(&msg, 0, sizeof(msg));
- memset(&iv, 0, sizeof(iv));
-
- iv.iov_base = buf;
- iv.iov_len = len;
-
- msg.msg_iov = &iv;
- msg.msg_iovlen = 1;
- msg.msg_control = cmsgbuf;
- msg.msg_controllen = sizeof(cmsgbuf);
-
- ret = TEMP_FAILURE_RETRY(recvmsg(fd, &msg, MSG_NOSIGNAL | MSG_CMSG_CLOEXEC));
-
- if (ret < 0 && errno == EPIPE) {
- // Treat this as an end of stream
- return 0;
- }
+ ret = ReceiveFileDescriptorVector(fd, buffer, len, 64, &received_fds);
if (ret < 0) {
+ if (errno == EPIPE) {
+ // Treat this as an end of stream
+ return 0;
+ }
+
jniThrowIOException(env, errno);
return -1;
}
- if ((msg.msg_flags & (MSG_CTRUNC | MSG_OOB | MSG_ERRQUEUE)) != 0) {
- // To us, any of the above flags are a fatal error
+ if (received_fds.size() > 0) {
+ jobjectArray fdArray = env->NewObjectArray(received_fds.size(), class_FileDescriptor, NULL);
- jniThrowException(env, "java/io/IOException",
- "Unexpected error or truncation during recvmsg()");
+ if (fdArray == NULL) {
+ // NewObjectArray has thrown.
+ return -1;
+ }
- return -1;
- }
+ for (size_t i = 0; i < received_fds.size(); i++) {
+ jobject fdObject = jniCreateFileDescriptor(env, received_fds[i].get());
- if (ret >= 0) {
- socket_process_cmsg(env, thisJ, &msg);
+ if (env->ExceptionCheck()) {
+ return -1;
+ }
+
+ env->SetObjectArrayElement(fdArray, i, fdObject);
+
+ if (env->ExceptionCheck()) {
+ return -1;
+ }
+ }
+
+ for (auto &fd : received_fds) {
+ // The fds are stored in java.io.FileDescriptors now.
+ static_cast<void>(fd.release());
+ }
+
+ env->SetObjectField(thisJ, field_inboundFileDescriptors, fdArray);
}
return ret;
@@ -243,7 +185,6 @@
static int socket_write_all(JNIEnv *env, jobject object, int fd,
void *buf, size_t len)
{
- ssize_t ret;
struct msghdr msg;
unsigned char *buffer = (unsigned char *)buf;
memset(&msg, 0, sizeof(msg));
@@ -256,14 +197,11 @@
return -1;
}
- struct cmsghdr *cmsg;
int countFds = outboundFds == NULL ? 0 : env->GetArrayLength(outboundFds);
- int fds[countFds];
- char msgbuf[CMSG_SPACE(countFds)];
+ std::vector<int> fds;
// Add any pending outbound file descriptors to the message
if (outboundFds != NULL) {
-
if (env->ExceptionCheck()) {
return -1;
}
@@ -274,47 +212,25 @@
return -1;
}
- fds[i] = jniGetFDFromFileDescriptor(env, fdObject);
+ fds.push_back(jniGetFDFromFileDescriptor(env, fdObject));
if (env->ExceptionCheck()) {
return -1;
}
}
-
- // See "man cmsg" really
- msg.msg_control = msgbuf;
- msg.msg_controllen = sizeof msgbuf;
- cmsg = CMSG_FIRSTHDR(&msg);
- cmsg->cmsg_level = SOL_SOCKET;
- cmsg->cmsg_type = SCM_RIGHTS;
- cmsg->cmsg_len = CMSG_LEN(sizeof fds);
- memcpy(CMSG_DATA(cmsg), fds, sizeof fds);
}
- // We only write our msg_control during the first write
- while (len > 0) {
- struct iovec iv;
- memset(&iv, 0, sizeof(iv));
+ ssize_t rc = SendFileDescriptorVector(fd, buffer, len, fds);
- iv.iov_base = buffer;
- iv.iov_len = len;
-
- msg.msg_iov = &iv;
- msg.msg_iovlen = 1;
-
- do {
- ret = sendmsg(fd, &msg, MSG_NOSIGNAL);
- } while (ret < 0 && errno == EINTR);
-
- if (ret < 0) {
+ while (rc != len) {
+ if (rc == -1) {
jniThrowIOException(env, errno);
return -1;
}
- buffer += ret;
- len -= ret;
+ buffer += rc;
+ len -= rc;
- // Wipes out any msg_control too
- memset(&msg, 0, sizeof(msg));
+ rc = send(fd, buffer, len, MSG_NOSIGNAL);
}
return 0;