Refactor, prerequisite for DNS-over-TLS pipelining
This change should have no effect on behavior, but it divides functionality
out into classes in a way that will enable pipelining.
It also adds unit tests for the newly divided functionality.
Test: Unit and integration tests pass.
Bug: 63448521
Change-Id: I08948be304b7a3e4ba10f754ef58bd41db6824c4
diff --git a/tests/dns_tls_test.cpp b/tests/dns_tls_test.cpp
new file mode 100644
index 0000000..5ab95bd
--- /dev/null
+++ b/tests/dns_tls_test.cpp
@@ -0,0 +1,265 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#define LOG_TAG "dns_tls_test"
+
+#include <gtest/gtest.h>
+
+#include "dns/DnsTlsDispatcher.h"
+#include "dns/DnsTlsServer.h"
+#include "dns/DnsTlsSessionCache.h"
+#include "dns/DnsTlsSocket.h"
+#include "dns/DnsTlsTransport.h"
+#include "dns/IDnsTlsSocket.h"
+#include "dns/IDnsTlsSocketFactory.h"
+
+#include <chrono>
+#include <arpa/inet.h>
+#include <android-base/macros.h>
+#include <netdutils/Slice.h>
+
+#include "log/log.h"
+
+namespace android {
+namespace net {
+
+using netdutils::Slice;
+using netdutils::makeSlice;
+
+typedef std::vector<uint8_t> bytevec;
+
+static void parseServer(const char* server, in_port_t port, sockaddr_storage* parsed) {
+ sockaddr_in* sin = reinterpret_cast<sockaddr_in*>(parsed);
+ if (inet_pton(AF_INET, server, &(sin->sin_addr)) == 1) {
+ // IPv4 parse succeeded, so it's IPv4
+ sin->sin_family = AF_INET;
+ sin->sin_port = htons(port);
+ return;
+ }
+ sockaddr_in6* sin6 = reinterpret_cast<sockaddr_in6*>(parsed);
+ if (inet_pton(AF_INET6, server, &(sin6->sin6_addr)) == 1){
+ // IPv6 parse succeeded, so it's IPv6.
+ sin6->sin6_family = AF_INET6;
+ sin6->sin6_port = htons(port);
+ return;
+ }
+ ALOGE("Failed to parse server address: %s", server);
+}
+
+bytevec FINGERPRINT1 = { 1 };
+
+std::string SERVERNAME1 = "dns.example.com";
+
+// BaseTest just provides constants that are useful for the tests.
+class BaseTest : public ::testing::Test {
+protected:
+ BaseTest() {
+ parseServer("192.0.2.1", 853, &V4ADDR1);
+ parseServer("192.0.2.2", 853, &V4ADDR2);
+
+ SERVER1 = DnsTlsServer(V4ADDR1);
+ SERVER1.fingerprints.insert(FINGERPRINT1);
+ SERVER1.name = SERVERNAME1;
+ }
+
+ sockaddr_storage V4ADDR1;
+ sockaddr_storage V4ADDR2;
+
+ DnsTlsServer SERVER1;
+};
+
+bytevec make_query(uint16_t id, size_t size) {
+ bytevec vec(size);
+ vec[0] = id >> 8;
+ vec[1] = id;
+ // Arbitrarily fill the query body with unique data.
+ for (size_t i = 2; i < size; ++i) {
+ vec[i] = id + i;
+ }
+ return vec;
+}
+
+// Query constants
+const unsigned MARK = 123;
+const uint16_t ID = 52;
+const uint16_t SIZE = 22;
+const bytevec QUERY = make_query(ID, SIZE);
+
+template <class T>
+class FakeSocketFactory : public IDnsTlsSocketFactory {
+public:
+ FakeSocketFactory() {}
+ std::unique_ptr<IDnsTlsSocket> createDnsTlsSocket(
+ const DnsTlsServer& server ATTRIBUTE_UNUSED,
+ unsigned mark ATTRIBUTE_UNUSED,
+ DnsTlsSessionCache* cache ATTRIBUTE_UNUSED) override {
+ return std::make_unique<T>();
+ }
+};
+
+bytevec make_echo(uint16_t id, const Slice query) {
+ bytevec response(query.size() + 2);
+ response[0] = id >> 8;
+ response[1] = id;
+ // Echo the query as the fake response.
+ memcpy(response.data() + 2, query.base(), query.size());
+ return response;
+}
+
+// Simplest possible fake server. This just echoes the query as the response.
+class FakeSocketEcho : public IDnsTlsSocket {
+public:
+ FakeSocketEcho() {}
+ DnsTlsServer::Result query(uint16_t id, const Slice query) override {
+ // Return the response immediately.
+ return { .code = DnsTlsServer::Response::success, .response = make_echo(id, query) };
+ }
+};
+
+class TransportTest : public BaseTest {};
+
+TEST_F(TransportTest, Query) {
+ FakeSocketFactory<FakeSocketEcho> factory;
+ DnsTlsTransport transport(SERVER1, MARK, &factory);
+ auto r = transport.query(makeSlice(QUERY));
+
+ EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
+ EXPECT_EQ(QUERY, r.response);
+}
+
+TEST_F(TransportTest, SerialQueries) {
+ FakeSocketFactory<FakeSocketEcho> factory;
+ DnsTlsTransport transport(SERVER1, MARK, &factory);
+ // Send more than 65536 queries serially.
+ for (int i = 0; i < 100000; ++i) {
+ auto r = transport.query(makeSlice(QUERY));
+
+ EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
+ EXPECT_EQ(QUERY, r.response);
+ }
+}
+
+// Returning null from the factory indicates a connection failure.
+class NullSocketFactory : public IDnsTlsSocketFactory {
+public:
+ NullSocketFactory() {}
+ std::unique_ptr<IDnsTlsSocket> createDnsTlsSocket(
+ const DnsTlsServer& server ATTRIBUTE_UNUSED,
+ unsigned mark ATTRIBUTE_UNUSED,
+ DnsTlsSessionCache* cache ATTRIBUTE_UNUSED) override {
+ return nullptr;
+ }
+};
+
+TEST_F(TransportTest, ConnectFail) {
+ NullSocketFactory factory;
+ DnsTlsTransport transport(SERVER1, MARK, &factory);
+ auto r = transport.query(makeSlice(QUERY));
+
+ EXPECT_EQ(DnsTlsTransport::Response::network_error, r.code);
+ EXPECT_TRUE(r.response.empty());
+}
+
+// Dispatcher tests
+class DispatcherTest : public BaseTest {};
+
+TEST_F(DispatcherTest, Query) {
+ bytevec ans(4096);
+ int resplen = 0;
+
+ auto factory = std::make_unique<FakeSocketFactory<FakeSocketEcho>>();
+ DnsTlsDispatcher dispatcher(std::move(factory));
+ auto r = dispatcher.query(SERVER1, MARK, makeSlice(QUERY),
+ makeSlice(ans), &resplen);
+
+ EXPECT_EQ(DnsTlsTransport::Response::success, r);
+ EXPECT_EQ(int(QUERY.size()), resplen);
+ ans.resize(resplen);
+ EXPECT_EQ(QUERY, ans);
+}
+
+TEST_F(DispatcherTest, AnswerTooLarge) {
+ bytevec ans(SIZE - 1); // Too small to hold the answer
+ int resplen = 0;
+
+ auto factory = std::make_unique<FakeSocketFactory<FakeSocketEcho>>();
+ DnsTlsDispatcher dispatcher(std::move(factory));
+ auto r = dispatcher.query(SERVER1, MARK, makeSlice(QUERY),
+ makeSlice(ans), &resplen);
+
+ EXPECT_EQ(DnsTlsTransport::Response::limit_error, r);
+}
+
+template<class T>
+class TrackingFakeSocketFactory : public IDnsTlsSocketFactory {
+public:
+ TrackingFakeSocketFactory() {}
+ std::unique_ptr<IDnsTlsSocket> createDnsTlsSocket(
+ const DnsTlsServer& server,
+ unsigned mark,
+ DnsTlsSessionCache* cache ATTRIBUTE_UNUSED) override {
+ std::lock_guard<std::mutex> guard(mLock);
+ keys.emplace(mark, server);
+ return std::make_unique<T>();
+ }
+ std::multiset<std::pair<unsigned, DnsTlsServer>> keys;
+private:
+ std::mutex mLock;
+};
+
+TEST_F(DispatcherTest, Dispatching) {
+ auto factory = std::make_unique<TrackingFakeSocketFactory<FakeSocketEcho>>();
+ auto* weak_factory = factory.get(); // Valid as long as dispatcher is in scope.
+ DnsTlsDispatcher dispatcher(std::move(factory));
+
+ // Populate a vector of two servers and two socket marks, four combinations
+ // in total.
+ std::vector<std::pair<unsigned, DnsTlsServer>> keys;
+ keys.emplace_back(MARK, SERVER1);
+ keys.emplace_back(MARK + 1, SERVER1);
+ keys.emplace_back(MARK, V4ADDR2);
+ keys.emplace_back(MARK + 1, V4ADDR2);
+
+ // Do one query on each server. They should all succeed.
+ std::vector<std::thread> threads;
+ for (size_t i = 0; i < keys.size(); ++i) {
+ auto key = keys[i % keys.size()];
+ threads.emplace_back([key, i] (DnsTlsDispatcher* dispatcher) {
+ auto q = make_query(i, SIZE);
+ bytevec ans(4096);
+ int resplen = 0;
+ unsigned mark = key.first;
+ const DnsTlsServer& server = key.second;
+ auto r = dispatcher->query(server, mark, makeSlice(q),
+ makeSlice(ans), &resplen);
+ EXPECT_EQ(DnsTlsTransport::Response::success, r);
+ EXPECT_EQ(int(q.size()), resplen);
+ ans.resize(resplen);
+ EXPECT_EQ(q, ans);
+ }, &dispatcher);
+ }
+ for (auto& thread : threads) {
+ thread.join();
+ }
+ // We expect that the factory created one socket for each key.
+ EXPECT_EQ(keys.size(), weak_factory->keys.size());
+ for (auto& key : keys) {
+ EXPECT_EQ(1U, weak_factory->keys.count(key));
+ }
+}
+
+} // end of namespace net
+} // end of namespace android