Merge "Support variety of IPv6 addresses on RDNS-on-cache"
diff --git a/include/netd_resolv/resolv.h b/include/netd_resolv/resolv.h
index 22eadeb..6e00df6 100644
--- a/include/netd_resolv/resolv.h
+++ b/include/netd_resolv/resolv.h
@@ -104,6 +104,6 @@
 // Function that performs RDNS in local cache. The |domain_name_size| is the size of domain_name
 // buffer, which is recommended to NS_MAXDNAME. Function return false if hostname not found or
 // domain_name_size > NS_MAXDNAME.
-LIBNETD_RESOLV_PUBLIC bool resolv_gethostbyaddr_from_local_cache(unsigned netId, char domain_name[],
-                                                                 unsigned domain_name_size,
-                                                                 char* ip_address);
+LIBNETD_RESOLV_PUBLIC bool resolv_gethostbyaddr_from_cache(unsigned netId, char domain_name[],
+                                                           size_t domain_name_size,
+                                                           const char* ip_address, int af);
diff --git a/include/netd_resolv/resolv_stub.h b/include/netd_resolv/resolv_stub.h
index d1939d1..4166ed5 100644
--- a/include/netd_resolv/resolv_stub.h
+++ b/include/netd_resolv/resolv_stub.h
@@ -38,8 +38,9 @@
 
     bool (*resolv_init)(const ResolverNetdCallbacks& callbacks);
 
-    bool (*resolv_gethostbyaddr_from_local_cache)(unsigned netId, char domain_name[],
-                                                  unsigned domain_name_size, char* ip_address);
+    bool (*resolv_gethostbyaddr_from_cache)(unsigned netId, char domain_name[],
+                                            size_t domain_name_size, const char* ip_address,
+                                            int af);
 } RESOLV_STUB;
 
 int resolv_stub_init();
diff --git a/libnetd_resolv.map.txt b/libnetd_resolv.map.txt
index 4b0fb95..06e67aa 100644
--- a/libnetd_resolv.map.txt
+++ b/libnetd_resolv.map.txt
@@ -22,7 +22,7 @@
   global:
     resolv_has_nameservers;
     resolv_init;
-    resolv_gethostbyaddr_from_local_cache;
+    resolv_gethostbyaddr_from_cache;
   local:
     *;
 };
diff --git a/res_cache.cpp b/res_cache.cpp
index de11363..68affe0 100644
--- a/res_cache.cpp
+++ b/res_cache.cpp
@@ -1293,10 +1293,16 @@
     return 0;
 }
 
-bool resolv_gethostbyaddr_from_local_cache(unsigned netid, char domain_name[],
-                                           unsigned domain_name_size, char* ip_address) {
+bool resolv_gethostbyaddr_from_cache(unsigned netid, char domain_name[], size_t domain_name_size,
+                                     const char* ip_address, int af) {
     if (domain_name_size > NS_MAXDNAME) {
-        LOG(ERROR) << __func__ << ": invalid domain_name_size " << domain_name_size;
+        LOG(WARNING) << __func__ << ": invalid domain_name_size " << domain_name_size;
+        return false;
+    } else if (ip_address == nullptr || ip_address[0] == '\0') {
+        LOG(WARNING) << __func__ << ": invalid ip_address";
+        return false;
+    } else if (af != AF_INET && af != AF_INET6) {
+        LOG(WARNING) << __func__ << ": unsupported AF";
         return false;
     }
 
@@ -1305,14 +1311,11 @@
 
     ns_rr rr;
     ns_msg handle;
-
-    int query_count;
     ns_rr rr_query;
 
-    // For ntop.
-    char* resolved_ip = nullptr;
-    char buf4[INET_ADDRSTRLEN];
-    char buf6[INET6_ADDRSTRLEN];
+    struct sockaddr_in sa;
+    struct sockaddr_in6 sa6;
+    char* addr_buf = nullptr;
 
     std::lock_guard guard(cache_mutex);
 
@@ -1340,29 +1343,29 @@
                 continue;
             }
 
-            resolved_ip = nullptr;
-            if (ns_rr_type(rr) == ns_t_a) {
-                inet_ntop(AF_INET, ns_rr_rdata(rr), buf4, sizeof(buf4));
-                resolved_ip = buf4;
-            } else if (ns_rr_type(rr) == ns_t_aaaa) {
-                inet_ntop(AF_INET6, ns_rr_rdata(rr), buf6, sizeof(buf6));
-                resolved_ip = buf6;
+            if (ns_rr_type(rr) == ns_t_a && af == AF_INET) {
+                addr_buf = (char*)&(sa.sin_addr);
+            } else if (ns_rr_type(rr) == ns_t_aaaa && af == AF_INET6) {
+                addr_buf = (char*)&(sa6.sin6_addr);
             } else {
                 continue;
             }
 
-            if ((resolved_ip != nullptr) && (resolved_ip[0] != '\0')) {
-                if (strcmp(resolved_ip, ip_address) == 0) {
-                    query_count = ns_msg_count(handle, ns_s_qd);
-                    for (int i = 0; i < query_count; i++) {
-                        memset(&rr_query, 0, sizeof(rr_query));
-                        if (ns_parserr(&handle, ns_s_qd, i, &rr_query)) {
-                            continue;
-                        }
-                        strlcpy(domain_name, ns_rr_name(rr_query), domain_name_size);
-                        if (domain_name[0] != '\0') {
-                            return true;
-                        }
+            if (inet_pton(af, ip_address, addr_buf) != 1) {
+                LOG(WARNING) << __func__ << ": inet_pton() fail";
+                return false;
+            }
+
+            if (memcmp(ns_rr_rdata(rr), addr_buf, ns_rr_rdlen(rr)) == 0) {
+                int query_count = ns_msg_count(handle, ns_s_qd);
+                for (int i = 0; i < query_count; i++) {
+                    memset(&rr_query, 0, sizeof(rr_query));
+                    if (ns_parserr(&handle, ns_s_qd, i, &rr_query)) {
+                        continue;
+                    }
+                    strlcpy(domain_name, ns_rr_name(rr_query), domain_name_size);
+                    if (domain_name[0] != '\0') {
+                        return true;
                     }
                 }
             }
diff --git a/resolv_cache_unit_test.cpp b/resolv_cache_unit_test.cpp
index 1b05ac0..9f4d0af 100644
--- a/resolv_cache_unit_test.cpp
+++ b/resolv_cache_unit_test.cpp
@@ -708,6 +708,85 @@
     expectCacheStats("GetStats", TEST_NETID, cacheStats);
 }
 
+TEST_F(ResolvCacheTest, GetHostByAddrFromCache_InvalidArgs) {
+    char domain_name[NS_MAXDNAME] = {};
+    const char query_v4[] = "1.2.3.5";
+
+    // invalid buffer size
+    EXPECT_FALSE(resolv_gethostbyaddr_from_cache(TEST_NETID, domain_name, NS_MAXDNAME + 1, nullptr,
+                                                 AF_INET));
+    EXPECT_STREQ("", domain_name);
+
+    // invalid query
+    EXPECT_FALSE(resolv_gethostbyaddr_from_cache(TEST_NETID, domain_name, NS_MAXDNAME, nullptr,
+                                                 AF_INET));
+    EXPECT_STREQ("", domain_name);
+
+    // unsupported AF
+    EXPECT_FALSE(resolv_gethostbyaddr_from_cache(TEST_NETID, domain_name, NS_MAXDNAME, query_v4,
+                                                 AF_UNSPEC));
+    EXPECT_STREQ("", domain_name);
+}
+
+TEST_F(ResolvCacheTest, GetHostByAddrFromCache) {
+    char domain_name[NS_MAXDNAME] = {};
+    const char query_v4[] = "1.2.3.5";
+    const char query_v6[] = "2001:db8::102:304";
+    const char query_v6_unabbreviated[] = "2001:0db8:0000:0000:0000:0000:0102:0304";
+    const char query_v6_mixed[] = "2001:db8::1.2.3.4";
+    const char answer[] = "existent.in.cache";
+
+    // cache does not exist
+    EXPECT_FALSE(resolv_gethostbyaddr_from_cache(TEST_NETID, domain_name, NS_MAXDNAME, query_v4,
+                                                 AF_INET));
+    EXPECT_STREQ("", domain_name);
+
+    // cache is empty
+    EXPECT_EQ(0, cacheCreate(TEST_NETID));
+    EXPECT_FALSE(resolv_gethostbyaddr_from_cache(TEST_NETID, domain_name, NS_MAXDNAME, query_v4,
+                                                 AF_INET));
+    EXPECT_STREQ("", domain_name);
+
+    // no v4 match in cache
+    CacheEntry ce = makeCacheEntry(QUERY, "any.data", ns_c_in, ns_t_a, "1.2.3.4");
+    EXPECT_EQ(0, cacheAdd(TEST_NETID, ce));
+    EXPECT_FALSE(resolv_gethostbyaddr_from_cache(TEST_NETID, domain_name, NS_MAXDNAME, query_v4,
+                                                 AF_INET));
+    EXPECT_STREQ("", domain_name);
+
+    // v4 match
+    ce = makeCacheEntry(QUERY, answer, ns_c_in, ns_t_a, query_v4);
+    EXPECT_EQ(0, cacheAdd(TEST_NETID, ce));
+    EXPECT_TRUE(resolv_gethostbyaddr_from_cache(TEST_NETID, domain_name, NS_MAXDNAME, query_v4,
+                                                AF_INET));
+    EXPECT_STREQ(answer, domain_name);
+
+    // no v6 match in cache
+    memset(domain_name, 0, NS_MAXDNAME);
+    EXPECT_FALSE(resolv_gethostbyaddr_from_cache(TEST_NETID, domain_name, NS_MAXDNAME, query_v6,
+                                                 AF_INET6));
+    EXPECT_STREQ("", domain_name);
+
+    // v6 match
+    ce = makeCacheEntry(QUERY, answer, ns_c_in, ns_t_aaaa, query_v6);
+    EXPECT_EQ(0, cacheAdd(TEST_NETID, ce));
+    EXPECT_TRUE(resolv_gethostbyaddr_from_cache(TEST_NETID, domain_name, NS_MAXDNAME, query_v6,
+                                                AF_INET6));
+    EXPECT_STREQ(answer, domain_name);
+
+    // v6 match with unabbreviated address format
+    memset(domain_name, 0, NS_MAXDNAME);
+    EXPECT_TRUE(resolv_gethostbyaddr_from_cache(TEST_NETID, domain_name, NS_MAXDNAME,
+                                                query_v6_unabbreviated, AF_INET6));
+    EXPECT_STREQ(answer, domain_name);
+
+    // v6 with mixed address format
+    memset(domain_name, 0, NS_MAXDNAME);
+    EXPECT_TRUE(resolv_gethostbyaddr_from_cache(TEST_NETID, domain_name, NS_MAXDNAME,
+                                                query_v6_mixed, AF_INET6));
+    EXPECT_STREQ(answer, domain_name);
+}
+
 namespace {
 
 constexpr int EAI_OK = 0;