Add mDNS device discovery for adb client
am: 3122cdf468

Change-Id: Ib5f80b3ccc9db7c6fe6f5c989e82083d734ce4df
diff --git a/Android.mk b/Android.mk
index 3ad9f0f..b5b03ea 100644
--- a/Android.mk
+++ b/Android.mk
@@ -144,6 +144,7 @@
 LOCAL_SRC_FILES := \
     $(LIBADB_SRC_FILES) \
     adb_auth_host.cpp \
+    transport_mdns.cpp
 
 LOCAL_SRC_FILES_darwin := $(LIBADB_darwin_SRC_FILES)
 LOCAL_SRC_FILES_linux := $(LIBADB_linux_SRC_FILES)
@@ -153,7 +154,7 @@
 
 # Even though we're building a static library (and thus there's no link step for
 # this to take effect), this adds the includes to our path.
-LOCAL_STATIC_LIBRARIES := libcrypto_utils libcrypto libbase
+LOCAL_STATIC_LIBRARIES := libcrypto_utils libcrypto libbase libmdnssd
 LOCAL_STATIC_LIBRARIES_linux := libusb
 LOCAL_STATIC_LIBRARIES_darwin := libusb
 
@@ -224,6 +225,7 @@
     libcutils \
     libdiagnose_usb \
     libgmock_host \
+    libmdnssd \
 
 LOCAL_STATIC_LIBRARIES_linux := libusb
 LOCAL_STATIC_LIBRARIES_darwin := libusb
@@ -291,6 +293,7 @@
     libcrypto \
     libdiagnose_usb \
     liblog \
+    libmdnssd \
 
 # Don't use libcutils on Windows.
 LOCAL_STATIC_LIBRARIES_darwin := libcutils
diff --git a/adb_mdns.h b/adb_mdns.h
new file mode 100644
index 0000000..2e544d7
--- /dev/null
+++ b/adb_mdns.h
@@ -0,0 +1,22 @@
+/*
+ * Copyright (C) 2016 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef _ADB_MDNS_H_
+#define _ADB_MDNS_H_
+
+const char* kADBServiceType = "_adb._tcp";
+
+#endif
diff --git a/client/main.cpp b/client/main.cpp
index 97a54fd..606203c 100644
--- a/client/main.cpp
+++ b/client/main.cpp
@@ -117,6 +117,8 @@
 
     init_transport_registration();
 
+    init_mdns_transport_discovery();
+
     usb_init();
     local_init(DEFAULT_ADB_LOCAL_TRANSPORT_PORT);
 
diff --git a/daemon/mdns.cpp b/daemon/mdns.cpp
index a8622ae..7811143 100644
--- a/daemon/mdns.cpp
+++ b/daemon/mdns.cpp
@@ -14,6 +14,7 @@
  * limitations under the License.
  */
 
+#include "adb_mdns.h"
 #include "sysdeps.h"
 
 #include <chrono>
diff --git a/services.cpp b/services.cpp
index a48d855..47f0a03 100644
--- a/services.cpp
+++ b/services.cpp
@@ -377,45 +377,6 @@
     D("wait_for_state is done");
 }
 
-static void connect_device(const std::string& address, std::string* response) {
-    if (address.empty()) {
-        *response = "empty address";
-        return;
-    }
-
-    std::string serial;
-    std::string host;
-    int port = DEFAULT_ADB_LOCAL_TRANSPORT_PORT;
-    if (!android::base::ParseNetAddress(address, &host, &port, &serial, response)) {
-        return;
-    }
-
-    std::string error;
-    int fd = network_connect(host.c_str(), port, SOCK_STREAM, 10, &error);
-    if (fd == -1) {
-        *response = android::base::StringPrintf("unable to connect to %s: %s",
-                                                serial.c_str(), error.c_str());
-        return;
-    }
-
-    D("client: connected %s remote on fd %d", serial.c_str(), fd);
-    close_on_exec(fd);
-    disable_tcp_nagle(fd);
-
-    // Send a TCP keepalive ping to the device every second so we can detect disconnects.
-    if (!set_tcp_keepalive(fd, 1)) {
-        D("warning: failed to configure TCP keepalives (%s)", strerror(errno));
-    }
-
-    int ret = register_socket_transport(fd, serial.c_str(), port, 0);
-    if (ret < 0) {
-        adb_close(fd);
-        *response = android::base::StringPrintf("already connected to %s", serial.c_str());
-    } else {
-        *response = android::base::StringPrintf("connected to %s", serial.c_str());
-    }
-}
-
 void connect_emulator(const std::string& port_spec, std::string* response) {
     std::vector<std::string> pieces = android::base::Split(port_spec, ",");
     if (pieces.size() != 2) {
diff --git a/transport.h b/transport.h
index 490e513..4d97fc7 100644
--- a/transport.h
+++ b/transport.h
@@ -187,6 +187,7 @@
 void update_transports(void);
 
 void init_transport_registration(void);
+void init_mdns_transport_discovery(void);
 std::string list_transports(bool long_listing);
 atransport* find_transport(const char* serial);
 void kick_all_tcp_devices();
@@ -194,6 +195,9 @@
 void register_usb_transport(usb_handle* h, const char* serial,
                             const char* devpath, unsigned writeable);
 
+/* Connect to a network address and register it as a device */
+void connect_device(const std::string& address, std::string* response);
+
 /* cause new transports to be init'd and added to the list */
 int register_socket_transport(int s, const char* serial, int port, int local);
 
diff --git a/transport_local.cpp b/transport_local.cpp
index c17f869..4198a52 100644
--- a/transport_local.cpp
+++ b/transport_local.cpp
@@ -30,6 +30,7 @@
 #include <thread>
 #include <vector>
 
+#include <android-base/parsenetaddress.h>
 #include <android-base/stringprintf.h>
 #include <cutils/sockets.h>
 
@@ -101,6 +102,46 @@
     return local_connect_arbitrary_ports(port-1, port, &dummy) == 0;
 }
 
+void connect_device(const std::string& address, std::string* response) {
+    if (address.empty()) {
+        *response = "empty address";
+        return;
+    }
+
+    std::string serial;
+    std::string host;
+    int port = DEFAULT_ADB_LOCAL_TRANSPORT_PORT;
+    if (!android::base::ParseNetAddress(address, &host, &port, &serial, response)) {
+        return;
+    }
+
+    std::string error;
+    int fd = network_connect(host.c_str(), port, SOCK_STREAM, 10, &error);
+    if (fd == -1) {
+        *response = android::base::StringPrintf("unable to connect to %s: %s",
+                                                serial.c_str(), error.c_str());
+        return;
+    }
+
+    D("client: connected %s remote on fd %d", serial.c_str(), fd);
+    close_on_exec(fd);
+    disable_tcp_nagle(fd);
+
+    // Send a TCP keepalive ping to the device every second so we can detect disconnects.
+    if (!set_tcp_keepalive(fd, 1)) {
+        D("warning: failed to configure TCP keepalives (%s)", strerror(errno));
+    }
+
+    int ret = register_socket_transport(fd, serial.c_str(), port, 0);
+    if (ret < 0) {
+        adb_close(fd);
+        *response = android::base::StringPrintf("already connected to %s", serial.c_str());
+    } else {
+        *response = android::base::StringPrintf("connected to %s", serial.c_str());
+    }
+}
+
+
 int local_connect_arbitrary_ports(int console_port, int adb_port, std::string* error) {
     int fd = -1;
 
diff --git a/transport_mdns.cpp b/transport_mdns.cpp
new file mode 100644
index 0000000..b63fc83
--- /dev/null
+++ b/transport_mdns.cpp
@@ -0,0 +1,257 @@
+/*
+ * Copyright (C) 2016 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#define TRACE_TAG TRANSPORT
+
+#include "transport.h"
+
+#include <arpa/inet.h>
+
+#include <android-base/stringprintf.h>
+#include <dns_sd.h>
+
+#include "adb_mdns.h"
+#include "adb_trace.h"
+#include "fdevent.h"
+#include "sysdeps.h"
+
+static DNSServiceRef service_ref;
+static fdevent service_ref_fde;
+
+static void register_service_ip(DNSServiceRef sdRef,
+                                DNSServiceFlags flags,
+                                uint32_t interfaceIndex,
+                                DNSServiceErrorType errorCode,
+                                const char* hostname,
+                                const sockaddr* address,
+                                uint32_t ttl,
+                                void* context);
+
+static void pump_service_ref(int /*fd*/, unsigned ev, void* data) {
+    DNSServiceRef* ref = reinterpret_cast<DNSServiceRef*>(data);
+
+    if (ev & FDE_READ)
+        DNSServiceProcessResult(*ref);
+}
+
+class AsyncServiceRef {
+  public:
+    bool Initialized() {
+        return initialized_;
+    }
+
+    virtual ~AsyncServiceRef() {
+        if (! initialized_) {
+            return;
+        }
+
+        DNSServiceRefDeallocate(sdRef_);
+        fdevent_remove(&fde_);
+    }
+
+  protected:
+    DNSServiceRef sdRef_;
+
+    void Initialize() {
+        fdevent_install(&fde_, DNSServiceRefSockFD(sdRef_),
+                        pump_service_ref, &sdRef_);
+        fdevent_set(&fde_, FDE_READ);
+        initialized_ = true;
+    }
+
+  private:
+    bool initialized_;
+    fdevent fde_;
+};
+
+class ResolvedService : public AsyncServiceRef {
+  public:
+    virtual ~ResolvedService() = default;
+
+    ResolvedService(std::string name, uint32_t interfaceIndex,
+                    const char* hosttarget, uint16_t port) :
+            name_(name),
+            port_(port) {
+        DNSServiceErrorType ret =
+            DNSServiceGetAddrInfo(
+                &sdRef_, 0, interfaceIndex,
+                kDNSServiceProtocol_IPv6|kDNSServiceProtocol_IPv4, hosttarget,
+                register_service_ip, reinterpret_cast<void*>(this));
+
+        if (ret != kDNSServiceErr_NoError) {
+            D("Got %d from DNSServiceGetAddrInfo.", ret);
+        } else {
+            Initialize();
+        }
+    }
+
+    void Connect(const sockaddr* address) {
+        char ip_addr[INET6_ADDRSTRLEN];
+        const void* ip_addr_data;
+        const char* addr_format;
+
+        if (address->sa_family == AF_INET) {
+            ip_addr_data =
+                &reinterpret_cast<const sockaddr_in*>(address)->sin_addr;
+            addr_format = "%s:%hu";
+        } else if (address->sa_family == AF_INET6) {
+            ip_addr_data =
+                &reinterpret_cast<const sockaddr_in6*>(address)->sin6_addr;
+            addr_format = "[%s]:%hu";
+        } else { // Should be impossible
+            D("mDNS resolved non-IP address.");
+            return;
+        }
+
+        if (!inet_ntop(address->sa_family, ip_addr_data, ip_addr,
+                       INET6_ADDRSTRLEN)) {
+            D("Could not convert IP address to string.");
+            return;
+        }
+
+        std::string response;
+        connect_device(android::base::StringPrintf(addr_format, ip_addr, port_),
+                       &response);
+        D("Connect to %s (%s:%hu) : %s", name_.c_str(), ip_addr, port_,
+          response.c_str());
+    }
+
+  private:
+    std::string name_;
+    const uint16_t port_;
+};
+
+static void register_service_ip(DNSServiceRef /*sdRef*/,
+                                DNSServiceFlags /*flags*/,
+                                uint32_t /*interfaceIndex*/,
+                                DNSServiceErrorType /*errorCode*/,
+                                const char* /*hostname*/,
+                                const sockaddr* address,
+                                uint32_t /*ttl*/,
+                                void* context) {
+    std::unique_ptr<ResolvedService> data(
+        reinterpret_cast<ResolvedService*>(context));
+    data->Connect(address);
+}
+
+static void register_resolved_mdns_service(DNSServiceRef sdRef,
+                                           DNSServiceFlags flags,
+                                           uint32_t interfaceIndex,
+                                           DNSServiceErrorType errorCode,
+                                           const char* fullname,
+                                           const char* hosttarget,
+                                           uint16_t port,
+                                           uint16_t txtLen,
+                                           const unsigned char* txtRecord,
+                                           void* context);
+
+class DiscoveredService : public AsyncServiceRef {
+  public:
+    DiscoveredService(uint32_t interfaceIndex, const char* serviceName,
+                      const char* regtype, const char* domain)
+        : serviceName_(serviceName) {
+
+        DNSServiceErrorType ret =
+            DNSServiceResolve(&sdRef_, 0, interfaceIndex, serviceName, regtype,
+                              domain, register_resolved_mdns_service,
+                              reinterpret_cast<void*>(this));
+
+        if (ret != kDNSServiceErr_NoError) {
+            D("Got %d from DNSServiceResolve.", ret);
+        } else {
+            Initialize();
+        }
+    }
+
+    const char* ServiceName() {
+        return serviceName_.c_str();
+    }
+
+  private:
+    std::string serviceName_;
+};
+
+static void register_resolved_mdns_service(DNSServiceRef sdRef,
+                                           DNSServiceFlags flags,
+                                           uint32_t interfaceIndex,
+                                           DNSServiceErrorType errorCode,
+                                           const char* fullname,
+                                           const char* hosttarget,
+                                           uint16_t port,
+                                           uint16_t /*txtLen*/,
+                                           const unsigned char* /*txtRecord*/,
+                                           void* context) {
+    std::unique_ptr<DiscoveredService> discovered(
+        reinterpret_cast<DiscoveredService*>(context));
+
+    if (errorCode != kDNSServiceErr_NoError) {
+        D("Got error %d resolving service.", errorCode);
+        return;
+    }
+
+
+    auto resolved =
+        new ResolvedService(discovered->ServiceName(),
+                            interfaceIndex, hosttarget, ntohs(port));
+
+    if (! resolved->Initialized()) {
+        delete resolved;
+    }
+
+    if (flags) { /* Only ever equals MoreComing or 0 */
+        discovered.release();
+    }
+}
+
+static void register_mdns_transport(DNSServiceRef sdRef,
+                                    DNSServiceFlags flags,
+                                    uint32_t interfaceIndex,
+                                    DNSServiceErrorType errorCode,
+                                    const char* serviceName,
+                                    const char* regtype,
+                                    const char* domain,
+                                    void*  /*context*/) {
+    if (errorCode != kDNSServiceErr_NoError) {
+        D("Got error %d during mDNS browse.", errorCode);
+        DNSServiceRefDeallocate(sdRef);
+        fdevent_remove(&service_ref_fde);
+        return;
+    }
+
+    auto discovered = new DiscoveredService(interfaceIndex, serviceName,
+                                            regtype, domain);
+
+    if (! discovered->Initialized()) {
+        delete discovered;
+    }
+}
+
+void init_mdns_transport_discovery(void) {
+    DNSServiceErrorType errorCode =
+        DNSServiceBrowse(&service_ref, 0, 0, kADBServiceType, nullptr,
+                         register_mdns_transport, nullptr);
+
+    if (errorCode != kDNSServiceErr_NoError) {
+        D("Got %d initiating mDNS browse.", errorCode);
+        return;
+    }
+
+    fdevent_install(&service_ref_fde,
+                    DNSServiceRefSockFD(service_ref),
+                    pump_service_ref,
+                    &service_ref);
+    fdevent_set(&service_ref_fde, FDE_READ);
+}