Refactor DoH implementation
Current doh implementation is designed to be totally managed by
DnsResolver, which means we might need bunch of C++ glue code to make
things work.
This goal of this refactoring is to minimize the required C++ glue code
and put most control logic into Rust side.
Test: atest
Bug: 155855709
Change-Id: I9a048f5fe72c4c25ae1b95ddf839f244eda34097
diff --git a/Android.bp b/Android.bp
index 2aedaa7..7401cb9 100644
--- a/Android.bp
+++ b/Android.bp
@@ -330,7 +330,8 @@
rlibs: [
"libandroid_logger",
"libanyhow",
- "liblazy_static",
+ "libbase64_rust",
+ "libfutures",
"liblibc",
"liblog_rust",
"libquiche",
@@ -365,7 +366,8 @@
rustlibs: [
"libandroid_logger",
"libanyhow",
- "liblazy_static",
+ "libbase64_rust",
+ "libfutures",
"liblibc",
"liblog_rust",
"libquiche_static",
@@ -386,7 +388,8 @@
rlibs: [
"libandroid_logger",
"libanyhow",
- "liblazy_static",
+ "libbase64_rust",
+ "libfutures",
"liblibc",
"liblog_rust",
"libquiche_static",
diff --git a/doh.h b/doh.h
index d04feb4..c9e1f1b 100644
--- a/doh.h
+++ b/doh.h
@@ -20,28 +20,63 @@
#pragma once
-/* Generated with cbindgen:0.15.0 */
+/* Generated with cbindgen:0.17.0 */
#include <stdint.h>
#include <sys/types.h>
-/// Context for a running DoH engine and associated thread.
-struct DohServer;
+/// The return code of doh_query means that there is no answer.
+static const ssize_t RESULT_INTERNAL_ERROR = -1;
+
+/// The return code of doh_query means that query can't be sent.
+static const ssize_t RESULT_CAN_NOT_SEND = -2;
+
+/// The return code of doh_query to indicate that the query timed out.
+static const ssize_t RESULT_TIMEOUT = -255;
+
+/// Context for a running DoH engine.
+struct DohDispatcher;
+
+using ValidationCallback = void (*)(uint32_t net_id, bool success, const char* ip_addr,
+ const char* host);
extern "C" {
-/// Performs static initialization fo the DoH engine.
-const char* doh_init();
-
+/// Performs static initialization for the DoH engine.
/// Creates and returns a DoH engine instance.
-/// The returned object must be freed with doh_delete().
-DohServer* doh_new(const char* url, const char* ip_addr, uint32_t mark, const char* cert_path);
+DohDispatcher* doh_dispatcher_new(ValidationCallback ptr);
-/// Deletes a DoH engine created by doh_new().
-void doh_delete(DohServer* doh);
+/// Deletes a DoH engine created by doh_dispatcher_new().
+/// # Safety
+/// `doh` must be a non-null pointer previously created by `doh_dispatcher_new()`
+/// and not yet deleted by `doh_dispatcher_delete()`.
+void doh_dispatcher_delete(DohDispatcher* doh);
-/// Sends a DNS query and waits for the response.
-ssize_t doh_query(DohServer* doh, uint8_t* query, size_t query_len, uint8_t* response,
- size_t response_len);
+/// Probes and stores the DoH server with the given configurations.
+/// Use the negative errno-style codes as the return value to represent the result.
+/// # Safety
+/// `doh` must be a non-null pointer previously created by `doh_dispatcher_new()`
+/// and not yet deleted by `doh_dispatcher_delete()`.
+/// `url`, `domain`, `ip_addr`, `cert_path` are null terminated strings.
+int32_t doh_net_new(DohDispatcher* doh, uint32_t net_id, const char* url, const char* domain,
+ const char* ip_addr, uint32_t sk_mark, const char* cert_path,
+ uint64_t timeout_ms);
+
+/// Sends a DNS query via the network associated to the given |net_id| and waits for the response.
+/// The return code should be either one of the public constant RESULT_* to indicate the error or
+/// the size of the answer.
+/// # Safety
+/// `doh` must be a non-null pointer previously created by `doh_dispatcher_new()`
+/// and not yet deleted by `doh_dispatcher_delete()`.
+/// `dns_query` must point to a buffer at least `dns_query_len` in size.
+/// `response` must point to a buffer at least `response_len` in size.
+ssize_t doh_query(DohDispatcher* doh, uint32_t net_id, uint8_t* dns_query, size_t dns_query_len,
+ uint8_t* response, size_t response_len, uint64_t timeout_ms);
+
+/// Clears the DoH servers associated with the given |netid|.
+/// # Safety
+/// `doh` must be a non-null pointer previously created by `doh_dispatcher_new()`
+/// and not yet deleted by `doh_dispatcher_delete()`.
+void doh_net_delete(DohDispatcher* doh, uint32_t net_id);
} // extern "C"
diff --git a/doh.rs b/doh.rs
index bd423e8..6e68b7f 100644
--- a/doh.rs
+++ b/doh.rs
@@ -17,14 +17,19 @@
//! DoH backend for the Android DnsResolver module.
use anyhow::{anyhow, Context, Result};
-use lazy_static::lazy_static;
-use libc::{c_char, size_t, ssize_t};
+use futures::future::join_all;
+use futures::stream::FuturesUnordered;
+use futures::StreamExt;
+use libc::{c_char, int32_t, size_t, ssize_t, uint32_t, uint64_t};
use log::{debug, error, info, warn};
use quiche::h3;
use ring::rand::SecureRandom;
use std::collections::HashMap;
+use std::ffi::CString;
use std::net::{IpAddr, SocketAddr};
+use std::ops::Deref;
use std::os::unix::io::{AsRawFd, RawFd};
+use std::pin::Pin;
use std::str::FromStr;
use std::sync::Arc;
use std::{ptr, slice};
@@ -32,21 +37,15 @@
use tokio::runtime::{Builder, Runtime};
use tokio::sync::{mpsc, oneshot};
use tokio::task;
-use tokio::time::Duration;
+use tokio::time::{timeout, Duration, Instant};
use url::Url;
-lazy_static! {
- /// Tokio runtime used to perform doh-handler tasks.
- static ref RUNTIME_STATIC: Arc<Runtime> = Arc::new(
- Builder::new_multi_thread()
- .worker_threads(2)
- .max_blocking_threads(1)
- .enable_all()
- .thread_name("doh-handler")
- .build()
- .expect("Failed to create tokio runtime")
- );
-}
+/// The return code of doh_query means that there is no answer.
+pub const RESULT_INTERNAL_ERROR: ssize_t = -1;
+/// The return code of doh_query means that query can't be sent.
+pub const RESULT_CAN_NOT_SEND: ssize_t = -2;
+/// The return code of doh_query to indicate that the query timed out.
+pub const RESULT_TIMEOUT: ssize_t = -255;
const MAX_BUFFERED_CMD_SIZE: usize = 400;
const MAX_INCOMING_BUFFER_SIZE_WHOLE: u64 = 10000000;
@@ -56,224 +55,638 @@
const DOH_PORT: u16 = 443;
const QUICHE_IDLE_TIMEOUT_MS: u64 = 180000;
const SYSTEM_CERT_PATH: &str = "/system/etc/security/cacerts";
+const NS_T_AAAA: u8 = 28;
+const NS_C_IN: u8 = 1;
+// Used to randomly generate query prefix and query id.
+const CHARSET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ\
+ abcdefghijklmnopqrstuvwxyz\
+ 0123456789";
type SCID = [u8; quiche::MAX_CONN_ID_LEN];
-type Query = Vec<u8>;
-type Response = Vec<u8>;
-type CmdSender = mpsc::Sender<Command>;
-type CmdReceiver = mpsc::Receiver<Command>;
-type QueryResponder = oneshot::Sender<Option<Response>>;
+type Base64Query = String;
+type CmdSender = mpsc::Sender<DohCommand>;
+type CmdReceiver = mpsc::Receiver<DohCommand>;
+type QueryResponder = oneshot::Sender<Response>;
+type DnsRequest = Vec<quiche::h3::Header>;
+type DnsRequestArg = [quiche::h3::Header];
+type ValidationCallback =
+ extern "C" fn(net_id: uint32_t, success: bool, ip_addr: *const c_char, host: *const c_char);
#[derive(Debug)]
-enum Command {
- DohQuery { query: Query, resp: QueryResponder },
+enum QueryError {
+ BrokenServer,
+ ConnectionError,
+ ServerNotReady,
+ Unexpected,
+}
+
+#[derive(Eq, PartialEq, Debug, Clone)]
+struct ServerInfo {
+ net_id: u32,
+ url: Url,
+ peer_addr: SocketAddr,
+ domain: Option<String>,
+ sk_mark: u32,
+ cert_path: Option<String>,
+}
+
+#[derive(Debug)]
+enum Response {
+ Error { error: QueryError },
+ Success { answer: Vec<u8> },
+}
+
+#[derive(Debug)]
+enum DohCommand {
+ Probe { info: ServerInfo, timeout: Duration },
+ Query { net_id: u32, base64_query: Base64Query, timeout: Duration, resp: QueryResponder },
+ Clear { net_id: u32 },
+ Exit,
+}
+
+#[derive(Eq, PartialEq, Debug, Clone)]
+enum ConnectionStatus {
+ Idle,
+ Ready,
+ Pending,
+ Fail,
+}
+
+trait OptionDeref<T: Deref> {
+ fn as_deref(&self) -> Option<&T::Target>;
+}
+
+impl<T: Deref> OptionDeref<T> for Option<T> {
+ fn as_deref(&self) -> Option<&T::Target> {
+ self.as_ref().map(Deref::deref)
+ }
}
/// Context for a running DoH engine.
pub struct DohDispatcher {
- /// Used to submit queries to the I/O thread.
- query_sender: CmdSender,
-
+ /// Used to submit cmds to the I/O task.
+ cmd_sender: CmdSender,
join_handle: task::JoinHandle<Result<()>>,
-}
-
-fn make_doh_udp_socket(ip_addr: &str, mark: u32) -> Result<(SocketAddr, std::net::UdpSocket)> {
- let peer_addr = SocketAddr::new(IpAddr::from_str(&ip_addr)?, DOH_PORT);
- let bind_addr = match peer_addr {
- std::net::SocketAddr::V4(_) => "0.0.0.0:0",
- std::net::SocketAddr::V6(_) => "[::]:0",
- };
- let udp_sk = std::net::UdpSocket::bind(bind_addr)?;
- udp_sk.set_nonblocking(true)?;
- mark_socket(udp_sk.as_raw_fd(), mark)?;
- udp_sk.connect(peer_addr)?;
-
- debug!("connecting to {:} from {:}", peer_addr, udp_sk.local_addr()?);
- Ok((peer_addr, udp_sk))
+ runtime: Arc<Runtime>,
}
// DoH dispatcher
impl DohDispatcher {
- fn new(
- url: &str,
- ip_addr: &str,
- mark: u32,
- cert_path: Option<&str>,
- ) -> Result<Box<DohDispatcher>> {
- // Setup socket
- let (peer_addr, udp_sk) = make_doh_udp_socket(&ip_addr, mark)?;
- DohDispatcher::new_with_socket(url, ip_addr, peer_addr, mark, cert_path, udp_sk)
- }
-
- fn new_with_socket(
- url: &str,
- ip_addr: &str,
- peer_addr: SocketAddr,
- mark: u32,
- cert_path: Option<&str>,
- udp_sk: std::net::UdpSocket,
- ) -> Result<Box<DohDispatcher>> {
- let url = Url::parse(&url.to_string())?;
- if url.domain().is_none() {
- return Err(anyhow!("no domain"));
- }
- // Setup quiche config
- let config = create_quiche_config(cert_path)?;
- let h3_config = h3::Config::new()?;
- let mut scid = [0; quiche::MAX_CONN_ID_LEN];
- ring::rand::SystemRandom::new().fill(&mut scid[..]).context("failed to generate scid")?;
-
- let (cmd_sender, cmd_receiver) = mpsc::channel::<Command>(MAX_BUFFERED_CMD_SIZE);
- debug!(
- "Creating a doh handler task: url={}, ip_addr={}, mark={:#x}, scid {:x?}",
- url, ip_addr, mark, &scid
+ fn new(validation_fn: ValidationCallback) -> Result<Box<DohDispatcher>> {
+ let (cmd_sender, cmd_receiver) = mpsc::channel::<DohCommand>(MAX_BUFFERED_CMD_SIZE);
+ let runtime = Arc::new(
+ Builder::new_multi_thread()
+ .worker_threads(2)
+ .enable_all()
+ .thread_name("doh-handler")
+ .build()
+ .expect("Failed to create tokio runtime"),
);
- let join_handle = RUNTIME_STATIC.spawn(doh_handler(
- url,
- peer_addr,
- udp_sk,
- config,
- h3_config,
- scid,
- cmd_receiver,
- ));
- Ok(Box::new(DohDispatcher { query_sender: cmd_sender, join_handle }))
+ let join_handle = runtime.spawn(doh_handler(cmd_receiver, runtime.clone(), validation_fn));
+ Ok(Box::new(DohDispatcher { cmd_sender, join_handle, runtime }))
}
- fn query(&self, cmd: Command) -> Result<()> {
- self.query_sender.blocking_send(cmd)?;
+ fn send_cmd(&self, cmd: DohCommand) -> Result<()> {
+ self.cmd_sender.blocking_send(cmd)?;
Ok(())
}
- fn abort_handler(&self) {
- self.join_handle.abort();
+ fn exit_handler(&mut self) {
+ if self.cmd_sender.blocking_send(DohCommand::Exit).is_err() {
+ return;
+ }
+ let _ = self.runtime.block_on(&mut self.join_handle);
}
}
-async fn doh_handler(
- url: url::Url,
- peer_addr: SocketAddr,
- udp_sk: std::net::UdpSocket,
- mut config: quiche::Config,
- h3_config: h3::Config,
+struct DohConnection {
+ net_id: u32,
scid: SCID,
- mut rx: CmdReceiver,
-) -> Result<()> {
- debug!("doh_handler: url={:?}", url);
+ quic_conn: Pin<Box<quiche::Connection>>,
+ udp_sk: UdpSocket,
+ h3_conn: Option<h3::Connection>,
+ status: ConnectionStatus,
+ query_map: HashMap<u64, QueryResponder>,
+ pending_queries: Vec<(DnsRequest, QueryResponder, Option<Instant>)>,
+ cached_session: Option<Vec<u8>>,
+}
- let connid = quiche::ConnectionId::from_ref(&scid);
- let sk = UdpSocket::from_std(udp_sk)?;
- let mut conn = quiche::connect(url.domain(), &connid, peer_addr, &mut config)?;
- let recv_info = quiche::RecvInfo { from: peer_addr };
- let mut quic_conn_start = std::time::Instant::now();
- let mut h3_conn: Option<h3::Connection> = None;
- let mut is_idle = false;
- let mut buf = [0; 65535];
+impl DohConnection {
+ fn new(info: &ServerInfo, config: &mut quiche::Config) -> Result<DohConnection> {
+ let udp_sk_std = make_doh_udp_socket(info.peer_addr, info.sk_mark)?;
+ let udp_sk = UdpSocket::from_std(udp_sk_std)?;
+ let mut scid = [0; quiche::MAX_CONN_ID_LEN];
+ ring::rand::SystemRandom::new().fill(&mut scid).context("failed to generate scid")?;
+ let connid = quiche::ConnectionId::from_ref(&scid);
+ let quic_conn = quiche::connect(info.domain.as_deref(), &connid, info.peer_addr, config)?;
- let mut query_map = HashMap::<u64, QueryResponder>::new();
- let mut pending_cmds: Vec<Command> = Vec::new();
+ Ok(DohConnection {
+ net_id: info.net_id,
+ scid,
+ quic_conn,
+ udp_sk,
+ h3_conn: None,
+ status: ConnectionStatus::Pending,
+ query_map: HashMap::new(),
+ pending_queries: Vec::new(),
+ cached_session: None,
+ })
+ }
- let mut ts = Duration::from_millis(QUICHE_IDLE_TIMEOUT_MS);
- loop {
- tokio::select! {
- size = sk.recv(&mut buf) => {
- debug!("recv {:?} ", size);
- match size {
- Ok(size) => {
- let processed = match conn.recv(&mut buf[..size], recv_info) {
- Ok(l) => l,
- Err(e) => {
- error!("quic recv failed: {:?}", e);
- continue;
- }
- };
- debug!("processed {} bytes", processed);
- },
- Err(e) => {
- error!("socket recv failed: {:?}", e);
- continue;
- },
- };
- }
- Some(cmd) = rx.recv() => {
- debug!("recv {:?}", cmd);
- pending_cmds.push(cmd);
- }
- _ = tokio::time::sleep(ts) => {
- conn.on_timeout();
- debug!("quic connection timeout");
- }
- }
- if conn.is_closed() {
- // Show connection statistics after it's closed
- if !is_idle {
- info!("connection closed, {:?}, {:?}", quic_conn_start.elapsed(), conn.stats());
- is_idle = true;
- if !conn.is_established() {
- error!("connection handshake timed out after {:?}", quic_conn_start.elapsed());
+ async fn probe(&mut self, req: DnsRequest) -> Result<()> {
+ self.connect().await?;
+ info!("probe start for {}", self.net_id);
+ // Send the probe query.
+ let req_id = self.send_dns_query(&req).await?;
+ loop {
+ self.recv_rx().await?;
+ self.flush_tx().await?;
+ if let Ok((stream_id, _buf)) = self.recv_query() {
+ if stream_id == req_id {
+ // TODO: Verify the answer
+ break;
}
}
+ }
+ Ok(())
+ }
- // If there is any pending query, resume the quic connection.
- if !pending_cmds.is_empty() {
- info!("still some pending queries but connection is not avaiable, resume it");
- conn = quiche::connect(url.domain(), &connid, peer_addr, &mut config)?;
- quic_conn_start = std::time::Instant::now();
- h3_conn = None;
- is_idle = false;
+ async fn connect(&mut self) -> Result<()> {
+ while !self.quic_conn.is_established() {
+ self.flush_tx().await?;
+ self.recv_rx().await?;
+ }
+ self.cached_session = self.quic_conn.session();
+ let h3_config = h3::Config::new()?;
+ self.h3_conn =
+ Some(quiche::h3::Connection::with_transport(&mut self.quic_conn, &h3_config)?);
+ self.status = ConnectionStatus::Ready;
+ info!("connected to Network {}", self.net_id);
+ Ok(())
+ }
+
+ async fn send_dns_query(&mut self, req: &DnsRequestArg) -> Result<u64> {
+ if !self.quic_conn.is_established() {
+ return Err(anyhow!("quic connection is not ready"));
+ }
+ let h3_conn = self.h3_conn.as_mut().ok_or_else(|| anyhow!("h3 conn isn't available"))?;
+ let stream_id = h3_conn.send_request(&mut self.quic_conn, &req, false /*fin*/)?;
+ self.flush_tx().await?;
+ Ok(stream_id)
+ }
+
+ async fn try_send_doh_query(
+ &mut self,
+ req: DnsRequest,
+ timeout: Duration,
+ resp: QueryResponder,
+ ) {
+ match self.status {
+ ConnectionStatus::Ready => {
+ // Send an query to probe the server.
+ match self.send_dns_query(&req).await {
+ Ok(req_id) => {
+ self.query_map.insert(req_id, resp);
+ }
+ Err(_) => {
+ resp.send(Response::Error { error: QueryError::ConnectionError }).ok();
+ }
+ }
+ }
+ ConnectionStatus::Pending => {
+ self.pending_queries.push((req, resp, Instant::now().checked_add(timeout)));
+ }
+ // Should not happen
+ _ => {
+ error!("Try to send query but status error {}", self.net_id);
}
}
+ }
- // Create a new HTTP/3 connection once the QUIC connection is established.
- if conn.is_established() && h3_conn.is_none() {
- info!("quic ready, creating h3 conn");
- h3_conn = Some(quiche::h3::Connection::with_transport(&mut conn, &h3_config)?);
+ fn resume_connection(&mut self, quic_conn: Pin<Box<quiche::Connection>>) {
+ self.quic_conn = quic_conn;
+ if let Some(session) = &self.cached_session {
+ if self.quic_conn.set_session(&session).is_err() {
+ warn!("can't restore session for network {}", self.net_id);
+ }
}
- // Try to receive query answers from h3 connection.
- if let Some(h3) = h3_conn.as_mut() {
- recv_query(h3, &mut conn, &mut query_map).await;
+ self.status = ConnectionStatus::Pending;
+ // TODO: Also do a re-probe?
+ }
+
+ async fn process_queries(&mut self) -> Result<()> {
+ if self.status == ConnectionStatus::Pending {
+ self.connect().await?;
}
- // Update the next timeout of quic connection.
- ts = conn.timeout().unwrap_or_else(|| Duration::from_millis(QUICHE_IDLE_TIMEOUT_MS));
- info!("next connection timouts {:?}", ts);
-
- // Process the pending queries
- while !pending_cmds.is_empty() && conn.is_established() {
- if let Some(cmd) = pending_cmds.pop() {
- match cmd {
- Command::DohQuery { query, resp } => {
- match send_dns_query(&query, &url, &mut h3_conn, &mut conn) {
- Ok(stream_id) => {
- query_map.insert(stream_id, resp);
- }
- Err(e) => {
- info!("failed to send query {}", e);
- pending_cmds.push(Command::DohQuery { query, resp });
+ loop {
+ while !self.pending_queries.is_empty() {
+ if let Some((req, resp, exp_time)) = self.pending_queries.pop() {
+ // TODO: check if req is expired.
+ match self.send_dns_query(&req).await {
+ Ok(req_id) => {
+ self.query_map.insert(req_id, resp);
+ }
+ Err(e) => {
+ if let Ok(quiche::h3::Error::StreamBlocked) =
+ e.downcast::<quiche::h3::Error>()
+ {
+ self.pending_queries.push((req, resp, exp_time));
+ break;
+ } else {
+ resp.send(Response::Error { error: QueryError::ConnectionError })
+ .ok();
}
}
}
}
}
+ // TODO: clean up the expired queries.
+ self.recv_rx().await?;
+ self.flush_tx().await?;
+ if let Ok((stream_id, buf)) = self.recv_query() {
+ if let Some(resp) = self.query_map.remove(&stream_id) {
+ resp.send(Response::Success { answer: buf }).unwrap_or_else(|e| {
+ warn!("the receiver dropped {:?}", e);
+ });
+ } else {
+ // Should not happen
+ warn!("No associated receiver found");
+ }
+ }
+ if self.quic_conn.is_closed() || !self.quic_conn.is_established() {
+ self.status = ConnectionStatus::Idle;
+ return Err(anyhow!("connection become idle"));
+ }
}
- flush_tx(&sk, &mut conn).await.unwrap_or_else(|e| {
- error!("flush error {:?} ", e);
- });
+ }
+
+ fn recv_query(&mut self) -> Result<(u64, Vec<u8>)> {
+ let h3_conn = self.h3_conn.as_mut().ok_or_else(|| anyhow!("h3 conn isn't available"))?;
+ loop {
+ match h3_conn.poll(&mut self.quic_conn) {
+ // Process HTTP/3 events.
+ Ok((stream_id, quiche::h3::Event::Data)) => {
+ debug!("quiche::h3::Event::Data");
+ let mut buf = vec![0; MAX_DATAGRAM_SIZE];
+ if let Ok(read) = h3_conn.recv_body(&mut self.quic_conn, stream_id, &mut buf) {
+ debug!(
+ "got {} bytes of response data on stream {}: {:x?}",
+ read,
+ stream_id,
+ &buf[..read]
+ );
+ buf.truncate(read);
+ return Ok((stream_id, buf));
+ }
+ }
+ Ok((stream_id, quiche::h3::Event::Headers { list, has_body })) => {
+ debug!(
+ "got response headers {:?} on stream id {} has_body {}",
+ list, stream_id, has_body
+ );
+ }
+ Ok((_stream_id, quiche::h3::Event::Finished)) => {
+ debug!("quiche::h3::Event::Finished");
+ }
+ Ok((_stream_id, quiche::h3::Event::Datagram)) => {
+ debug!("quiche::h3::Event::Datagram");
+ }
+ Ok((_stream_id, quiche::h3::Event::GoAway)) => {
+ debug!("quiche::h3::Event::GoAway");
+ }
+ Err(e) => {
+ return Err(anyhow!(e));
+ }
+ }
+ }
+ }
+
+ async fn recv_rx(&mut self) -> Result<()> {
+ // TODO: Evaluate if we could make the buffer smaller.
+ let mut buf = [0; 65535];
+ let ts = self
+ .quic_conn
+ .timeout()
+ .unwrap_or_else(|| Duration::from_millis(QUICHE_IDLE_TIMEOUT_MS));
+ match timeout(ts, self.udp_sk.recv_from(&mut buf)).await {
+ Ok(v) => {
+ match v {
+ Ok((size, from)) => {
+ let recv_info = quiche::RecvInfo { from };
+ let processed = match self.quic_conn.recv(&mut buf[..size], recv_info) {
+ Ok(l) => l,
+ Err(e) => {
+ return Err(anyhow!("quic recv failed: {:?}", e));
+ }
+ };
+ debug!("processed {} bytes", processed);
+ return Ok(());
+ }
+ Err(e) => {
+ return Err(anyhow!("socket recv failed: {:?}", e));
+ }
+ };
+ }
+ Err(_) => {
+ warn!("timeout did not receive value within {:?} ms, {}", ts, self.net_id);
+ self.quic_conn.on_timeout();
+ return Ok(());
+ }
+ }
+ }
+
+ async fn flush_tx(&mut self) -> Result<()> {
+ let mut out = [0; MAX_DATAGRAM_SIZE];
+ debug!("flush_tx entry ");
+ loop {
+ let (write, _) = match self.quic_conn.send(&mut out) {
+ Ok(v) => v,
+ Err(quiche::Error::Done) => {
+ debug!("done writing");
+ break;
+ }
+ Err(e) => {
+ self.quic_conn.close(false, 0x1, b"fail").ok();
+ return Err(anyhow::Error::new(e));
+ }
+ };
+ self.udp_sk.send(&out[..write]).await?;
+ debug!("written {}", write);
+ }
+ Ok(())
}
}
-fn send_dns_query(
- query: &[u8],
- url: &url::Url,
- h3_conn: &mut Option<quiche::h3::Connection>,
- mut conn: &mut quiche::Connection,
-) -> Result<u64> {
- let h3_conn = h3_conn.as_mut().ok_or_else(|| anyhow!("h3 conn isn't available"))?;
+fn report_private_dns_validation(
+ info: &ServerInfo,
+ status: &ConnectionStatus,
+ runtime: Arc<Runtime>,
+ validation_fn: ValidationCallback,
+) {
+ let (ip_addr, domain) = match (
+ CString::new(info.peer_addr.ip().to_string()),
+ CString::new(info.domain.clone().unwrap_or_default()),
+ ) {
+ (Ok(ip_addr), Ok(domain)) => (ip_addr, domain),
+ _ => {
+ error!("report_private_dns_validation bad input");
+ return;
+ }
+ };
+ let netd_id = info.net_id;
+ let status = status.clone();
+ runtime.spawn_blocking(move || {
+ validation_fn(netd_id, status == ConnectionStatus::Ready, ip_addr.as_ptr(), domain.as_ptr())
+ });
+}
+fn handle_probe_result(
+ result: (ServerInfo, Result<DohConnection, (anyhow::Error, DohConnection)>),
+ doh_conn_map: &mut HashMap<u32, (ServerInfo, Option<DohConnection>)>,
+ runtime: Arc<Runtime>,
+ validation_fn: ValidationCallback,
+) {
+ let (info, doh_conn) = match result {
+ (info, Ok(doh_conn)) => {
+ info!("probing_task success on net_id: {}", info.net_id);
+ (info, doh_conn)
+ }
+ (info, Err((e, mut doh_conn))) => {
+ error!("probe failed on network {}, {:?}", e, info.net_id);
+ doh_conn.status = ConnectionStatus::Fail;
+ (info, doh_conn)
+ // TODO: Retry probe?
+ }
+ };
+ // If the network is removed or the server is replaced before probing,
+ // ignore the probe result.
+ match doh_conn_map.get(&info.net_id) {
+ Some((server_info, _)) => {
+ if *server_info != info {
+ warn!(
+ "The previous configuration for network {} was replaced before probe finished",
+ info.net_id
+ );
+ return;
+ }
+ }
+ _ => {
+ warn!("network {} was removed before probe finished", info.net_id);
+ return;
+ }
+ }
+ report_private_dns_validation(&info, &doh_conn.status, runtime, validation_fn);
+ doh_conn_map.insert(info.net_id, (info, Some(doh_conn)));
+}
+
+async fn probe_task(
+ info: ServerInfo,
+ mut doh: DohConnection,
+ t: Duration,
+) -> (ServerInfo, Result<DohConnection, (anyhow::Error, DohConnection)>) {
+ let req = match make_probe_query() {
+ Ok(q) => match make_dns_request(&q, &info.url) {
+ Ok(req) => req,
+ Err(e) => return (info, Err((anyhow!(e), doh))),
+ },
+ Err(e) => return (info, Err((anyhow!(e), doh))),
+ };
+ match timeout(t, doh.probe(req)).await {
+ Ok(v) => match v {
+ Ok(_) => (info, Ok(doh)),
+ Err(e) => (info, Err((e, doh))),
+ },
+ Err(e) => (info, Err((anyhow!(e), doh))),
+ }
+}
+
+fn make_connection_if_needed(
+ info: &ServerInfo,
+ doh_conn_map: &mut HashMap<u32, (ServerInfo, Option<DohConnection>)>,
+ config_cache: &mut QuicheConfigCache,
+) -> Result<Option<DohConnection>> {
+ // Check if connection exists.
+ match doh_conn_map.get(&info.net_id) {
+ // The connection exists but has failed. Re-probe.
+ Some((server_info, Some(doh)))
+ if *server_info == *info && doh.status == ConnectionStatus::Fail =>
+ {
+ let (_, doh) = doh_conn_map
+ .insert(info.net_id, (info.clone(), None))
+ .ok_or_else(|| anyhow!("unexpected error, missing connection"))?;
+ return Ok(doh);
+ }
+ // The connection exists or the connection is under probing, ignore.
+ Some((server_info, _)) if *server_info == *info => return Ok(None),
+ // TODO: change the inner connection instead of removing?
+ _ => doh_conn_map.remove(&info.net_id),
+ };
+ match &info.cert_path {
+ // The cert path is not either empty or SYSTEM_CERT_PATH, which means it's used by tests,
+ // it's not necessary to cache the config.
+ Some(cert_path) if cert_path != SYSTEM_CERT_PATH => {
+ let mut config = create_quiche_config(Some(&cert_path))?;
+ let doh = DohConnection::new(&info, &mut config)?;
+ doh_conn_map.insert(info.net_id, (info.clone(), None));
+ Ok(Some(doh))
+ }
+ // The normal cases, get the config from config cache.
+ cert_path => {
+ let config =
+ config_cache.get(&cert_path)?.ok_or_else(|| anyhow!("no quiche config"))?;
+ let doh = DohConnection::new(&info, config)?;
+ doh_conn_map.insert(info.net_id, (info.clone(), None));
+ Ok(Some(doh))
+ }
+ }
+}
+
+struct QuicheConfigCache {
+ cert_path: Option<String>,
+ config: Option<quiche::Config>,
+}
+
+impl QuicheConfigCache {
+ fn get(&mut self, cert_path: &Option<String>) -> Result<Option<&mut quiche::Config>> {
+ if !cert_path.as_ref().map_or(true, |path| path == SYSTEM_CERT_PATH) {
+ return Err(anyhow!("Custom cert_path is not allowed for config cache"));
+ }
+ // No config is cached or the cached config isn't matched with the input cert_path
+ // Create it with the input cert_path.
+ if self.config.is_none() || self.cert_path != *cert_path {
+ self.config = Some(create_quiche_config(cert_path.as_deref())?);
+ self.cert_path = cert_path.clone();
+ }
+ return Ok(self.config.as_mut());
+ }
+}
+
+fn resume_connection(
+ info: &ServerInfo,
+ quic_conn: &mut DohConnection,
+ config_cache: &mut QuicheConfigCache,
+) -> Result<()> {
+ let mut c = config_cache.get(&info.cert_path)?.ok_or_else(|| anyhow!("no quiche config"))?;
+ let connid = quiche::ConnectionId::from_ref(&quic_conn.scid);
+ let new_quic_conn = quiche::connect(info.domain.as_deref(), &connid, info.peer_addr, &mut c)?;
+ quic_conn.resume_connection(new_quic_conn);
+ Ok(())
+}
+
+async fn handle_query_cmd(
+ net_id: u32,
+ base64_query: Base64Query,
+ timeout: Duration,
+ resp: QueryResponder,
+ doh_conn_map: &mut HashMap<u32, (ServerInfo, Option<DohConnection>)>,
+ config_cache: &mut QuicheConfigCache,
+) {
+ if let Some((info, quic_conn)) = doh_conn_map.get_mut(&net_id) {
+ match (&info.domain, quic_conn) {
+ // Connection is not ready, strict mode
+ (Some(_), None) => {
+ let _ = resp.send(Response::Error { error: QueryError::ServerNotReady });
+ }
+ // Connection is not ready, Opportunistic mode
+ (None, None) => {
+ let _ = resp.send(Response::Error { error: QueryError::ServerNotReady });
+ }
+ // Connection is ready
+ (_, Some(quic_conn)) => {
+ if quic_conn.status == ConnectionStatus::Fail {
+ let _ = resp.send(Response::Error { error: QueryError::BrokenServer });
+ return;
+ } else if quic_conn.status == ConnectionStatus::Idle {
+ if let Err(e) = resume_connection(info, quic_conn, config_cache) {
+ error!("resume_connection failed {:?}", e);
+ let _ = resp.send(Response::Error { error: QueryError::BrokenServer });
+ return;
+ }
+ }
+ if let Ok(req) = make_dns_request(&base64_query, &info.url) {
+ debug!("Try to send query");
+ quic_conn.try_send_doh_query(req, timeout, resp).await;
+ } else {
+ let _ = resp.send(Response::Error { error: QueryError::Unexpected });
+ }
+ }
+ }
+ } else {
+ error!("No connection is associated with the given net id {}", net_id);
+ let _ = resp.send(Response::Error { error: QueryError::ServerNotReady });
+ }
+}
+
+async fn doh_handler(
+ mut cmd_rx: CmdReceiver,
+ runtime: Arc<Runtime>,
+ validation_fn: ValidationCallback,
+) -> Result<()> {
+ info!("doh_dispatcher entry");
+ let mut config_cache: QuicheConfigCache = QuicheConfigCache { cert_path: None, config: None };
+
+ // Currently, only support 1 server per network.
+ let mut doh_conn_map: HashMap<u32, (ServerInfo, Option<DohConnection>)> = HashMap::new();
+ let mut probe_futures = FuturesUnordered::new();
+ loop {
+ tokio::select! {
+ _ = async {
+ let mut futures = vec![];
+ for (_, doh_conn) in doh_conn_map.values_mut() {
+ if let Some(doh_conn) = doh_conn {
+ if doh_conn.status != ConnectionStatus::Fail {
+ futures.push(doh_conn.process_queries());
+ }
+ }
+ }
+ join_all(futures).await
+ } , if !doh_conn_map.is_empty() => {},
+ Some(result) = probe_futures.next() => {
+ let runtime_clone = runtime.clone();
+ handle_probe_result(result, &mut doh_conn_map, runtime_clone, validation_fn);
+ info!("probe_futures remaining size: {}", probe_futures.len());
+ },
+ Some(cmd) = cmd_rx.recv() => {
+ info!("recv {:?}", cmd);
+ match cmd {
+ DohCommand::Probe { info, timeout: t } => {
+ match make_connection_if_needed(&info, &mut doh_conn_map, &mut config_cache) {
+ Ok(Some(doh)) => {
+ // Create a new async task associated to the DoH connection.
+ probe_futures.push(probe_task(info, doh, t));
+ debug!("probe_map size: {}", probe_futures.len());
+ }
+ Ok(None) => {
+ // No further probe is needed.
+ warn!("connection for network {} already exists", info.net_id);
+ // TODO: Report the status again?
+ }
+ Err(e) => {
+ error!("create connection for network {} error {:?}", info.net_id, e);
+ report_private_dns_validation(&info, &ConnectionStatus::Fail, runtime.clone(), validation_fn);
+ }
+ }
+ },
+ DohCommand::Query { net_id, base64_query, timeout, resp } => {
+ handle_query_cmd(net_id, base64_query, timeout, resp, &mut doh_conn_map, &mut config_cache).await;
+ },
+ DohCommand::Clear { net_id } => {
+ doh_conn_map.remove(&net_id);
+ info!("Doh Clear server for netid: {}", net_id);
+ },
+ DohCommand::Exit => return Ok(()),
+ }
+ }
+ }
+ }
+}
+
+fn make_dns_request(base64_query: &str, url: &url::Url) -> Result<DnsRequest> {
let mut path = String::from(url.path());
path.push_str("?dns=");
- path.push_str(std::str::from_utf8(&query)?);
- let _req = vec![
+ path.push_str(&base64_query);
+ let req = vec![
quiche::h3::Header::new(b":method", b"GET"),
quiche::h3::Header::new(b":scheme", b"https"),
quiche::h3::Header::new(
@@ -286,86 +699,36 @@
// TODO: is content-length required?
];
- Ok(h3_conn.send_request(&mut conn, &_req, false /*fin*/)?)
+ Ok(req)
}
-async fn recv_query(
- h3_conn: &mut h3::Connection,
- mut conn: &mut quiche::Connection,
- map: &mut HashMap<u64, QueryResponder>,
-) {
- // Process HTTP/3 events.
- let mut buf = [0; MAX_DATAGRAM_SIZE];
- loop {
- match h3_conn.poll(&mut conn) {
- Ok((stream_id, quiche::h3::Event::Headers { list, has_body })) => {
- info!(
- "got response headers {:?} on stream id {} has_body {}",
- list, stream_id, has_body
- );
- }
- Ok((stream_id, quiche::h3::Event::Data)) => {
- debug!("quiche::h3::Event::Data");
- if let Ok(read) = h3_conn.recv_body(&mut conn, stream_id, &mut buf) {
- info!(
- "got {} bytes of response data on stream {}: {:x?}",
- read,
- stream_id,
- &buf[..read]
- );
- if let Some(resp) = map.remove(&stream_id) {
- resp.send(Some(buf[..read].to_vec())).unwrap_or_else(|e| {
- warn!("the receiver dropped {:?}", e);
- });
- }
- }
- }
- Ok((_stream_id, quiche::h3::Event::Finished)) => {
- debug!("quiche::h3::Event::Finished");
- }
- Ok((_stream_id, quiche::h3::Event::Datagram)) => {
- debug!("quiche::h3::Event::Datagram");
- }
- Ok((_stream_id, quiche::h3::Event::GoAway)) => {
- debug!("quiche::h3::Event::GoAway");
- }
- Err(quiche::h3::Error::Done) => {
- debug!("quiche::h3::Error::Done");
- break;
- }
- Err(e) => {
- error!("HTTP/3 processing failed: {:?}", e);
- break;
- }
- }
+fn make_doh_udp_socket(peer_addr: SocketAddr, mark: u32) -> Result<std::net::UdpSocket> {
+ let bind_addr = match peer_addr {
+ std::net::SocketAddr::V4(_) => "0.0.0.0:0",
+ std::net::SocketAddr::V6(_) => "[::]:0",
+ };
+ let udp_sk = std::net::UdpSocket::bind(bind_addr)?;
+ udp_sk.set_nonblocking(true)?;
+ if mark_socket(udp_sk.as_raw_fd(), mark).is_err() {
+ warn!("Mark socket failed, is it a test?");
}
-}
+ udp_sk.connect(peer_addr)?;
-async fn flush_tx(sk: &UdpSocket, conn: &mut quiche::Connection) -> Result<()> {
- let mut out = [0; MAX_DATAGRAM_SIZE];
- loop {
- let (write, _) = match conn.send(&mut out) {
- Ok(v) => v,
- Err(quiche::Error::Done) => {
- debug!("done writing");
- break;
- }
- Err(e) => {
- conn.close(false, 0x1, b"fail").ok();
- return Err(anyhow::Error::new(e));
- }
- };
- sk.send(&out[..write]).await?;
- debug!("written {}", write);
- }
- Ok(())
+ info!("connecting to {:} from {:}", peer_addr, udp_sk.local_addr()?);
+ Ok(udp_sk)
}
fn create_quiche_config(cert_path: Option<&str>) -> Result<quiche::Config> {
let mut config = quiche::Config::new(quiche::PROTOCOL_VERSION)?;
config.set_application_protos(h3::APPLICATION_PROTOCOL)?;
- config.verify_peer(true);
- config.load_verify_locations_from_directory(cert_path.unwrap_or(SYSTEM_CERT_PATH))?;
+ match cert_path {
+ Some(path) => {
+ config.verify_peer(true);
+ config.load_verify_locations_from_directory(path)?;
+ }
+ None => config.verify_peer(false),
+ }
+
// Some of these configs are necessary, or the server can't respond the HTTP/3 request.
config.set_max_idle_timeout(QUICHE_IDLE_TIMEOUT_MS);
config.set_max_recv_udp_payload_size(MAX_DATAGRAM_SIZE);
@@ -398,208 +761,191 @@
}
}
-/// Performs static initialization fo the DoH engine.
-#[no_mangle]
-pub extern "C" fn doh_init() -> *const c_char {
- android_logger::init_once(android_logger::Config::default().with_min_level(log::Level::Trace));
- static VERSION: &str = "1.0\0";
- VERSION.as_ptr() as *const c_char
+#[rustfmt::skip]
+fn make_probe_query() -> Result<String> {
+ let mut rnd = [0; 8];
+ ring::rand::SystemRandom::new().fill(&mut rnd).context("failed to generate probe rnd")?;
+ let c = |byte| CHARSET[(byte as usize) % CHARSET.len()];
+ let query = vec![
+ rnd[6], rnd[7], // [0-1] query ID
+ 1, 0, // [2-3] flags; query[2] = 1 for recursion desired (RD).
+ 0, 1, // [4-5] QDCOUNT (number of queries)
+ 0, 0, // [6-7] ANCOUNT (number of answers)
+ 0, 0, // [8-9] NSCOUNT (number of name server records)
+ 0, 0, // [10-11] ARCOUNT (number of additional records)
+ 19, c(rnd[0]), c(rnd[1]), c(rnd[2]), c(rnd[3]), c(rnd[4]), c(rnd[5]), b'-', b'd', b'n',
+ b's', b'o', b'h', b't', b't', b'p', b's', b'-', b'd', b's',
+ 6, b'm', b'e', b't', b'r', b'i', b'c', 7, b'g', b's',
+ b't', b'a', b't', b'i', b'c', 3, b'c', b'o', b'm',
+ 0, // null terminator of FQDN (root TLD)
+ 0, NS_T_AAAA, // QTYPE
+ 0, NS_C_IN // QCLASS
+ ];
+ Ok(base64::encode_config(query, base64::URL_SAFE_NO_PAD))
}
+/// Performs static initialization for the DoH engine.
/// Creates and returns a DoH engine instance.
-/// The returned object must be freed with doh_delete().
-/// # Safety
-/// All the pointer args are null terminated strings.
#[no_mangle]
-pub unsafe extern "C" fn doh_new(
- url: *const c_char,
- ip_addr: *const c_char,
- mark: libc::uint32_t,
- cert_path: *const c_char,
-) -> *mut DohDispatcher {
- let (url, ip_addr, cert_path) = match (
- std::ffi::CStr::from_ptr(url).to_str(),
- std::ffi::CStr::from_ptr(ip_addr).to_str(),
- std::ffi::CStr::from_ptr(cert_path).to_str(),
- ) {
- (Ok(url), Ok(ip_addr), Ok(cert_path)) => {
- if !cert_path.is_empty() {
- (url, ip_addr, Some(cert_path))
- } else {
- (url, ip_addr, None)
- }
- }
- _ => {
- error!("bad input");
- return ptr::null_mut();
- }
- };
- match DohDispatcher::new(url, ip_addr, mark, cert_path) {
+pub extern "C" fn doh_dispatcher_new(ptr: ValidationCallback) -> *mut DohDispatcher {
+ android_logger::init_once(android_logger::Config::default().with_min_level(log::Level::Info));
+ match DohDispatcher::new(ptr) {
Ok(c) => Box::into_raw(c),
Err(e) => {
- error!("doh_new: failed: {:?}", e);
+ error!("doh_dispatcher_new: failed: {:?}", e);
ptr::null_mut()
}
}
}
-/// Deletes a DoH engine created by doh_new().
+/// Deletes a DoH engine created by doh_dispatcher_new().
/// # Safety
-/// `doh` must be a non-null pointer previously created by `doh_new()`
-/// and not yet deleted by `doh_delete()`.
+/// `doh` must be a non-null pointer previously created by `doh_dispatcher_new()`
+/// and not yet deleted by `doh_dispatcher_delete()`.
#[no_mangle]
-pub unsafe extern "C" fn doh_delete(doh: *mut DohDispatcher) {
- Box::from_raw(doh).abort_handler()
+pub unsafe extern "C" fn doh_dispatcher_delete(doh: *mut DohDispatcher) {
+ Box::from_raw(doh).exit_handler()
}
-/// Sends a DNS query and waits for the response.
+/// Probes and stores the DoH server with the given configurations.
+/// Use the negative errno-style codes as the return value to represent the result.
/// # Safety
-/// `doh` must be a non-null pointer previously created by `doh_new()`
-/// and not yet deleted by `doh_delete()`.
-/// `query` must point to a buffer at least `query_len` in size.
+/// `doh` must be a non-null pointer previously created by `doh_dispatcher_new()`
+/// and not yet deleted by `doh_dispatcher_delete()`.
+/// `url`, `domain`, `ip_addr`, `cert_path` are null terminated strings.
+#[no_mangle]
+pub unsafe extern "C" fn doh_net_new(
+ doh: &mut DohDispatcher,
+ net_id: uint32_t,
+ url: *const c_char,
+ domain: *const c_char,
+ ip_addr: *const c_char,
+ sk_mark: libc::uint32_t,
+ cert_path: *const c_char,
+ timeout_ms: libc::uint64_t,
+) -> int32_t {
+ let (url, domain, ip_addr, cert_path) = match (
+ std::ffi::CStr::from_ptr(url).to_str(),
+ std::ffi::CStr::from_ptr(domain).to_str(),
+ std::ffi::CStr::from_ptr(ip_addr).to_str(),
+ std::ffi::CStr::from_ptr(cert_path).to_str(),
+ ) {
+ (Ok(url), Ok(domain), Ok(ip_addr), Ok(cert_path)) => {
+ if domain.is_empty() {
+ (url, None, ip_addr.to_string(), None)
+ } else if !cert_path.is_empty() {
+ (url, Some(domain.to_string()), ip_addr.to_string(), Some(cert_path.to_string()))
+ } else {
+ (
+ url,
+ Some(domain.to_string()),
+ ip_addr.to_string(),
+ Some(SYSTEM_CERT_PATH.to_string()),
+ )
+ }
+ }
+ _ => {
+ error!("bad input"); // Should not happen
+ return -libc::EINVAL;
+ }
+ };
+
+ let (url, ip_addr) = match (url::Url::parse(url), IpAddr::from_str(&ip_addr)) {
+ (Ok(url), Ok(ip_addr)) => (url, ip_addr),
+ _ => {
+ error!("bad ip or url"); // Should not happen
+ return -libc::EINVAL;
+ }
+ };
+ let cmd = DohCommand::Probe {
+ info: ServerInfo {
+ net_id,
+ url,
+ peer_addr: SocketAddr::new(ip_addr, DOH_PORT),
+ domain,
+ sk_mark,
+ cert_path,
+ },
+ timeout: Duration::from_millis(timeout_ms),
+ };
+ if let Err(e) = doh.send_cmd(cmd) {
+ error!("Failed to send the probe: {:?}", e);
+ return -libc::EPIPE;
+ }
+ 0
+}
+
+/// Sends a DNS query via the network associated to the given |net_id| and waits for the response.
+/// The return code should be either one of the public constant RESULT_* to indicate the error or
+/// the size of the answer.
+/// # Safety
+/// `doh` must be a non-null pointer previously created by `doh_dispatcher_new()`
+/// and not yet deleted by `doh_dispatcher_delete()`.
+/// `dns_query` must point to a buffer at least `dns_query_len` in size.
/// `response` must point to a buffer at least `response_len` in size.
#[no_mangle]
pub unsafe extern "C" fn doh_query(
doh: &mut DohDispatcher,
- query: *mut u8,
- query_len: size_t,
+ net_id: uint32_t,
+ dns_query: *mut u8,
+ dns_query_len: size_t,
response: *mut u8,
response_len: size_t,
+ timeout_ms: uint64_t,
) -> ssize_t {
- let q = slice::from_raw_parts_mut(query, query_len);
+ let q = slice::from_raw_parts_mut(dns_query, dns_query_len);
+
+ let t = Duration::from_millis(timeout_ms);
let (resp_tx, resp_rx) = oneshot::channel();
- let cmd = Command::DohQuery { query: q.to_vec(), resp: resp_tx };
- if let Err(e) = doh.query(cmd) {
+ let cmd = DohCommand::Query {
+ net_id,
+ base64_query: base64::encode_config(q, base64::URL_SAFE_NO_PAD),
+ timeout: t,
+ resp: resp_tx,
+ };
+
+ if let Err(e) = doh.send_cmd(cmd) {
error!("Failed to send the query: {:?}", e);
- return -1;
+ return RESULT_CAN_NOT_SEND;
}
- match RUNTIME_STATIC.block_on(resp_rx) {
- Ok(value) => {
- if let Some(resp) = value {
- if resp.len() > response_len || resp.len() > isize::MAX as usize {
- return -1;
+ if let Ok(rt) = Runtime::new() {
+ let local = task::LocalSet::new();
+ match local.block_on(&rt, async { timeout(t, resp_rx).await }) {
+ Ok(v) => match v {
+ Ok(v) => match v {
+ Response::Success { answer } => {
+ if answer.len() > response_len || answer.len() > isize::MAX as usize {
+ return RESULT_INTERNAL_ERROR;
+ }
+ let response = slice::from_raw_parts_mut(response, answer.len());
+ response.copy_from_slice(&answer);
+ answer.len() as ssize_t
+ }
+ Response::Error { error: QueryError::ServerNotReady } => RESULT_CAN_NOT_SEND,
+ _ => RESULT_INTERNAL_ERROR,
+ },
+ Err(e) => {
+ error!("no result {}", e);
+ RESULT_INTERNAL_ERROR
}
- let response = slice::from_raw_parts_mut(response, resp.len());
- response.copy_from_slice(&resp);
- return resp.len() as ssize_t;
+ },
+ Err(e) => {
+ error!("timeout: {}", e);
+ RESULT_TIMEOUT
}
- -1
}
- Err(e) => {
- error!("no result {}", e);
- -1
- }
+ } else {
+ RESULT_INTERNAL_ERROR
}
}
-#[cfg(test)]
-mod tests {
- use super::*;
- use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
-
- const TEST_MARK: u32 = 0xD0033;
- const LOOPBACK_ADDR: &str = "127.0.0.1";
-
- #[test]
- fn dohdispatcher_invalid_args() {
- let test_args = [
- // Bad url
- ("foo", "bar"),
- ("https://1", "bar"),
- ("https:/", "bar"),
- // Bad ip
- ("https://dns.google", "bar"),
- ("https://dns.google", "256.256.256.256"),
- ];
- for args in &test_args {
- assert!(
- DohDispatcher::new(args.0, args.1, 0, None).is_err(),
- "doh dispatcher should not be created"
- )
- }
- }
-
- #[test]
- fn make_doh_udp_socket() {
- // Bad ip
- for ip in &["foo", "1", "333.333.333.333"] {
- assert!(super::make_doh_udp_socket(ip, 0).is_err(), "udp socket should not be created");
- }
- // Make a socket connecting to loopback with a test mark.
- let (_, sk) = super::make_doh_udp_socket(LOOPBACK_ADDR, TEST_MARK).unwrap();
- // Check if the socket is connected to loopback.
- assert_eq!(
- sk.peer_addr().unwrap(),
- SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), DOH_PORT))
- );
-
- // Check if the socket mark is correct.
- let fd: RawFd = sk.as_raw_fd();
-
- let mut mark: u32 = 50;
- let mut size = std::mem::size_of::<u32>() as libc::socklen_t;
- unsafe {
- // Safety: fd must be valid.
- assert_eq!(
- libc::getsockopt(
- fd,
- libc::SOL_SOCKET,
- libc::SO_MARK,
- &mut mark as *mut _ as *mut libc::c_void,
- &mut size as *mut _ as *mut libc::socklen_t,
- ),
- 0
- );
- }
- assert_eq!(mark, TEST_MARK);
-
- // Check if the socket is non-blocking.
- unsafe {
- // Safety: fd must be valid.
- assert_eq!(libc::fcntl(fd, libc::F_GETFL, 0) & libc::O_NONBLOCK, libc::O_NONBLOCK);
- }
- }
-
- #[test]
- fn create_quiche_config() {
- assert!(
- super::create_quiche_config(None).is_ok(),
- "quiche config without cert creating failed"
- );
- assert!(
- super::create_quiche_config(Some("data/local/tmp/")).is_ok(),
- "quiche config with cert creating failed"
- );
- }
-
- const GOOGLE_DNS_URL: &str = "https://dns.google/dns-query";
- const GOOGLE_DNS_IP: &str = "8.8.8.8";
- // qtype: A, qname: www.example.com
- const SAMPLE_QUERY: &str = "q80BAAABAAAAAAAAA3d3dwdleGFtcGxlA2NvbQAAAQAB";
- #[test]
- fn close_doh() {
- let (peer_socket, udp_sk) = super::make_doh_udp_socket(LOOPBACK_ADDR, TEST_MARK).unwrap();
- let doh = DohDispatcher::new_with_socket(
- GOOGLE_DNS_URL,
- GOOGLE_DNS_IP,
- peer_socket,
- 0,
- None,
- udp_sk,
- )
- .unwrap();
- let (resp_tx, resp_rx) = oneshot::channel();
- let cmd = Command::DohQuery { query: SAMPLE_QUERY.as_bytes().to_vec(), resp: resp_tx };
- assert!(doh.query(cmd).is_ok(), "Send query failed");
- doh.abort_handler();
- assert!(RUNTIME_STATIC.block_on(resp_rx).is_err(), "channel should already be closed");
- }
-
- #[test]
- fn doh_init() {
- unsafe {
- // Safety: the returned pointer of doh_init() must be a null terminated string.
- assert_eq!(std::ffi::CStr::from_ptr(super::doh_init()).to_str().unwrap(), "1.0");
- }
+/// Clears the DoH servers associated with the given |netid|.
+/// # Safety
+/// `doh` must be a non-null pointer previously created by `doh_dispatcher_new()`
+/// and not yet deleted by `doh_dispatcher_delete()`.
+#[no_mangle]
+pub extern "C" fn doh_net_delete(doh: &mut DohDispatcher, net_id: uint32_t) {
+ if let Err(e) = doh.send_cmd(DohCommand::Clear { net_id }) {
+ error!("Failed to send the query: {:?}", e);
}
}
diff --git a/tests/Android.bp b/tests/Android.bp
index b917388..e83a4a7 100644
--- a/tests/Android.bp
+++ b/tests/Android.bp
@@ -304,5 +304,8 @@
"libring-core",
"libssl",
],
+ shared_libs: [
+ "libnetd_client",
+ ],
min_sdk_version: "29",
}
diff --git a/tests/doh_ffi_test.cpp b/tests/doh_ffi_test.cpp
index 5004156..6e80ec2 100644
--- a/tests/doh_ffi_test.cpp
+++ b/tests/doh_ffi_test.cpp
@@ -16,18 +16,60 @@
#include "doh.h"
+#include <chrono>
+#include <condition_variable>
+#include <mutex>
+
+#include <resolv.h>
+
+#include <NetdClient.h>
#include <gmock/gmock-matchers.h>
#include <gtest/gtest.h>
+static const char* GOOGLE_SERVER_IP = "8.8.8.8";
+static const int TIMEOUT_MS = 3000;
+constexpr int MAXPACKET = (8 * 1024);
+constexpr unsigned int MINIMAL_NET_ID = 100;
+
+std::mutex m;
+std::condition_variable cv;
+unsigned int dnsNetId;
+
TEST(DoHFFITest, SmokeTest) {
- EXPECT_STREQ(doh_init(), "1.0");
- DohServer* doh = doh_new("https://dns.google/dns-query", "8.8.8.8", 0, "");
+ getNetworkForDns(&dnsNetId);
+ // To ensure that we have a real network.
+ ASSERT_GE(dnsNetId, MINIMAL_NET_ID) << "No available networks";
+
+ auto callback = [](uint32_t netId, bool success, const char* ip_addr, const char* host) {
+ EXPECT_EQ(netId, dnsNetId);
+ EXPECT_TRUE(success);
+ EXPECT_STREQ(ip_addr, GOOGLE_SERVER_IP);
+ EXPECT_STREQ(host, "");
+ cv.notify_one();
+ };
+ DohDispatcher* doh = doh_dispatcher_new(callback);
EXPECT_TRUE(doh != nullptr);
- // www.example.com
- uint8_t query[] = "q80BAAABAAAAAAAAA3d3dwdleGFtcGxlA2NvbQAAAQAB";
+ // TODO: Use a local server instead of dns.google.
+ // sk_mark doesn't matter here because this test doesn't have permission to set sk_mark.
+ // The DNS packet would be sent via default network.
+ EXPECT_EQ(doh_net_new(doh, dnsNetId, "https://dns.google/dns-query", /* domain */ "",
+ GOOGLE_SERVER_IP,
+ /* sk_mark */ 0, /* cert_path */ "", TIMEOUT_MS),
+ 0);
+ {
+ std::unique_lock<std::mutex> lk(m);
+ EXPECT_EQ(cv.wait_for(lk, std::chrono::milliseconds(TIMEOUT_MS)),
+ std::cv_status::no_timeout);
+ }
+
+ std::vector<uint8_t> buf(MAXPACKET, 0);
+ ssize_t len = res_mkquery(ns_o_query, "www.example.com", ns_c_in, ns_t_aaaa, nullptr, 0,
+ nullptr, buf.data(), MAXPACKET);
uint8_t answer[8192];
- ssize_t len = doh_query(doh, query, sizeof query, answer, sizeof answer);
+
+ len = doh_query(doh, dnsNetId, buf.data(), len, answer, sizeof answer, TIMEOUT_MS);
EXPECT_GT(len, 0);
- doh_delete(doh);
+ doh_net_delete(doh, dnsNetId);
+ doh_dispatcher_delete(doh);
}