blob: 8e0033d62c2301f829d460e048a1ea9260e0a17b [file] [log] [blame]
Mike Yubab3daa2018-10-19 22:11:43 +08001/*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
Ken Chen5471dca2019-04-15 15:25:35 +080017#define LOG_TAG "resolv"
Mike Yubab3daa2018-10-19 22:11:43 +080018//#define LOG_NDEBUG 0
19
Bernie Innocentiec4219b2019-01-30 11:16:36 +090020#include "DnsTlsQueryMap.h"
Mike Yubab3daa2018-10-19 22:11:43 +080021
22#include "log/log.h"
23
24namespace android {
25namespace net {
26
Bernie Innocentiec4219b2019-01-30 11:16:36 +090027std::unique_ptr<DnsTlsQueryMap::QueryFuture> DnsTlsQueryMap::recordQuery(
28 const netdutils::Slice query) {
Mike Yubab3daa2018-10-19 22:11:43 +080029 std::lock_guard guard(mLock);
30
31 // Store the query so it can be matched to the response or reissued.
32 if (query.size() < 2) {
33 ALOGW("Query is too short");
34 return nullptr;
35 }
36 int32_t newId = getFreeId();
37 if (newId < 0) {
38 ALOGW("All query IDs are in use");
39 return nullptr;
40 }
41 Query q = { .newId = static_cast<uint16_t>(newId), .query = query };
42 std::map<uint16_t, QueryPromise>::iterator it;
43 bool inserted;
44 std::tie(it, inserted) = mQueries.emplace(newId, q);
45 if (!inserted) {
46 ALOGE("Failed to store pending query");
47 return nullptr;
48 }
49 return std::make_unique<QueryFuture>(q, it->second.result.get_future());
50}
51
52void DnsTlsQueryMap::expire(QueryPromise* p) {
53 Result r = { .code = Response::network_error };
54 p->result.set_value(r);
55}
56
57void DnsTlsQueryMap::markTried(uint16_t newId) {
58 std::lock_guard guard(mLock);
59 auto it = mQueries.find(newId);
60 if (it != mQueries.end()) {
61 it->second.tries++;
62 }
63}
64
65void DnsTlsQueryMap::cleanup() {
66 std::lock_guard guard(mLock);
67 for (auto it = mQueries.begin(); it != mQueries.end();) {
68 auto& p = it->second;
69 if (p.tries >= kMaxTries) {
70 expire(&p);
71 it = mQueries.erase(it);
72 } else {
73 ++it;
74 }
75 }
76}
77
78int32_t DnsTlsQueryMap::getFreeId() {
79 if (mQueries.empty()) {
80 return 0;
81 }
82 uint16_t maxId = mQueries.rbegin()->first;
83 if (maxId < UINT16_MAX) {
84 return maxId + 1;
85 }
86 if (mQueries.size() == UINT16_MAX + 1) {
87 // Map is full.
88 return -1;
89 }
90 // Linear scan.
91 uint16_t nextId = 0;
92 for (auto& pair : mQueries) {
93 uint16_t id = pair.first;
94 if (id != nextId) {
95 // Found a gap.
96 return nextId;
97 }
98 nextId = id + 1;
99 }
100 // Unreachable (but the compiler isn't smart enough to prove it).
101 return -1;
102}
103
104std::vector<DnsTlsQueryMap::Query> DnsTlsQueryMap::getAll() {
105 std::lock_guard guard(mLock);
106 std::vector<DnsTlsQueryMap::Query> queries;
107 for (auto& q : mQueries) {
108 queries.push_back(q.second.query);
109 }
110 return queries;
111}
112
113bool DnsTlsQueryMap::empty() {
114 std::lock_guard guard(mLock);
115 return mQueries.empty();
116}
117
118void DnsTlsQueryMap::clear() {
119 std::lock_guard guard(mLock);
120 for (auto& q : mQueries) {
121 expire(&q.second);
122 }
123 mQueries.clear();
124}
125
126void DnsTlsQueryMap::onResponse(std::vector<uint8_t> response) {
127 ALOGV("Got response of size %zu", response.size());
128 if (response.size() < 2) {
129 ALOGW("Response is too short");
130 return;
131 }
132 uint16_t id = response[0] << 8 | response[1];
133 std::lock_guard guard(mLock);
134 auto it = mQueries.find(id);
135 if (it == mQueries.end()) {
136 ALOGW("Discarding response: unknown ID %d", id);
137 return;
138 }
139 Result r = { .code = Response::success, .response = std::move(response) };
140 // Rewrite ID to match the query
141 const uint8_t* data = it->second.query.query.base();
142 r.response[0] = data[0];
143 r.response[1] = data[1];
144 ALOGV("Sending result to dispatcher");
145 it->second.result.set_value(std::move(r));
146 mQueries.erase(it);
147}
148
149} // end of namespace net
150} // end of namespace android