Merge "Support RDNS on cache and uid/pid tagging"
am: effead990e

Change-Id: I0328380764f1e375886c4661f43708d59db7a370
diff --git a/DnsProxyListener.cpp b/DnsProxyListener.cpp
index ce6081d..74eb793 100644
--- a/DnsProxyListener.cpp
+++ b/DnsProxyListener.cpp
@@ -147,7 +147,7 @@
     return false;
 }
 
-void maybeFixupNetContext(android_net_context* ctx) {
+void maybeFixupNetContext(android_net_context* ctx, pid_t pid) {
     if (requestingUseLocalNameservers(ctx->flags) && !hasPermissionToBypassPrivateDns(ctx->uid)) {
         // Not permitted; clear the flag.
         ctx->flags &= ~NET_CONTEXT_FLAG_USE_LOCAL_NAMESERVERS;
@@ -161,6 +161,7 @@
             ctx->flags |= NET_CONTEXT_FLAG_USE_DNS_OVER_TLS | NET_CONTEXT_FLAG_USE_EDNS;
         }
     }
+    ctx->pid = pid;
 }
 
 void addIpAddrWithinLimit(std::vector<std::string>* ip_addrs, const sockaddr* addr,
@@ -669,7 +670,7 @@
 
     addrinfo* result = nullptr;
     Stopwatch s;
-    maybeFixupNetContext(&mNetContext);
+    maybeFixupNetContext(&mNetContext, mClient->getPid());
     const uid_t uid = mClient->getUid();
     int32_t rv = 0;
     NetworkDnsEventReported event;
@@ -856,7 +857,7 @@
                << mNetContext.dns_mark << " " << mNetContext.uid << " " << mNetContext.flags << "}";
 
     Stopwatch s;
-    maybeFixupNetContext(&mNetContext);
+    maybeFixupNetContext(&mNetContext, mClient->getPid());
 
     // Decode
     std::vector<uint8_t> msg(MAXPACKET, 0);
@@ -1073,7 +1074,7 @@
 
 void DnsProxyListener::GetHostByNameHandler::run() {
     Stopwatch s;
-    maybeFixupNetContext(&mNetContext);
+    maybeFixupNetContext(&mNetContext, mClient->getPid());
     const uid_t uid = mClient->getUid();
     hostent* hp = nullptr;
     hostent hbuf;
@@ -1236,7 +1237,7 @@
 
 void DnsProxyListener::GetHostByAddrHandler::run() {
     Stopwatch s;
-    maybeFixupNetContext(&mNetContext);
+    maybeFixupNetContext(&mNetContext, mClient->getPid());
     const uid_t uid = mClient->getUid();
     hostent* hp = nullptr;
     hostent hbuf;
diff --git a/DnsTlsSocket.cpp b/DnsTlsSocket.cpp
index bb43ed9..6a221fe 100644
--- a/DnsTlsSocket.cpp
+++ b/DnsTlsSocket.cpp
@@ -38,6 +38,7 @@
 #include <netdutils/SocketOption.h>
 #include <netdutils/ThreadUtil.h>
 
+#include "netd_resolv/resolv.h"
 #include "private/android_filesystem_config.h"  // AID_DNS
 #include "resolv_private.h"
 
@@ -95,7 +96,7 @@
         return Status(errno);
     }
 
-    resolv_tag_socket(mSslFd.get(), AID_DNS);
+    resolv_tag_socket(mSslFd.get(), AID_DNS, NET_CONTEXT_INVALID_PID);
 
     const socklen_t len = sizeof(mMark);
     if (setsockopt(mSslFd.get(), SOL_SOCKET, SO_MARK, &mMark, len) == -1) {
diff --git a/getaddrinfo.cpp b/getaddrinfo.cpp
index 46137de..a9550b0 100644
--- a/getaddrinfo.cpp
+++ b/getaddrinfo.cpp
@@ -266,6 +266,7 @@
             .dns_netid = NETID_UNSET,
             .dns_mark = MARK_UNSET,
             .uid = NET_CONTEXT_INVALID_UID,
+            .pid = NET_CONTEXT_INVALID_PID,
     };
     NetworkDnsEventReported event;
     return android_getaddrinfofornetcontext(hostname, servname, &hints, &netcontext, result,
diff --git a/include/netd_resolv/resolv.h b/include/netd_resolv/resolv.h
index afd63f5..22eadeb 100644
--- a/include/netd_resolv/resolv.h
+++ b/include/netd_resolv/resolv.h
@@ -30,6 +30,7 @@
 
 #include "params.h"
 
+#include <arpa/nameser.h>
 #include <netinet/in.h>
 
 /*
@@ -43,6 +44,9 @@
  */
 #define MARK_UNSET 0u
 
+#define NET_CONTEXT_INVALID_UID ((uid_t)-1)
+#define NET_CONTEXT_INVALID_PID ((pid_t)-1)
+
 /*
  * A struct to capture context relevant to network operations.
  *
@@ -59,11 +63,12 @@
     unsigned app_mark;
     unsigned dns_netid;
     unsigned dns_mark;
-    uid_t uid;
+    uid_t uid = NET_CONTEXT_INVALID_UID;
     unsigned flags;
+    // Variable to store the pid of the application sending DNS query.
+    pid_t pid = NET_CONTEXT_INVALID_PID;
 };
 
-#define NET_CONTEXT_INVALID_UID ((uid_t) -1)
 #define NET_CONTEXT_FLAG_USE_LOCAL_NAMESERVERS 0x00000001
 #define NET_CONTEXT_FLAG_USE_EDNS 0x00000002
 #define NET_CONTEXT_FLAG_USE_DNS_OVER_TLS 0x00000004
@@ -74,7 +79,7 @@
 typedef void (*get_network_context_callback)(unsigned netid, uid_t uid,
                                              android_net_context* netcontext);
 typedef void (*log_callback)(const char* msg);
-typedef int (*tagSocketCallback)(int sockFd, uint32_t tag, uid_t uid);
+typedef int (*tagSocketCallback)(int sockFd, uint32_t tag, uid_t uid, pid_t pid);
 
 /*
  * Some functions needed by the resolver (e.g. checkCallingPermission()) live in
@@ -95,3 +100,10 @@
 
 // Set callbacks and bring DnsResolver up.
 LIBNETD_RESOLV_PUBLIC bool resolv_init(const ResolverNetdCallbacks* callbacks);
+
+// Function that performs RDNS in local cache. The |domain_name_size| is the size of domain_name
+// buffer, which is recommended to NS_MAXDNAME. Function return false if hostname not found or
+// domain_name_size > NS_MAXDNAME.
+LIBNETD_RESOLV_PUBLIC bool resolv_gethostbyaddr_from_local_cache(unsigned netId, char domain_name[],
+                                                                 unsigned domain_name_size,
+                                                                 char* ip_address);
diff --git a/include/netd_resolv/resolv_stub.h b/include/netd_resolv/resolv_stub.h
index 41f56fa..d1939d1 100644
--- a/include/netd_resolv/resolv_stub.h
+++ b/include/netd_resolv/resolv_stub.h
@@ -37,6 +37,9 @@
     bool (*resolv_has_nameservers)(unsigned netid);
 
     bool (*resolv_init)(const ResolverNetdCallbacks& callbacks);
+
+    bool (*resolv_gethostbyaddr_from_local_cache)(unsigned netId, char domain_name[],
+                                                  unsigned domain_name_size, char* ip_address);
 } RESOLV_STUB;
 
 int resolv_stub_init();
diff --git a/libnetd_resolv.map.txt b/libnetd_resolv.map.txt
index be193db..4b0fb95 100644
--- a/libnetd_resolv.map.txt
+++ b/libnetd_resolv.map.txt
@@ -22,6 +22,7 @@
   global:
     resolv_has_nameservers;
     resolv_init;
+    resolv_gethostbyaddr_from_local_cache;
   local:
     *;
 };
diff --git a/res_cache.cpp b/res_cache.cpp
index 56b15e7..25ccea8 100644
--- a/res_cache.cpp
+++ b/res_cache.cpp
@@ -1289,6 +1289,85 @@
     return 0;
 }
 
+bool resolv_gethostbyaddr_from_local_cache(unsigned netid, char domain_name[],
+                                           unsigned domain_name_size, char* ip_address) {
+    if (domain_name_size > NS_MAXDNAME) {
+        LOG(ERROR) << __func__ << ": invalid domain_name_size " << domain_name_size;
+        return false;
+    }
+
+    Cache* cache = nullptr;
+    Entry* node = nullptr;
+
+    ns_rr rr;
+    ns_msg handle;
+
+    int query_count;
+    ns_rr rr_query;
+
+    // For ntop.
+    char* resolved_ip = nullptr;
+    char buf4[INET_ADDRSTRLEN];
+    char buf6[INET6_ADDRSTRLEN];
+
+    std::lock_guard guard(cache_mutex);
+
+    cache = find_named_cache_locked(netid);
+    if (cache == nullptr) {
+        return false;
+    }
+
+    for (node = cache->mru_list.mru_next; node != nullptr && node != &cache->mru_list;
+         node = node->mru_next) {
+        if (node->answer == nullptr) {
+            continue;
+        }
+
+        memset(&handle, 0, sizeof(handle));
+
+        if (ns_initparse(node->answer, node->answerlen, &handle) < 0) {
+            continue;
+        }
+
+        for (int n = 0; n < ns_msg_count(handle, ns_s_an); n++) {
+            memset(&rr, 0, sizeof(rr));
+
+            if (ns_parserr(&handle, ns_s_an, n, &rr)) {
+                continue;
+            }
+
+            resolved_ip = nullptr;
+            if (ns_rr_type(rr) == ns_t_a) {
+                inet_ntop(AF_INET, ns_rr_rdata(rr), buf4, sizeof(buf4));
+                resolved_ip = buf4;
+            } else if (ns_rr_type(rr) == ns_t_aaaa) {
+                inet_ntop(AF_INET6, ns_rr_rdata(rr), buf6, sizeof(buf6));
+                resolved_ip = buf6;
+            } else {
+                continue;
+            }
+
+            if ((resolved_ip != nullptr) && (resolved_ip[0] != '\0')) {
+                if (strcmp(resolved_ip, ip_address) == 0) {
+                    query_count = ns_msg_count(handle, ns_s_qd);
+                    for (int i = 0; i < query_count; i++) {
+                        memset(&rr_query, 0, sizeof(rr_query));
+                        if (ns_parserr(&handle, ns_s_qd, i, &rr_query)) {
+                            continue;
+                        }
+                        strlcpy(domain_name, ns_rr_name(rr_query), domain_name_size);
+                        if (domain_name[0] != '\0') {
+                            return true;
+                        }
+                    }
+                }
+            }
+        }
+    }
+
+    return false;
+}
+
 // Head of the list of caches.
 static struct resolv_cache_info res_cache_list GUARDED_BY(cache_mutex);
 
diff --git a/res_init.cpp b/res_init.cpp
index a1619d9..e15127f 100644
--- a/res_init.cpp
+++ b/res_init.cpp
@@ -98,6 +98,7 @@
 
     statp->netid = netcontext->dns_netid;
     statp->uid = netcontext->uid;
+    statp->pid = netcontext->pid;
     statp->id = arc4random_uniform(65536);
     statp->_mark = netcontext->dns_mark;
     statp->netcontext_flags = netcontext->flags;
diff --git a/res_send.cpp b/res_send.cpp
index 1dc5f57..2fba22c 100644
--- a/res_send.cpp
+++ b/res_send.cpp
@@ -715,7 +715,7 @@
                     return -1;
             }
         }
-        resolv_tag_socket(statp->_vcsock, statp->uid);
+        resolv_tag_socket(statp->_vcsock, statp->uid, statp->pid);
         if (statp->_mark != MARK_UNSET) {
             if (setsockopt(statp->_vcsock, SOL_SOCKET, SO_MARK, &statp->_mark,
                            sizeof(statp->_mark)) < 0) {
@@ -958,7 +958,7 @@
             }
         }
 
-        resolv_tag_socket(statp->nssocks[ns], statp->uid);
+        resolv_tag_socket(statp->nssocks[ns], statp->uid, statp->pid);
         if (statp->_mark != MARK_UNSET) {
             if (setsockopt(statp->nssocks[ns], SOL_SOCKET, SO_MARK, &(statp->_mark),
                            sizeof(statp->_mark)) < 0) {
diff --git a/resolv_private.h b/resolv_private.h
index ea7cdc2..1babd84 100644
--- a/resolv_private.h
+++ b/resolv_private.h
@@ -90,6 +90,7 @@
 struct ResState {
     unsigned netid;                           // NetId: cache key and socket mark
     uid_t uid;                                // uid of the app that sent the DNS lookup
+    pid_t pid;                                // pid of the app that sent the DNS lookup
     int nscount;                              // number of name srvers
     uint16_t id;                              // current message id
     std::vector<std::string> search_domains;  // domains to search
@@ -182,9 +183,9 @@
 
 android::net::IpVersion ipFamilyToIPVersion(int ipFamily);
 
-inline void resolv_tag_socket(int sock, uid_t uid) {
+inline void resolv_tag_socket(int sock, uid_t uid, pid_t pid) {
     if (android::net::gResNetdCallbacks.tagSocket != nullptr) {
-        if (int err = android::net::gResNetdCallbacks.tagSocket(sock, TAG_SYSTEM_DNS, uid)) {
+        if (int err = android::net::gResNetdCallbacks.tagSocket(sock, TAG_SYSTEM_DNS, uid, pid)) {
             LOG(WARNING) << "Failed to tag socket: " << strerror(-err);
         }
     }