shill: Add http_proxy class
The http_proxy adds a device/connection based proxy that guarantees
to the caller that its HTTP request will go out a particular device's
connection. DNS requests occur through a bound socket to this device
and goes to DNS servers configured on this connection. HTTP requests
will also be bound to this interface. This facility will be used by
a number of peripheral bits including portal detection, activation and
cashew.
BUG=chromium-os:21664
TEST=New unit test. New (disabled) functional test, against which I
can run "curl -x" and Chrome with manual proxy settings.
Change-Id: I0d59bf0ae27d3538ef359f786742f5c2f1d6fef9
Reviewed-on: https://gerrit.chromium.org/gerrit/10165
Reviewed-by: Thieu Le <thieule@chromium.org>
Tested-by: Paul Stewart <pstew@chromium.org>
Commit-Ready: Paul Stewart <pstew@chromium.org>
diff --git a/Makefile b/Makefile
index 06c96dc..82331f9 100644
--- a/Makefile
+++ b/Makefile
@@ -71,6 +71,7 @@
DBUS_BINDINGS = $(DBUS_ADAPTOR_BINDINGS) $(DBUS_PROXY_BINDINGS)
SHILL_OBJS = \
+ async_connection.o \
byte_string.o \
cellular.o \
cellular_capability.o \
@@ -102,6 +103,7 @@
glib.o \
glib_io_ready_handler.o \
glib_io_input_handler.o \
+ http_proxy.o \
ip_address.o \
ipconfig.o \
ipconfig_dbus_adaptor.o \
@@ -151,6 +153,7 @@
TEST_BIN = shill_unittest
TEST_OBJS = \
+ async_connection_unittest.o \
byte_string_unittest.o \
cellular_capability_cdma_unittest.o \
cellular_capability_gsm_unittest.o \
@@ -169,12 +172,14 @@
dhcp_provider_unittest.o \
dns_client_unittest.o \
error_unittest.o \
+ http_proxy_unittest.o \
ip_address_unittest.o \
ipconfig_unittest.o \
key_file_store_unittest.o \
manager_unittest.o \
mock_adaptors.o \
mock_ares.o \
+ mock_async_connection.o \
mock_control.o \
mock_dbus_properties_proxy.o \
mock_device.o \
@@ -182,6 +187,7 @@
mock_dhcp_config.o \
mock_dhcp_provider.o \
mock_dhcp_proxy.o \
+ mock_dns_client.o \
mock_event_dispatcher.o \
mock_glib.o \
mock_ipconfig.o \
diff --git a/async_connection.cc b/async_connection.cc
new file mode 100644
index 0000000..63530ec
--- /dev/null
+++ b/async_connection.cc
@@ -0,0 +1,113 @@
+// Copyright (c) 2011 The Chromium OS Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "shill/async_connection.h"
+
+#include <errno.h>
+#include <netinet/in.h>
+
+#include <string>
+
+#include "shill/event_dispatcher.h"
+#include "shill/ip_address.h"
+#include "shill/sockets.h"
+
+using std::string;
+
+namespace shill {
+
+AsyncConnection::AsyncConnection(const string &interface_name,
+ EventDispatcher *dispatcher,
+ Sockets *sockets,
+ Callback2<bool, int>::Type *callback)
+ : interface_name_(interface_name),
+ dispatcher_(dispatcher),
+ sockets_(sockets),
+ callback_(callback),
+ fd_(-1),
+ connect_completion_callback_(
+ NewCallback(this, &AsyncConnection::OnConnectCompletion)) { }
+
+AsyncConnection::~AsyncConnection() {
+ Stop();
+}
+
+bool AsyncConnection::Start(const IPAddress &address, int port) {
+ DCHECK(fd_ < 0);
+
+ fd_ = sockets_->Socket(PF_INET, SOCK_STREAM, 0);
+ if (fd_ < 0 ||
+ sockets_->SetNonBlocking(fd_) < 0) {
+ error_ = sockets_->ErrorString();
+ PLOG(ERROR) << "Async socket setup failed";
+ Stop();
+ return false;
+ }
+
+ if (!interface_name_.empty() &&
+ sockets_->BindToDevice(fd_, interface_name_) < 0) {
+ error_ = sockets_->ErrorString();
+ PLOG(ERROR) << "Async socket failed to bind to device";
+ Stop();
+ return false;
+ }
+
+ struct sockaddr_in iaddr;
+ CHECK_EQ(sizeof(iaddr.sin_addr.s_addr), address.GetLength());
+
+ memset(&iaddr, 0, sizeof(iaddr));
+ iaddr.sin_family = AF_INET;
+ memcpy(&iaddr.sin_addr.s_addr, address.address().GetConstData(),
+ sizeof(iaddr.sin_addr.s_addr));
+ iaddr.sin_port = htons(port);
+
+ socklen_t addrlen = sizeof(iaddr);
+ int ret = sockets_->Connect(fd_,
+ reinterpret_cast<struct sockaddr *>(&iaddr),
+ addrlen);
+ if (ret == 0) {
+ callback_->Run(true, fd_); // Passes ownership
+ fd_ = -1;
+ return true;
+ }
+
+ if (sockets_->Error() != EINPROGRESS) {
+ error_ = sockets_->ErrorString();
+ PLOG(ERROR) << "Async socket connection failed";
+ Stop();
+ return false;
+ }
+
+ connect_completion_handler_.reset(
+ dispatcher_->CreateReadyHandler(fd_,
+ IOHandler::kModeOutput,
+ connect_completion_callback_.get()));
+ error_ = string();
+
+ return true;
+}
+
+void AsyncConnection::Stop() {
+ connect_completion_handler_.reset();
+ if (fd_ >= 0) {
+ sockets_->Close(fd_);
+ fd_ = -1;
+ }
+}
+
+void AsyncConnection::OnConnectCompletion(int fd) {
+ CHECK_EQ(fd_, fd);
+
+ if (sockets_->GetSocketError(fd_) != 0) {
+ error_ = sockets_->ErrorString();
+ PLOG(ERROR) << "Async GetSocketError returns failure";
+ callback_->Run(false, -1);
+ } else {
+ callback_->Run(true, fd_); // Passes ownership
+ fd_ = -1;
+ }
+ Stop();
+}
+
+} // namespace shill
diff --git a/async_connection.h b/async_connection.h
new file mode 100644
index 0000000..7e0675d
--- /dev/null
+++ b/async_connection.h
@@ -0,0 +1,79 @@
+// Copyright (c) 2011 The Chromium OS Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef SHILL_ASYNC_CONNECTION_
+#define SHILL_ASYNC_CONNECTION_
+
+#include <string>
+
+#include <base/callback_old.h>
+#include <base/memory/scoped_ptr.h>
+
+#include "shill/refptr_types.h"
+
+namespace shill {
+
+class EventDispatcher;
+class IOHandler;
+class IPAddress;
+class Sockets;
+
+// The AsyncConnection class implements an asynchronous
+// outgoing TCP connection. When passed an IPAddress and
+// port, and it will notify the caller when the connection
+// is made. It can also be passed an interface name to
+// bind the local side of the connection.
+class AsyncConnection {
+ public:
+ // If non-empty |interface_name| specifies an local interface from which
+ // to originate the connection.
+ AsyncConnection(const std::string &interface_name,
+ EventDispatcher *dispatcher,
+ Sockets *sockets,
+ Callback2<bool, int>::Type *callback);
+ virtual ~AsyncConnection();
+
+ // Open a connection given an IP address and port (in host order).
+ // When the connection completes, |callback| will be called with the
+ // a boolean (indicating success if true) and an fd of the opened socket
+ // (in the success case). If successful, ownership of this open fd is
+ // passed to the caller on execution of the callback.
+ //
+ // This function (Start) returns true if the connection is in progress,
+ // or if the connection has immediately succeeded (the callback will be
+ // called in this case). On success the callback may be called before
+ // Start() returns to its caller. On failure to start the connection,
+ // this function returns false, but does not execute the callback.
+ //
+ // Calling Start() on an AsyncConnection that is already Start()ed is
+ // an error.
+ virtual bool Start(const IPAddress &address, int port);
+
+ // Stop the open connection, closing any fds that are still owned.
+ // Calling Stop() on an unstarted or Stop()ped AsyncConnection is
+ // a no-op.
+ virtual void Stop();
+
+ std::string error() const { return error_; }
+
+ private:
+ friend class AsyncConnectionTest;
+
+ void OnConnectCompletion(int fd);
+
+ std::string interface_name_;
+ EventDispatcher *dispatcher_;
+ Sockets *sockets_;
+ Callback2<bool, int>::Type *callback_;
+ std::string error_;
+ int fd_;
+ scoped_ptr<Callback1<int>::Type> connect_completion_callback_;
+ scoped_ptr<IOHandler> connect_completion_handler_;
+
+ DISALLOW_COPY_AND_ASSIGN(AsyncConnection);
+};
+
+} // namespace shill
+
+#endif // SHILL_ASYNC_CONNECTION_
diff --git a/async_connection_unittest.cc b/async_connection_unittest.cc
new file mode 100644
index 0000000..b78781d
--- /dev/null
+++ b/async_connection_unittest.cc
@@ -0,0 +1,251 @@
+// Copyright (c) 2011 The Chromium OS Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "shill/async_connection.h"
+
+#include <netinet/in.h>
+
+#include <vector>
+
+#include <gtest/gtest.h>
+
+#include "shill/ip_address.h"
+#include "shill/mock_event_dispatcher.h"
+#include "shill/mock_sockets.h"
+
+using std::string;
+using ::testing::_;
+using ::testing::Return;
+using ::testing::ReturnNew;
+using ::testing::StrEq;
+using ::testing::StrictMock;
+using ::testing::Test;
+
+namespace shill {
+
+namespace {
+const char kInterfaceName[] = "int0";
+const char kConnectAddress[] = "10.11.12.13";
+const int kConnectPort = 10203;
+const int kErrorNumber = 30405;
+const int kSocketFD = 60708;
+} // namespace {}
+
+class AsyncConnectionTest : public Test {
+ public:
+ AsyncConnectionTest()
+ : async_connection_(kInterfaceName, &dispatcher_, &sockets_,
+ callback_target_.callback()),
+ address_(IPAddress::kFamilyIPv4) { }
+
+ virtual void SetUp() {
+ EXPECT_TRUE(address_.SetAddressFromString(kConnectAddress));
+ }
+ virtual void TearDown() {
+ if (async_connection_.fd_ >= 0) {
+ EXPECT_CALL(sockets(), Close(kSocketFD))
+ .WillOnce(Return(0));
+ }
+ }
+
+ protected:
+ class ConnectCallbackTarget {
+ public:
+ ConnectCallbackTarget()
+ : callback_(NewCallback(this, &ConnectCallbackTarget::CallTarget)) {}
+
+ MOCK_METHOD2(CallTarget, void(bool success, int fd));
+ Callback2<bool, int>::Type *callback() { return callback_.get(); }
+
+ private:
+ scoped_ptr<Callback2<bool, int>::Type> callback_;
+ };
+
+ void ExpectReset() {
+ EXPECT_STREQ(kInterfaceName, async_connection_.interface_name_.c_str());
+ EXPECT_EQ(&dispatcher_, async_connection_.dispatcher_);
+ EXPECT_EQ(&sockets_, async_connection_.sockets_);
+ EXPECT_EQ(callback_target_.callback(), async_connection_.callback_);
+ EXPECT_EQ(-1, async_connection_.fd_);
+ EXPECT_TRUE(async_connection_.connect_completion_callback_.get());
+ EXPECT_FALSE(async_connection_.connect_completion_handler_.get());
+ }
+
+ void StartConnection() {
+ EXPECT_CALL(sockets_, Socket(_, _, _))
+ .WillOnce(Return(kSocketFD));
+ EXPECT_CALL(sockets_, SetNonBlocking(kSocketFD))
+ .WillOnce(Return(0));
+ EXPECT_CALL(sockets_, BindToDevice(kSocketFD, StrEq(kInterfaceName)))
+ .WillOnce(Return(0));
+ EXPECT_CALL(sockets(), Connect(kSocketFD, _, _))
+ .WillOnce(Return(-1));
+ EXPECT_CALL(sockets_, Error())
+ .WillOnce(Return(EINPROGRESS));
+ EXPECT_CALL(dispatcher(), CreateReadyHandler(kSocketFD,
+ IOHandler::kModeOutput,
+ connect_completion_callback()))
+ .WillOnce(ReturnNew<IOHandler>());
+ EXPECT_TRUE(async_connection().Start(address_, kConnectPort));
+ }
+
+ void OnConnectCompletion(int fd) {
+ async_connection_.OnConnectCompletion(fd);
+ }
+ AsyncConnection &async_connection() { return async_connection_; }
+ StrictMock<MockSockets> &sockets() { return sockets_; }
+ MockEventDispatcher &dispatcher() { return dispatcher_; }
+ const IPAddress &address() { return address_; }
+ int fd() { return async_connection_.fd_; }
+ void set_fd(int fd) { async_connection_.fd_ = fd; }
+ StrictMock<ConnectCallbackTarget> &callback_target() {
+ return callback_target_;
+ }
+ Callback1<int>::Type *connect_completion_callback() {
+ return async_connection_.connect_completion_callback_.get();
+ }
+
+ private:
+ MockEventDispatcher dispatcher_;
+ StrictMock<MockSockets> sockets_;
+ StrictMock<ConnectCallbackTarget> callback_target_;
+ AsyncConnection async_connection_;
+ IPAddress address_;
+};
+
+TEST_F(AsyncConnectionTest, InitState) {
+ ExpectReset();
+ EXPECT_EQ(string(), async_connection().error());
+}
+
+TEST_F(AsyncConnectionTest, StartSocketFailure) {
+ EXPECT_CALL(sockets(), Socket(_, _, _))
+ .WillOnce(Return(-1));
+ EXPECT_CALL(sockets(), Error())
+ .WillOnce(Return(kErrorNumber));
+ EXPECT_FALSE(async_connection().Start(address(), kConnectPort));
+ ExpectReset();
+ EXPECT_STREQ(strerror(kErrorNumber), async_connection().error().c_str());
+}
+
+TEST_F(AsyncConnectionTest, StartNonBlockingFailure) {
+ EXPECT_CALL(sockets(), Socket(_, _, _))
+ .WillOnce(Return(kSocketFD));
+ EXPECT_CALL(sockets(), SetNonBlocking(kSocketFD))
+ .WillOnce(Return(-1));
+ EXPECT_CALL(sockets(), Error())
+ .WillOnce(Return(kErrorNumber));
+ EXPECT_CALL(sockets(), Close(kSocketFD))
+ .WillOnce(Return(0));
+ EXPECT_FALSE(async_connection().Start(address(), kConnectPort));
+ ExpectReset();
+ EXPECT_STREQ(strerror(kErrorNumber), async_connection().error().c_str());
+}
+
+TEST_F(AsyncConnectionTest, StartBindToDeviceFailure) {
+ EXPECT_CALL(sockets(), Socket(_, _, _))
+ .WillOnce(Return(kSocketFD));
+ EXPECT_CALL(sockets(), SetNonBlocking(kSocketFD))
+ .WillOnce(Return(0));
+ EXPECT_CALL(sockets(), BindToDevice(kSocketFD, StrEq(kInterfaceName)))
+ .WillOnce(Return(-1));
+ EXPECT_CALL(sockets(), Error())
+ .WillOnce(Return(kErrorNumber));
+ EXPECT_CALL(sockets(), Close(kSocketFD))
+ .WillOnce(Return(0));
+ EXPECT_FALSE(async_connection().Start(address(), kConnectPort));
+ ExpectReset();
+ EXPECT_STREQ(strerror(kErrorNumber), async_connection().error().c_str());
+}
+
+TEST_F(AsyncConnectionTest, SynchronousFailure) {
+ EXPECT_CALL(sockets(), Socket(_, _, _))
+ .WillOnce(Return(kSocketFD));
+ EXPECT_CALL(sockets(), SetNonBlocking(kSocketFD))
+ .WillOnce(Return(0));
+ EXPECT_CALL(sockets(), BindToDevice(kSocketFD, StrEq(kInterfaceName)))
+ .WillOnce(Return(0));
+ EXPECT_CALL(sockets(), Connect(kSocketFD, _, _))
+ .WillOnce(Return(-1));
+ EXPECT_CALL(sockets(), Error())
+ .Times(2)
+ .WillRepeatedly(Return(0));
+ EXPECT_CALL(sockets(), Close(kSocketFD))
+ .WillOnce(Return(0));
+ EXPECT_FALSE(async_connection().Start(address(), kConnectPort));
+ ExpectReset();
+}
+
+MATCHER_P2(IsSocketAddress, address, port, "") {
+ const struct sockaddr_in *arg_saddr =
+ reinterpret_cast<const struct sockaddr_in *>(arg);
+ IPAddress arg_addr(IPAddress::kFamilyIPv4,
+ ByteString(reinterpret_cast<const unsigned char *>(
+ &arg_saddr->sin_addr.s_addr),
+ sizeof(arg_saddr->sin_addr.s_addr)));
+ return address.Equals(arg_addr) && arg_saddr->sin_port == htons(port);
+}
+
+TEST_F(AsyncConnectionTest, SynchronousStart) {
+ EXPECT_CALL(sockets(), Socket(_, _, _))
+ .WillOnce(Return(kSocketFD));
+ EXPECT_CALL(sockets(), SetNonBlocking(kSocketFD))
+ .WillOnce(Return(0));
+ EXPECT_CALL(sockets(), BindToDevice(kSocketFD, StrEq(kInterfaceName)))
+ .WillOnce(Return(0));
+ EXPECT_CALL(sockets(), Connect(kSocketFD,
+ IsSocketAddress(address(), kConnectPort),
+ sizeof(struct sockaddr_in)))
+ .WillOnce(Return(-1));
+ EXPECT_CALL(dispatcher(),
+ CreateReadyHandler(kSocketFD,
+ IOHandler::kModeOutput,
+ connect_completion_callback()))
+ .WillOnce(ReturnNew<IOHandler>());
+ EXPECT_CALL(sockets(), Error())
+ .WillOnce(Return(EINPROGRESS));
+ EXPECT_TRUE(async_connection().Start(address(), kConnectPort));
+ EXPECT_EQ(kSocketFD, fd());
+}
+
+TEST_F(AsyncConnectionTest, AsynchronousFailure) {
+ StartConnection();
+ EXPECT_CALL(sockets(), GetSocketError(kSocketFD))
+ .WillOnce(Return(1));
+ EXPECT_CALL(sockets(), Error())
+ .WillOnce(Return(kErrorNumber));
+ EXPECT_CALL(callback_target(), CallTarget(false, -1));
+ EXPECT_CALL(sockets(), Close(kSocketFD))
+ .WillOnce(Return(0));
+ OnConnectCompletion(kSocketFD);
+ ExpectReset();
+ EXPECT_STREQ(strerror(kErrorNumber), async_connection().error().c_str());
+}
+
+TEST_F(AsyncConnectionTest, AsynchronousSuccess) {
+ StartConnection();
+ EXPECT_CALL(sockets(), GetSocketError(kSocketFD))
+ .WillOnce(Return(0));
+ EXPECT_CALL(callback_target(), CallTarget(true, kSocketFD));
+ OnConnectCompletion(kSocketFD);
+ ExpectReset();
+}
+
+TEST_F(AsyncConnectionTest, SynchronousSuccess) {
+ EXPECT_CALL(sockets(), Socket(_, _, _))
+ .WillOnce(Return(kSocketFD));
+ EXPECT_CALL(sockets(), SetNonBlocking(kSocketFD))
+ .WillOnce(Return(0));
+ EXPECT_CALL(sockets(), BindToDevice(kSocketFD, StrEq(kInterfaceName)))
+ .WillOnce(Return(0));
+ EXPECT_CALL(sockets(), Connect(kSocketFD,
+ IsSocketAddress(address(), kConnectPort),
+ sizeof(struct sockaddr_in)))
+ .WillOnce(Return(0));
+ EXPECT_CALL(callback_target(), CallTarget(true, kSocketFD));
+ EXPECT_TRUE(async_connection().Start(address(), kConnectPort));
+ ExpectReset();
+}
+
+} // namespace shill
diff --git a/byte_string.h b/byte_string.h
index 22ac750..efe94f4 100644
--- a/byte_string.h
+++ b/byte_string.h
@@ -52,6 +52,7 @@
bool IsZero() const;
bool Equals(const ByteString &b) const;
void Append(const ByteString &b);
+ void Clear() { data_.clear(); }
void Resize(int size) {
data_.resize(size, 0);
}
diff --git a/device.cc b/device.cc
index ad9e766..a40caf2 100644
--- a/device.cc
+++ b/device.cc
@@ -24,6 +24,7 @@
#include "shill/dhcp_provider.h"
#include "shill/error.h"
#include "shill/event_dispatcher.h"
+#include "shill/http_proxy.h"
#include "shill/manager.h"
#include "shill/property_accessor.h"
#include "shill/refptr_types.h"
diff --git a/dns_client.h b/dns_client.h
index dd47784..9f13c20 100644
--- a/dns_client.h
+++ b/dns_client.h
@@ -46,12 +46,12 @@
int timeout_ms,
EventDispatcher *dispatcher,
Callback1<bool>::Type *callback);
- ~DNSClient();
+ virtual ~DNSClient();
- bool Start(const std::string &hostname);
- void Stop();
- const IPAddress &address() const { return address_; }
- const std::string &error() const { return error_; }
+ virtual bool Start(const std::string &hostname);
+ virtual void Stop();
+ virtual const IPAddress &address() const { return address_; }
+ virtual const std::string &error() const { return error_; }
private:
friend class DNSClientTest;
diff --git a/glib_io_input_handler.cc b/glib_io_input_handler.cc
index 1f45652..a4418f4 100644
--- a/glib_io_input_handler.cc
+++ b/glib_io_input_handler.cc
@@ -17,6 +17,7 @@
unsigned char buf[4096];
gsize len;
GIOError err;
+ gboolean ret = TRUE;
if (cond & (G_IO_NVAL | G_IO_HUP | G_IO_ERR))
return FALSE;
@@ -26,13 +27,14 @@
if (err) {
if (err == G_IO_ERROR_AGAIN)
return TRUE;
- return FALSE;
+ len = 0;
+ ret = FALSE;
}
InputData input_data(buf, len);
callback->Run(&input_data);
- return TRUE;
+ return ret;
}
GlibIOInputHandler::GlibIOInputHandler(int fd,
diff --git a/glib_io_ready_handler.cc b/glib_io_ready_handler.cc
index bb4fa7c..b72ca0e 100644
--- a/glib_io_ready_handler.cc
+++ b/glib_io_ready_handler.cc
@@ -14,13 +14,14 @@
static gboolean DispatchIOHandler(GIOChannel *chan,
GIOCondition cond,
gpointer data) {
- Callback1<int>::Type *callback = static_cast<Callback1<int>::Type *>(data);
+ Callback1<int>::Type *callback =
+ reinterpret_cast<Callback1<int>::Type *>(data);
+
+ callback->Run(g_io_channel_unix_get_fd(chan));
if (cond & (G_IO_NVAL | G_IO_HUP | G_IO_ERR))
return FALSE;
- callback->Run(g_io_channel_unix_get_fd(chan));
-
return TRUE;
}
diff --git a/http_proxy.cc b/http_proxy.cc
new file mode 100644
index 0000000..60bc281
--- /dev/null
+++ b/http_proxy.cc
@@ -0,0 +1,652 @@
+// Copyright (c) 2011 The Chromium OS Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "shill/http_proxy.h"
+
+#include <errno.h>
+#include <netinet/in.h>
+#include <linux/if.h> // Needs definitions from netinet/in.h
+#include <stdio.h>
+#include <time.h>
+
+#include <string>
+#include <vector>
+
+#include <base/logging.h>
+#include <base/string_number_conversions.h>
+#include <base/string_split.h>
+#include <base/string_util.h>
+#include <base/stringprintf.h>
+
+#include "shill/async_connection.h"
+#include "shill/dns_client.h"
+#include "shill/event_dispatcher.h"
+#include "shill/ip_address.h"
+#include "shill/sockets.h"
+
+using base::StringPrintf;
+using std::string;
+using std::vector;
+
+namespace shill {
+
+const int HTTPProxy::kClientHeaderTimeoutSeconds = 1;
+const int HTTPProxy::kConnectTimeoutSeconds = 10;
+const int HTTPProxy::kDNSTimeoutSeconds = 5;
+const int HTTPProxy::kDefaultServerPort = 80;
+const int HTTPProxy::kInputTimeoutSeconds = 30;
+const size_t HTTPProxy::kMaxClientQueue = 10;
+const size_t HTTPProxy::kMaxHeaderCount = 128;
+const size_t HTTPProxy::kMaxHeaderSize = 2048;
+const int HTTPProxy::kTransactionTimeoutSeconds = 600;
+
+const char HTTPProxy::kHTTPURLDelimiters[] = " /#?";
+const char HTTPProxy::kHTTPURLPrefix[] = "http://";
+const char HTTPProxy::kHTTPVersionPrefix[] = " HTTP/1";
+const char HTTPProxy::kInternalErrorMsg[] = "Proxy Failed: Internal Error";
+
+
+HTTPProxy::HTTPProxy(const std::string &interface_name,
+ const std::vector<std::string> &dns_servers)
+ : state_(kStateIdle),
+ interface_name_(interface_name),
+ dns_servers_(dns_servers),
+ accept_callback_(NewCallback(this, &HTTPProxy::AcceptClient)),
+ connect_completion_callback_(
+ NewCallback(this, &HTTPProxy::OnConnectCompletion)),
+ dns_client_callback_(NewCallback(this, &HTTPProxy::GetDNSResult)),
+ read_client_callback_(NewCallback(this, &HTTPProxy::ReadFromClient)),
+ read_server_callback_(NewCallback(this, &HTTPProxy::ReadFromServer)),
+ write_client_callback_(NewCallback(this, &HTTPProxy::WriteToClient)),
+ write_server_callback_(NewCallback(this, &HTTPProxy::WriteToServer)),
+ task_factory_(this),
+ dispatcher_(NULL),
+ dns_client_(NULL),
+ proxy_port_(-1),
+ proxy_socket_(-1),
+ server_async_connection_(NULL),
+ sockets_(NULL),
+ client_socket_(-1),
+ server_port_(kDefaultServerPort),
+ server_socket_(-1),
+ idle_timeout_(NULL) { }
+
+HTTPProxy::~HTTPProxy() {
+ Stop();
+}
+
+bool HTTPProxy::Start(EventDispatcher *dispatcher,
+ Sockets *sockets) {
+ VLOG(3) << "In " << __func__;
+
+ if (sockets_) {
+ // We are already running.
+ return true;
+ }
+
+ proxy_socket_ = sockets->Socket(PF_INET, SOCK_STREAM, 0);
+ if (proxy_socket_ < 0) {
+ PLOG(ERROR) << "Failed to open proxy socket";
+ return false;
+ }
+
+ struct sockaddr_in addr;
+ socklen_t addrlen = sizeof(addr);
+ memset(&addr, 0, sizeof(addr));
+ addr.sin_family = AF_INET;
+ addr.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
+ if (sockets->Bind(proxy_socket_,
+ reinterpret_cast<struct sockaddr *>(&addr),
+ sizeof(addr)) < 0 ||
+ sockets->GetSockName(proxy_socket_,
+ reinterpret_cast<struct sockaddr *>(&addr),
+ &addrlen) < 0 ||
+ sockets->SetNonBlocking(proxy_socket_) < 0 ||
+ sockets->Listen(proxy_socket_, kMaxClientQueue) < 0) {
+ sockets->Close(proxy_socket_);
+ proxy_socket_ = -1;
+ PLOG(ERROR) << "HTTPProxy socket setup failed";
+ return false;
+ }
+
+ accept_handler_.reset(
+ dispatcher->CreateReadyHandler(proxy_socket_, IOHandler::kModeInput,
+ accept_callback_.get()));
+ dispatcher_ = dispatcher;
+ dns_client_.reset(new DNSClient(IPAddress::kFamilyIPv4,
+ interface_name_,
+ dns_servers_,
+ kDNSTimeoutSeconds * 1000,
+ dispatcher,
+ dns_client_callback_.get()));
+ proxy_port_ = ntohs(addr.sin_port);
+ server_async_connection_.reset(
+ new AsyncConnection(interface_name_, dispatcher, sockets,
+ connect_completion_callback_.get()));
+ sockets_ = sockets;
+ state_ = kStateWaitConnection;
+ return true;
+}
+
+void HTTPProxy::Stop() {
+ VLOG(3) << "In " << __func__;
+
+ if (!sockets_ ) {
+ return;
+ }
+
+ StopClient();
+
+ accept_handler_.reset();
+ dispatcher_ = NULL;
+ dns_client_.reset();
+ proxy_port_ = -1;
+ server_async_connection_.reset();
+ sockets_->Close(proxy_socket_);
+ proxy_socket_ = -1;
+ sockets_ = NULL;
+ state_ = kStateIdle;
+}
+
+// IOReadyHandler callback routine fired when a client connects to the
+// proxy's socket. We Accept() the client and start reading a request
+// from it.
+void HTTPProxy::AcceptClient(int fd) {
+ VLOG(3) << "In " << __func__;
+
+ int client_fd = sockets_->Accept(fd, NULL, NULL);
+ if (client_fd < 0) {
+ PLOG(ERROR) << "Client accept failed";
+ return;
+ }
+
+ accept_handler_->Stop();
+
+ client_socket_ = client_fd;
+
+ sockets_->SetNonBlocking(client_socket_);
+ read_client_handler_.reset(
+ dispatcher_->CreateInputHandler(client_socket_,
+ read_client_callback_.get()));
+ // Overall transaction timeout.
+ dispatcher_->PostDelayedTask(
+ task_factory_.NewRunnableMethod(&HTTPProxy::StopClient),
+ kTransactionTimeoutSeconds * 1000);
+
+ state_ = kStateReadClientHeader;
+ StartIdleTimeout();
+}
+
+bool HTTPProxy::ConnectServer(const IPAddress &address, int port) {
+ state_ = kStateConnectServer;
+ if (!server_async_connection_->Start(address, port)) {
+ SendClientError(500, "Could not create socket to connect to server");
+ return false;
+ }
+ StartIdleTimeout();
+ return true;
+}
+
+// DNSClient callback that fires when the DNS request completes.
+void HTTPProxy::GetDNSResult(bool result) {
+ if (!result) {
+ SendClientError(502, string("Could not resolve hostname: ") +
+ dns_client_->error());
+ return;
+ }
+ ConnectServer(dns_client_->address(), server_port_);
+}
+
+// IOReadyHandler callback routine which fires when the asynchronous Connect()
+// to the remote server completes (or fails).
+void HTTPProxy::OnConnectCompletion(bool success, int fd) {
+ if (!success) {
+ SendClientError(500, string("Socket connection delayed failure: ") +
+ server_async_connection_->error());
+ return;
+ }
+ server_socket_ = fd;
+ state_ = kStateTunnelData;
+ StartTransmit();
+}
+
+// Read through the header lines from the client, modifying or adding
+// lines as necessary. Perform final determination of the hostname/port
+// we should connect to and either start a DNS request or connect to a
+// numeric address.
+bool HTTPProxy::ParseClientRequest() {
+ VLOG(3) << "In " << __func__;
+
+ string host;
+ bool found_via = false;
+ bool found_connection = false;
+ for (vector<string>::iterator it = client_headers_.begin();
+ it != client_headers_.end(); ++it) {
+ if (StartsWithASCII(*it, "Host:", false)) {
+ host = it->substr(5);
+ } else if (StartsWithASCII(*it, "Via:", false)) {
+ found_via = true;
+ (*it).append(StringPrintf(", %s shill-proxy", client_version_.c_str()));
+ } else if (StartsWithASCII(*it, "Connection:", false)) {
+ found_connection = true;
+ (*it).assign("Connection: close");
+ } else if (StartsWithASCII(*it, "Proxy-Connection:", false)) {
+ (*it).assign("Proxy-Connection: close");
+ }
+ }
+
+ if (!found_connection) {
+ client_headers_.push_back("Connection: close");
+ }
+ if (!found_via) {
+ client_headers_.push_back(
+ StringPrintf("Via: %s shill-proxy", client_version_.c_str()));
+ }
+
+ // Assemble the request as it will be sent to the server.
+ client_data_.Clear();
+ for (vector<string>::iterator it = client_headers_.begin();
+ it != client_headers_.end(); ++it) {
+ client_data_.Append(ByteString(*it + "\r\n", false));
+ }
+ client_data_.Append(ByteString(string("\r\n"), false));
+
+ TrimWhitespaceASCII(host, TRIM_ALL, &host);
+ if (host.empty()) {
+ // Revert to using the hostname in the URL if no "Host:" header exists.
+ host = server_hostname_;
+ }
+
+ if (host.empty()) {
+ SendClientError(400, "I don't know what host you want me to connect to");
+ return false;
+ }
+
+ server_port_ = 80;
+ vector<string> host_parts;
+ base::SplitString(host, ':', &host_parts);
+
+ if (host_parts.size() > 2) {
+ SendClientError(400, "Too many colons in hostname");
+ return false;
+ } else if (host_parts.size() == 2) {
+ server_hostname_ = host_parts[0];
+ if (!base::StringToInt(host_parts[1], &server_port_)) {
+ SendClientError(400, "Could not parse port number");
+ return false;
+ }
+ } else {
+ server_hostname_ = host;
+ }
+
+ IPAddress addr(IPAddress::kFamilyIPv4);
+ if (addr.SetAddressFromString(server_hostname_)) {
+ if (!ConnectServer(addr, server_port_)) {
+ return false;
+ }
+ } else {
+ VLOG(3) << "Looking up host: " << server_hostname_;
+ if (!dns_client_->Start(server_hostname_)) {
+ SendClientError(502, "Could not resolve hostname");
+ return false;
+ }
+ state_ = kStateLookupServer;
+ }
+ return true;
+}
+
+// Accept a new line into the client headers. Returns false if a parse
+// error occurs.
+bool HTTPProxy::ProcessLastHeaderLine() {
+ string *header = &client_headers_.back();
+ TrimString(*header, "\r", header);
+
+ if (header->empty()) {
+ // Empty line terminates client headers.
+ client_headers_.pop_back();
+ if (!ParseClientRequest()) {
+ return false;
+ }
+ }
+
+ // Is this is the first header line?
+ if (client_headers_.size() == 1) {
+ if (!ReadClientHTTPVersion(header) || !ReadClientHostname(header)) {
+ return false;
+ }
+ }
+
+ if (client_headers_.size() >= kMaxHeaderCount) {
+ SendClientError(500, kInternalErrorMsg);
+ return false;
+ }
+
+ return true;
+}
+
+// Split input from client into header lines, and consume parsed lines
+// from InputData. The passed in |data| is modified to indicate the
+// characters consumed.
+bool HTTPProxy::ReadClientHeaders(InputData *data) {
+ unsigned char *ptr = data->buf;
+ unsigned char *end = ptr + data->len;
+
+ if (client_headers_.empty()) {
+ client_headers_.push_back(string());
+ }
+
+ for (; ptr < end && state_ == kStateReadClientHeader; ++ptr) {
+ if (*ptr == '\n') {
+ if (!ProcessLastHeaderLine()) {
+ return false;
+ }
+
+ // Start a new line. New chararacters we receive will be appended there.
+ client_headers_.push_back(string());
+ continue;
+ }
+
+ string *header = &client_headers_.back();
+ // Is the first character of the header line a space or tab character?
+ if (header->empty() && (*ptr == ' ' || *ptr == '\t') &&
+ client_headers_.size() > 1) {
+ // Line Continuation: Add this character to the previous header line.
+ // This way, all of the data (including newlines and line continuation
+ // characters) related to a specific header will be contained within
+ // a single element of |client_headers_|, and manipulation of headers
+ // such as appending will be simpler. This is accomplished by removing
+ // the empty line we started, and instead appending the whitespace
+ // and following characters to the previous line.
+ client_headers_.pop_back();
+ header = &client_headers_.back();
+ header->append("\r\n");
+ }
+
+ if (header->length() >= kMaxHeaderSize) {
+ SendClientError(500, kInternalErrorMsg);
+ return false;
+ }
+ header->push_back(*ptr);
+ }
+
+ // Return the remaining data to the caller -- this could be POST data
+ // or other non-header data sent with the client request.
+ data->buf = ptr;
+ data->len = end - ptr;
+
+ return true;
+}
+
+// Finds the URL in the first line of an HTTP client header, and extracts
+// and removes the hostname (and port) from the URL. Returns false if a
+// parse error occurs, and true otherwise (whether or not the hostname was
+// found).
+bool HTTPProxy::ReadClientHostname(string *header) {
+ const string http_url_prefix(kHTTPURLPrefix);
+ size_t url_idx = header->find(http_url_prefix);
+ if (url_idx != string::npos) {
+ size_t host_start = url_idx + http_url_prefix.length();
+ size_t host_end =
+ header->find_first_of(kHTTPURLDelimiters, host_start);
+ if (host_end != string::npos) {
+ server_hostname_ = header->substr(host_start,
+ host_end - host_start);
+ // Modify the URL passed upstream to remove "http://<hostname>".
+ header->erase(url_idx, host_end - url_idx);
+ if ((*header)[url_idx] != '/') {
+ header->insert(url_idx, "/");
+ }
+ } else {
+ LOG(ERROR) << "Could not find end of hostname in request. Line was: "
+ << *header;
+ SendClientError(500, kInternalErrorMsg);
+ return false;
+ }
+ }
+ return true;
+}
+
+// Extract the HTTP version number from the first line of the client headers.
+// Returns true if found.
+bool HTTPProxy::ReadClientHTTPVersion(string *header) {
+ const string http_version_prefix(kHTTPVersionPrefix);
+ size_t http_ver_pos = header->find(http_version_prefix);
+ if (http_ver_pos != string::npos) {
+ client_version_ =
+ header->substr(http_ver_pos + http_version_prefix.length() - 1);
+ } else {
+ SendClientError(501, "Server only accepts HTTP/1.x requests");
+ return false;
+ }
+ return true;
+}
+
+// IOInputHandler callback that fires when data is read from the client.
+// This could be header data, or perhaps POST data that follows the headers.
+void HTTPProxy::ReadFromClient(InputData *data) {
+ VLOG(3) << "In " << __func__ << " length " << data->len;
+
+ if (data->len == 0) {
+ // EOF from client.
+ StopClient();
+ return;
+ }
+
+ if (state_ == kStateReadClientHeader) {
+ if (!ReadClientHeaders(data)) {
+ return;
+ }
+ if (state_ == kStateReadClientHeader) {
+ // Still consuming client headers; restart the input timer.
+ StartIdleTimeout();
+ return;
+ }
+ }
+
+ // Check data->len again since ReadClientHeaders() may have consumed some
+ // part of it.
+ if (data->len != 0) {
+ // The client sent some information after its headers. Buffer the client
+ // input and temporarily disable input events from the client.
+ client_data_.Append(ByteString(data->buf, data->len));
+ read_client_handler_->Stop();
+ StartTransmit();
+ }
+}
+
+// IOInputHandler callback which fires when data has been read from the
+// server.
+void HTTPProxy::ReadFromServer(InputData *data) {
+ VLOG(3) << "In " << __func__ << " length " << data->len;
+ if (data->len == 0) {
+ // Server closed connection.
+ if (server_data_.IsEmpty()) {
+ StopClient();
+ return;
+ }
+ state_ = kStateFlushResponse;
+ } else {
+ read_server_handler_->Stop();
+ }
+
+ server_data_.Append(ByteString(data->buf, data->len));
+
+ StartTransmit();
+}
+
+// Return an HTTP error message back to the client.
+void HTTPProxy::SendClientError(int code, const string &error) {
+ VLOG(3) << "In " << __func__;
+ LOG(ERROR) << "Sending error " << error;
+
+ string error_msg = StringPrintf("HTTP/1.1 %d ERROR\r\n"
+ "Content-Type: text/plain\r\n\r\n"
+ "%s", code, error.c_str());
+ server_data_ = ByteString(error_msg, false);
+ state_ = kStateFlushResponse;
+ StartTransmit();
+}
+
+// Start a timeout for "the next event". This timeout augments the overall
+// transaction timeout to make sure there is some activity occurring at
+// reasonable intervals.
+void HTTPProxy::StartIdleTimeout() {
+ int timeout_seconds = 0;
+ switch (state_) {
+ case kStateReadClientHeader:
+ timeout_seconds = kClientHeaderTimeoutSeconds;
+ break;
+ case kStateConnectServer:
+ timeout_seconds = kConnectTimeoutSeconds;
+ break;
+ case kStateLookupServer:
+ // DNSClient has its own internal timeout, so we need not set one here.
+ timeout_seconds = 0;
+ break;
+ default:
+ timeout_seconds = kInputTimeoutSeconds;
+ break;
+ }
+ if (idle_timeout_) {
+ idle_timeout_->Cancel();
+ idle_timeout_ = NULL;
+ }
+ if (timeout_seconds != 0) {
+ idle_timeout_ = task_factory_.NewRunnableMethod(&HTTPProxy::StopClient);
+ dispatcher_->PostDelayedTask(idle_timeout_, timeout_seconds * 1000);
+ }
+}
+
+// Start the various input handlers. Listen for new data only if we have
+// completely written the last data we've received to the other end.
+void HTTPProxy::StartReceive() {
+ if (state_ == kStateTunnelData && client_data_.IsEmpty()) {
+ read_client_handler_->Start();
+ }
+ if (server_data_.IsEmpty()) {
+ if (state_ == kStateTunnelData) {
+ if (read_server_handler_.get()) {
+ read_server_handler_->Start();
+ } else {
+ read_server_handler_.reset(
+ dispatcher_->CreateInputHandler(server_socket_,
+ read_server_callback_.get()));
+ }
+ } else if (state_ == kStateFlushResponse) {
+ StopClient();
+ return;
+ }
+ }
+ StartIdleTimeout();
+}
+
+// Start the various output-ready handlers for the endpoints we have
+// data waiting for.
+void HTTPProxy::StartTransmit() {
+ if (state_ == kStateTunnelData && !client_data_.IsEmpty()) {
+ if (write_server_handler_.get()) {
+ write_server_handler_->Start();
+ } else {
+ write_server_handler_.reset(
+ dispatcher_->CreateReadyHandler(server_socket_,
+ IOHandler::kModeOutput,
+ write_server_callback_.get()));
+ }
+ }
+ if ((state_ == kStateFlushResponse || state_ == kStateTunnelData) &&
+ !server_data_.IsEmpty()) {
+ if (write_client_handler_.get()) {
+ write_client_handler_->Start();
+ } else {
+ write_client_handler_.reset(
+ dispatcher_->CreateReadyHandler(client_socket_,
+ IOHandler::kModeOutput,
+ write_client_callback_.get()));
+ }
+ }
+ StartIdleTimeout();
+}
+
+// End the transaction with the current client, restart the IOHandler
+// which alerts us to new clients connecting. This function is called
+// during various error conditions and is a callback for all timeouts.
+void HTTPProxy::StopClient() {
+ VLOG(3) << "In " << __func__;
+
+ write_client_handler_.reset();
+ read_client_handler_.reset();
+ if (client_socket_ != -1) {
+ sockets_->Close(client_socket_);
+ client_socket_ = -1;
+ }
+ client_headers_.clear();
+ client_version_.clear();
+ server_port_ = kDefaultServerPort;
+ write_server_handler_.reset();
+ read_server_handler_.reset();
+ if (server_socket_ != -1) {
+ sockets_->Close(server_socket_);
+ server_socket_ = -1;
+ }
+ server_hostname_.clear();
+ client_data_.Clear();
+ server_data_.Clear();
+ dns_client_->Stop();
+ server_async_connection_->Stop();
+ task_factory_.RevokeAll();
+ idle_timeout_ = NULL;
+ accept_handler_->Start();
+ state_ = kStateWaitConnection;
+}
+
+// Output ReadyHandler callback which fires when the client socket is
+// ready for data to be sent to it.
+void HTTPProxy::WriteToClient(int fd) {
+ CHECK_EQ(client_socket_, fd);
+ int ret = sockets_->Send(fd, server_data_.GetConstData(),
+ server_data_.GetLength(), 0);
+ VLOG(3) << "In " << __func__ << " wrote " << ret << " of " <<
+ server_data_.GetLength();
+ if (ret < 0) {
+ LOG(ERROR) << "Server write failed";
+ StopClient();
+ return;
+ }
+
+ server_data_ = ByteString(server_data_.GetConstData() + ret,
+ server_data_.GetLength() - ret);
+
+ if (server_data_.IsEmpty()) {
+ write_client_handler_->Stop();
+ }
+
+ StartReceive();
+}
+
+// Output ReadyHandler callback which fires when the server socket is
+// ready for data to be sent to it.
+void HTTPProxy::WriteToServer(int fd) {
+ CHECK_EQ(server_socket_, fd);
+ int ret = sockets_->Send(fd, client_data_.GetConstData(),
+ client_data_.GetLength(), 0);
+ VLOG(3) << "In " << __func__ << " wrote " << ret << " of " <<
+ client_data_.GetLength();
+
+ if (ret < 0) {
+ LOG(ERROR) << "Client write failed";
+ StopClient();
+ return;
+ }
+
+ client_data_ = ByteString(client_data_.GetConstData() + ret,
+ client_data_.GetLength() - ret);
+
+ if (client_data_.IsEmpty()) {
+ write_server_handler_->Stop();
+ }
+
+ StartReceive();
+}
+
+} // namespace shill
diff --git a/http_proxy.h b/http_proxy.h
new file mode 100644
index 0000000..6bc0601
--- /dev/null
+++ b/http_proxy.h
@@ -0,0 +1,154 @@
+// Copyright (c) 2011 The Chromium OS Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef SHILL_HTTP_PROXY_
+#define SHILL_HTTP_PROXY_
+
+#include <string>
+#include <vector>
+
+#include <base/callback_old.h>
+#include <base/memory/ref_counted.h>
+#include <base/memory/scoped_ptr.h>
+#include <base/task.h>
+
+#include "shill/byte_string.h"
+#include "shill/refptr_types.h"
+
+namespace shill {
+
+class AsyncConnection;
+class EventDispatcher;
+class DNSClient;
+class InputData;
+class IOHandler;
+class IPAddress;
+class Sockets;
+
+// The HTTPProxy class implements a simple web proxy that
+// is bound to a specific interface and name server. This
+// allows us to specify which connection a URL should be
+// fetched through, even though many connections
+// could be active at the same time.
+//
+// This service is meant to be low-performance, since we
+// do not want to divert resources from the rest of the
+// connection manager. As such, we serve one client request
+// at a time. This is probably okay since the use case is
+// limited -- only portal detection, activation and Cashew
+// are planned to be full-time users.
+class HTTPProxy {
+ public:
+ enum State {
+ kStateIdle,
+ kStateWaitConnection,
+ kStateReadClientHeader,
+ kStateLookupServer,
+ kStateConnectServer,
+ kStateTunnelData,
+ kStateFlushResponse,
+ };
+
+ HTTPProxy(const std::string &interface_name,
+ const std::vector<std::string> &dns_servers);
+ virtual ~HTTPProxy();
+
+ // Start HTTP proxy.
+ bool Start(EventDispatcher *dispatcher, Sockets *sockets);
+
+ // Shutdown.
+ void Stop();
+
+ int proxy_port() const { return proxy_port_; }
+
+ private:
+ friend class HTTPProxyTest;
+
+ // Time to wait for initial headers from client.
+ static const int kClientHeaderTimeoutSeconds;
+ // Time to wait for connection to remote server.
+ static const int kConnectTimeoutSeconds;
+ // Time to wait for DNS server.
+ static const int kDNSTimeoutSeconds;
+ // Default port on remote server to connect to.
+ static const int kDefaultServerPort;
+ // Time to wait for any input from either server or client.
+ static const int kInputTimeoutSeconds;
+ // Maximum clients to be kept waiting.
+ static const size_t kMaxClientQueue;
+ // Maximum number of header lines to accept.
+ static const size_t kMaxHeaderCount;
+ // Maximum length of an individual header line.
+ static const size_t kMaxHeaderSize;
+ // Timeout for whole transaction.
+ static const int kTransactionTimeoutSeconds;
+
+ static const char kHTTPURLDelimiters[];
+ static const char kHTTPURLPrefix[];
+ static const char kHTTPVersionPrefix[];
+ static const char kHTTPVersionErrorMsg[];
+ static const char kInternalErrorMsg[]; // Message to send on failure.
+
+ void AcceptClient(int fd);
+ bool ConnectServer(const IPAddress &address, int port);
+ void GetDNSResult(bool result);
+ void OnConnectCompletion(bool success, int fd);
+ bool ParseClientRequest();
+ bool ProcessLastHeaderLine();
+ bool ReadClientHeaders(InputData *data);
+ bool ReadClientHostname(std::string *header);
+ bool ReadClientHTTPVersion(std::string *header);
+ void ReadFromClient(InputData *data);
+ void ReadFromServer(InputData *data);
+ void SendClientError(int code, const std::string &error);
+ void StartIdleTimeout();
+ void StartReceive();
+ void StartTransmit();
+ void StopClient();
+ void WriteToClient(int fd);
+ void WriteToServer(int fd);
+
+ // State held for the lifetime of the proxy.
+ State state_;
+ const std::string interface_name_;
+ std::vector<std::string> dns_servers_;
+ scoped_ptr<Callback1<int>::Type> accept_callback_;
+ scoped_ptr<Callback2<bool, int>::Type> connect_completion_callback_;
+ scoped_ptr<Callback1<bool>::Type> dns_client_callback_;
+ scoped_ptr<Callback1<InputData *>::Type> read_client_callback_;
+ scoped_ptr<Callback1<InputData *>::Type> read_server_callback_;
+ scoped_ptr<Callback1<int>::Type> write_client_callback_;
+ scoped_ptr<Callback1<int>::Type> write_server_callback_;
+ ScopedRunnableMethodFactory<HTTPProxy> task_factory_;
+
+ // State held while proxy is started (even if no transaction is active).
+ scoped_ptr<IOHandler> accept_handler_;
+ EventDispatcher *dispatcher_;
+ scoped_ptr<DNSClient> dns_client_;
+ int proxy_port_;
+ int proxy_socket_;
+ scoped_ptr<AsyncConnection> server_async_connection_;
+ Sockets *sockets_;
+
+ // State held while proxy is started and a transaction is active.
+ int client_socket_;
+ std::string client_version_;
+ int server_port_;
+ int server_socket_;
+ CancelableTask *idle_timeout_;
+ std::vector<std::string> client_headers_;
+ std::string server_hostname_;
+ ByteString client_data_;
+ ByteString server_data_;
+ scoped_ptr<IOHandler> read_client_handler_;
+ scoped_ptr<IOHandler> write_client_handler_;
+ scoped_ptr<IOHandler> read_server_handler_;
+ scoped_ptr<IOHandler> write_server_handler_;
+
+ DISALLOW_COPY_AND_ASSIGN(HTTPProxy);
+};
+
+} // namespace shill
+
+#endif // SHILL_HTTP_PROXY_
diff --git a/http_proxy_unittest.cc b/http_proxy_unittest.cc
new file mode 100644
index 0000000..07e9cb7
--- /dev/null
+++ b/http_proxy_unittest.cc
@@ -0,0 +1,738 @@
+// Copyright (c) 2011 The Chromium OS Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "shill/http_proxy.h"
+
+#include <netinet/in.h>
+
+#include <string>
+#include <vector>
+
+#include <base/stringprintf.h>
+#include <gtest/gtest.h>
+
+#include "shill/ip_address.h"
+#include "shill/mock_async_connection.h"
+#include "shill/mock_dns_client.h"
+#include "shill/mock_event_dispatcher.h"
+#include "shill/mock_sockets.h"
+
+using base::StringPrintf;
+using std::string;
+using std::vector;
+using ::testing::_;
+using ::testing::AtLeast;
+using ::testing::DoAll;
+using ::testing::Invoke;
+using ::testing::Return;
+using ::testing::ReturnArg;
+using ::testing::ReturnNew;
+using ::testing::ReturnRef;
+using ::testing::SetArgumentPointee;
+using ::testing::StrEq;
+using ::testing::StrictMock;
+using ::testing::Test;
+
+namespace shill {
+
+namespace {
+const char kBadHeader[] = "BLAH\r\n";
+const char kBadHostnameLine[] = "GET HTTP/1.1 http://hostname\r\n";
+const char kBasicGetHeader[] = "GET / HTTP/1.1\r\n";
+const char kBasicGetHeaderWithURL[] =
+ "GET http://www.chromium.org/ HTTP/1.1\r\n";
+const char kBasicGetHeaderWithURLNoTrailingSlash[] =
+ "GET http://www.chromium.org HTTP/1.1\r\n";
+const char kQueryTemplate[] = "GET %s HTTP/%s\r\n%s"
+ "User-Agent: Mozilla/5.0 (X11; CrOS i686 1299.0.2011) "
+ "AppleWebKit/535.8 (KHTML, like Gecko) Chrome/17.0.936.0 Safari/535.8\r\n"
+ "Accept: text/html,application/xhtml+xml,application/xml;"
+ "q=0.9,*/*;q=0.8\r\n"
+ "Accept-Encoding: gzip,deflate,sdch\r\n"
+ "Accept-Language: en-US,en;q=0.8,ja;q=0.6\r\n"
+ "Accept-Charset: ISO-8859-1,utf-8;q=0.7,*;q=0.3\r\n"
+ "Cookie: PREF=ID=xxxxxxxxxxxxxxxx:U=xxxxxxxxxxxxxxxx:FF=0:"
+ "TM=1317340083:LM=1317390705:GM=1:S=_xxxxxxxxxxxxxxx; "
+ "NID=52=xxxxxxxxxxxxxxxxxxxx-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
+ "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx_xxxxxxxxxxxxxxxxxxxxxxx; "
+ "HSID=xxxxxxxxxxxx-xxxx; APISID=xxxxxxxxxxxxxxxx/xxxxxxxxxxxxxxxxx; "
+ "SID=xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx_xxxxxxxxxxx"
+ "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx-xxxxxxxxxxxxxxx"
+ "xxx_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx-xxxxxxxxxxxxxxxxxx"
+ "_xxxxx-xxxxxxxxxxxxxxxxxxxxxxxxxx-xx-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
+ "xxxxxxxxxxxxxxxx\r\n\r\n";
+const char kInterfaceName[] = "int0";
+const char kDNSServer0[] = "8.8.8.8";
+const char kDNSServer1[] = "8.8.4.4";
+const char kServerAddress[] = "10.10.10.10";
+const char *kDNSServers[] = { kDNSServer0, kDNSServer1 };
+const int kProxyFD = 10203;
+const int kServerFD = 10204;
+const int kClientFD = 10205;
+const int kServerPort = 40506;
+} // namespace {}
+
+MATCHER_P(IsIPAddress, address, "") {
+ IPAddress ip_address(IPAddress::kFamilyIPv4);
+ EXPECT_TRUE(ip_address.SetAddressFromString(address));
+ return ip_address.Equals(arg);
+}
+
+class HTTPProxyTest : public Test {
+ public:
+ HTTPProxyTest()
+ : server_async_connection_(NULL),
+ dns_servers_(kDNSServers, kDNSServers + 2),
+ dns_client_(NULL),
+ proxy_(kInterfaceName, dns_servers_) { }
+ protected:
+ virtual void TearDown() {
+ if (proxy_.sockets_) {
+ ExpectStop();
+ }
+ }
+ string CreateRequest(const string &url, const string &http_version,
+ const string &extra_lines) {
+ string append_lines(extra_lines);
+ if (append_lines.size()) {
+ append_lines.append("\r\n");
+ }
+ return StringPrintf(kQueryTemplate, url.c_str(), http_version.c_str(),
+ append_lines.c_str());
+ }
+ int InvokeGetSockName(int fd, struct sockaddr *addr_out,
+ socklen_t *sockaddr_size) {
+ struct sockaddr_in addr;
+ EXPECT_EQ(kProxyFD, fd);
+ EXPECT_GE(sizeof(sockaddr_in), *sockaddr_size);
+ addr.sin_addr.s_addr = 0;
+ addr.sin_port = kServerPort;
+ memcpy(addr_out, &addr, sizeof(addr));
+ *sockaddr_size = sizeof(sockaddr_in);
+ return 0;
+ }
+ void InvokeSyncConnect(const IPAddress &/*address*/, int /*port*/) {
+ proxy_.OnConnectCompletion(true, kServerFD);
+ }
+ size_t FindInRequest(const string &find_string) {
+ const ByteString &request_data = GetClientData();
+ string request_string(
+ reinterpret_cast<const char *>(request_data.GetConstData()),
+ request_data.GetLength());
+ return request_string.find(find_string);
+ }
+ // Accessors
+ const ByteString &GetClientData() {
+ return proxy_.client_data_;
+ }
+ HTTPProxy *proxy() { return &proxy_; }
+ HTTPProxy::State GetProxyState() {
+ return proxy_.state_;
+ }
+ const ByteString &GetServerData() {
+ return proxy_.server_data_;
+ }
+ MockSockets &sockets() { return sockets_; }
+ MockEventDispatcher &dispatcher() { return dispatcher_; }
+
+
+ // Expectations
+ void ExpectClientReset() {
+ EXPECT_EQ(-1, proxy_.client_socket_);
+ EXPECT_TRUE(proxy_.client_version_.empty());
+ EXPECT_EQ(HTTPProxy::kDefaultServerPort, proxy_.server_port_);
+ EXPECT_EQ(-1, proxy_.server_socket_);
+ EXPECT_FALSE(proxy_.idle_timeout_);
+ EXPECT_TRUE(proxy_.client_headers_.empty());
+ EXPECT_TRUE(proxy_.server_hostname_.empty());
+ EXPECT_TRUE(proxy_.client_data_.IsEmpty());
+ EXPECT_TRUE(proxy_.server_data_.IsEmpty());
+ EXPECT_FALSE(proxy_.read_client_handler_.get());
+ EXPECT_FALSE(proxy_.write_client_handler_.get());
+ EXPECT_FALSE(proxy_.read_server_handler_.get());
+ EXPECT_FALSE(proxy_.write_server_handler_.get());
+ }
+ void ExpectReset() {
+ EXPECT_FALSE(proxy_.accept_handler_.get());
+ EXPECT_FALSE(proxy_.dispatcher_);
+ EXPECT_FALSE(proxy_.dns_client_.get());
+ EXPECT_EQ(-1, proxy_.proxy_port_);
+ EXPECT_EQ(-1, proxy_.proxy_socket_);
+ EXPECT_FALSE(proxy_.server_async_connection_.get());
+ EXPECT_FALSE(proxy_.sockets_);
+ EXPECT_EQ(HTTPProxy::kStateIdle, proxy_.state_);
+ ExpectClientReset();
+ }
+ void ExpectStart() {
+ EXPECT_CALL(sockets(), Socket(_, _, _))
+ .WillOnce(Return(kProxyFD));
+ EXPECT_CALL(sockets(), Bind(kProxyFD, _, _))
+ .WillOnce(Return(0));
+ EXPECT_CALL(sockets(), GetSockName(kProxyFD, _, _))
+ .WillOnce(Invoke(this, &HTTPProxyTest::InvokeGetSockName));
+ EXPECT_CALL(sockets(), SetNonBlocking(kProxyFD))
+ .WillOnce(Return(0));
+ EXPECT_CALL(sockets(), Listen(kProxyFD, _))
+ .WillOnce(Return(0));
+ EXPECT_CALL(dispatcher_, CreateReadyHandler(kProxyFD,
+ IOHandler::kModeInput,
+ proxy_.accept_callback_.get()))
+ .WillOnce(ReturnNew<IOHandler>());
+ }
+ void ExpectStop() {
+ if (dns_client_) {
+ EXPECT_CALL(*dns_client_, Stop())
+ .Times(AtLeast(1));
+ }
+ if (server_async_connection_) {
+ EXPECT_CALL(*server_async_connection_, Stop())
+ .Times(AtLeast(1));
+ }
+ }
+ void ExpectClientInput(int fd) {
+ EXPECT_CALL(sockets(), Accept(kProxyFD, _, _))
+ .WillOnce(Return(fd));
+ EXPECT_CALL(sockets(), SetNonBlocking(fd))
+ .WillOnce(Return(0));
+ EXPECT_CALL(dispatcher(),
+ CreateInputHandler(fd, proxy_.read_client_callback_.get()))
+ .WillOnce(ReturnNew<IOHandler>());
+ ExpectTransactionTimeout();
+ ExpectClientHeaderTimeout();
+ }
+ void ExpectTimeout(int timeout) {
+ EXPECT_CALL(dispatcher_, PostDelayedTask(_, timeout * 1000))
+ .WillOnce(Return(true));
+ }
+ void ExpectClientHeaderTimeout() {
+ ExpectTimeout(HTTPProxy::kClientHeaderTimeoutSeconds);
+ }
+ void ExpectConnectTimeout() {
+ ExpectTimeout(HTTPProxy::kConnectTimeoutSeconds);
+ }
+ void ExpectInputTimeout() {
+ ExpectTimeout(HTTPProxy::kInputTimeoutSeconds);
+ }
+ void ExpectRepeatedInputTimeout() {
+ EXPECT_CALL(dispatcher_,
+ PostDelayedTask(_, HTTPProxy::kInputTimeoutSeconds * 1000))
+ .WillRepeatedly(Return(true));
+ }
+ void ExpectTransactionTimeout() {
+ ExpectTimeout(HTTPProxy::kTransactionTimeoutSeconds);
+ }
+ void ExpectClientError(int code, const string &error) {
+ EXPECT_EQ(HTTPProxy::kStateFlushResponse, GetProxyState());
+ string status_line = StringPrintf("HTTP/1.1 %d ERROR", code);
+ string server_data(reinterpret_cast<char *>(proxy_.server_data_.GetData()),
+ proxy_.server_data_.GetLength());
+ EXPECT_NE(string::npos, server_data.find(status_line));
+ EXPECT_NE(string::npos, server_data.find(error));
+ }
+ void ExpectClientInternalError() {
+ ExpectClientError(500, HTTPProxy::kInternalErrorMsg);
+ }
+ void ExpectClientVersion(const string &version) {
+ EXPECT_EQ(version, proxy_.client_version_);
+ }
+ void ExpectServerHostname(const string &hostname) {
+ EXPECT_EQ(hostname, proxy_.server_hostname_);
+ }
+ void ExpectFirstLine(const string &line) {
+ EXPECT_EQ(line, proxy_.client_headers_[0] + "\r\n");
+ }
+ void ExpectDNSRequest(const string &host, bool return_value) {
+ EXPECT_CALL(*dns_client_, Start(StrEq(host)))
+ .WillOnce(Return(return_value));
+ }
+ void ExpectDNSFailure(const string &error) {
+ EXPECT_CALL(*dns_client_, error())
+ .WillOnce(ReturnRef(error));
+ }
+ void ExpectAsyncConnect(const string &address, int port,
+ bool return_value) {
+ EXPECT_CALL(*server_async_connection_, Start(IsIPAddress(address), port))
+ .WillOnce(Return(return_value));
+ }
+ void ExpectSyncConnect(const string &address, int port) {
+ EXPECT_CALL(*server_async_connection_, Start(IsIPAddress(address), port))
+ .WillOnce(DoAll(Invoke(this, &HTTPProxyTest::InvokeSyncConnect),
+ Return(true)));
+ }
+ void ExpectClientResult() {
+ EXPECT_CALL(dispatcher(),
+ CreateReadyHandler(kClientFD,
+ IOHandler::kModeOutput,
+ proxy_.write_client_callback_.get()))
+ .WillOnce(ReturnNew<IOHandler>());
+ ExpectInputTimeout();
+ }
+ void ExpectServerInput() {
+ EXPECT_CALL(dispatcher(),
+ CreateInputHandler(kServerFD,
+ proxy_.read_server_callback_.get()))
+ .WillOnce(ReturnNew<IOHandler>());
+ ExpectInputTimeout();
+ }
+ void ExpectServerOutput() {
+ EXPECT_CALL(dispatcher(),
+ CreateReadyHandler(kServerFD,
+ IOHandler::kModeOutput,
+ proxy_.write_server_callback_.get()))
+ .WillOnce(ReturnNew<IOHandler>());
+ ExpectInputTimeout();
+ }
+ void ExpectRepeatedServerOutput() {
+ EXPECT_CALL(dispatcher(),
+ CreateReadyHandler(kServerFD,
+ IOHandler::kModeOutput,
+ proxy_.write_server_callback_.get()))
+ .WillOnce(ReturnNew<IOHandler>());
+ ExpectRepeatedInputTimeout();
+ }
+
+ void ExpectTunnelClose() {
+ EXPECT_CALL(sockets(), Close(kClientFD))
+ .WillOnce(Return(0));
+ EXPECT_CALL(sockets(), Close(kServerFD))
+ .WillOnce(Return(0));
+ ExpectStop();
+ }
+
+ // Callers for various private routines in the proxy
+ bool StartProxy() {
+ bool ret = proxy_.Start(&dispatcher_, &sockets_);
+ if (ret) {
+ dns_client_ = new StrictMock<MockDNSClient>();
+ // Passes ownership.
+ proxy_.dns_client_.reset(dns_client_);
+ server_async_connection_ = new StrictMock<MockAsyncConnection>();
+ // Passes ownership.
+ proxy_.server_async_connection_.reset(server_async_connection_);
+ }
+ return ret;
+ }
+ void AcceptClient(int fd) {
+ proxy_.AcceptClient(fd);
+ }
+ void GetDNSResult(bool result) {
+ proxy_.GetDNSResult(result);
+ }
+ void OnConnectCompletion(bool result, int sockfd) {
+ proxy_.OnConnectCompletion(result, sockfd);
+ }
+ void ReadFromClient(const string &data) {
+ const unsigned char *ptr =
+ reinterpret_cast<const unsigned char *>(data.c_str());
+ vector<unsigned char> data_bytes(ptr, ptr + data.length());
+ InputData proxy_data(data_bytes.data(), data_bytes.size());
+ proxy_.ReadFromClient(&proxy_data);
+ }
+ void ReadFromServer(const string &data) {
+ const unsigned char *ptr =
+ reinterpret_cast<const unsigned char *>(data.c_str());
+ vector<unsigned char> data_bytes(ptr, ptr + data.length());
+ InputData proxy_data(data_bytes.data(), data_bytes.size());
+ proxy_.ReadFromServer(&proxy_data);
+ }
+ void SendClientError(int code, const string &error) {
+ proxy_.SendClientError(code, error);
+ EXPECT_FALSE(proxy_.server_data_.IsEmpty());
+ }
+ void StopClient() {
+ EXPECT_CALL(*dns_client_, Stop());
+ EXPECT_CALL(*server_async_connection_, Stop());
+ proxy_.StopClient();
+ }
+ void StopProxy() {
+ ExpectStop();
+ proxy_.Stop();
+ server_async_connection_ = NULL;
+ dns_client_ = NULL;
+ ExpectReset();
+ }
+ void WriteToClient(int fd) {
+ proxy_.WriteToClient(fd);
+ }
+ void WriteToServer(int fd) {
+ proxy_.WriteToServer(fd);
+ }
+
+ void SetupClient() {
+ ExpectStart();
+ ASSERT_TRUE(StartProxy());
+ ExpectClientInput(kClientFD);
+ AcceptClient(kProxyFD);
+ EXPECT_EQ(HTTPProxy::kStateReadClientHeader, GetProxyState());
+ }
+ void SetupConnectWithRequest(const string &url, const string &http_version,
+ const string &extra_lines) {
+ ExpectDNSRequest("www.chromium.org", true);
+ ReadFromClient(CreateRequest(url, http_version, extra_lines));
+ IPAddress addr(IPAddress::kFamilyIPv4);
+ EXPECT_TRUE(addr.SetAddressFromString(kServerAddress));
+ EXPECT_CALL(*dns_client_, address())
+ .WillOnce(ReturnRef(addr));;
+ GetDNSResult(true);
+ }
+ void SetupConnect() {
+ SetupConnectWithRequest("/", "1.1", "Host: www.chromium.org:40506");
+ }
+ void SetupConnectAsync() {
+ SetupClient();
+ ExpectAsyncConnect(kServerAddress, kServerPort, true);
+ ExpectConnectTimeout();
+ SetupConnect();
+ }
+ void SetupConnectComplete() {
+ SetupConnectAsync();
+ ExpectServerOutput();
+ OnConnectCompletion(true, kServerFD);
+ EXPECT_EQ(HTTPProxy::kStateTunnelData, GetProxyState());
+ }
+
+ private:
+ // Owned by the HTTPProxy, but tracked here for EXPECT().
+ StrictMock<MockAsyncConnection> *server_async_connection_;
+ vector<string> dns_servers_;
+ // Owned by the HTTPProxy, but tracked here for EXPECT().
+ StrictMock<MockDNSClient> *dns_client_;
+ MockEventDispatcher dispatcher_;
+ HTTPProxy proxy_;
+ StrictMock<MockSockets> sockets_;
+};
+
+TEST_F(HTTPProxyTest, StartFailSocket) {
+ EXPECT_CALL(sockets(), Socket(_, _, _))
+ .WillOnce(Return(-1));
+ EXPECT_FALSE(StartProxy());
+ ExpectReset();
+}
+
+TEST_F(HTTPProxyTest, StartFailBind) {
+ EXPECT_CALL(sockets(), Socket(_, _, _))
+ .WillOnce(Return(kProxyFD));
+ EXPECT_CALL(sockets(), Bind(kProxyFD, _, _))
+ .WillOnce(Return(-1));
+ EXPECT_CALL(sockets(), Close(kProxyFD))
+ .WillOnce(Return(0));
+ EXPECT_FALSE(StartProxy());
+ ExpectReset();
+}
+
+TEST_F(HTTPProxyTest, StartFailGetSockName) {
+ EXPECT_CALL(sockets(), Socket(_, _, _))
+ .WillOnce(Return(kProxyFD));
+ EXPECT_CALL(sockets(), Bind(kProxyFD, _, _))
+ .WillOnce(Return(0));
+ EXPECT_CALL(sockets(), GetSockName(kProxyFD, _, _))
+ .WillOnce(Return(-1));
+ EXPECT_CALL(sockets(), Close(kProxyFD))
+ .WillOnce(Return(0));
+ EXPECT_FALSE(StartProxy());
+ ExpectReset();
+}
+
+TEST_F(HTTPProxyTest, StartFailSetNonBlocking) {
+ EXPECT_CALL(sockets(), Socket(_, _, _))
+ .WillOnce(Return(kProxyFD));
+ EXPECT_CALL(sockets(), Bind(kProxyFD, _, _))
+ .WillOnce(Return(0));
+ EXPECT_CALL(sockets(), GetSockName(kProxyFD, _, _))
+ .WillOnce(Return(0));
+ EXPECT_CALL(sockets(), SetNonBlocking(kProxyFD))
+ .WillOnce(Return(-1));
+ EXPECT_CALL(sockets(), Close(kProxyFD))
+ .WillOnce(Return(0));
+ EXPECT_FALSE(StartProxy());
+ ExpectReset();
+}
+
+TEST_F(HTTPProxyTest, StartFailListen) {
+ EXPECT_CALL(sockets(), Socket(_, _, _))
+ .WillOnce(Return(kProxyFD));
+ EXPECT_CALL(sockets(), Bind(kProxyFD, _, _))
+ .WillOnce(Return(0));
+ EXPECT_CALL(sockets(), GetSockName(kProxyFD, _, _))
+ .WillOnce(Return(0));
+ EXPECT_CALL(sockets(), SetNonBlocking(kProxyFD))
+ .WillOnce(Return(0));
+ EXPECT_CALL(sockets(), Listen(kProxyFD, _))
+ .WillOnce(Return(-1));
+ EXPECT_CALL(sockets(), Close(kProxyFD))
+ .WillOnce(Return(0));
+ EXPECT_FALSE(StartProxy());
+ ExpectReset();
+}
+
+TEST_F(HTTPProxyTest, StartSuccess) {
+ ExpectStart();
+ EXPECT_TRUE(StartProxy());
+}
+
+TEST_F(HTTPProxyTest, SendClientError) {
+ SetupClient();
+ ExpectClientResult();
+ SendClientError(500, "This is an error");
+ ExpectClientError(500, "This is an error");
+
+ // We succeed in sending all but one byte of the client response.
+ int buf_len = GetServerData().GetLength();
+ EXPECT_CALL(sockets(), Send(kClientFD, _, buf_len, 0))
+ .WillOnce(Return(buf_len - 1));
+ ExpectInputTimeout();
+ WriteToClient(kClientFD);
+ EXPECT_EQ(1, GetServerData().GetLength());
+ EXPECT_EQ(HTTPProxy::kStateFlushResponse, GetProxyState());
+
+ // When we are able to send the last byte, we close the connection.
+ EXPECT_CALL(sockets(), Send(kClientFD, _, 1, 0))
+ .WillOnce(Return(1));
+ EXPECT_CALL(sockets(), Close(kClientFD))
+ .WillOnce(Return(0));
+ ExpectStop();
+ WriteToClient(kClientFD);
+ EXPECT_EQ(HTTPProxy::kStateWaitConnection, GetProxyState());
+}
+
+TEST_F(HTTPProxyTest, ReadBadFirstLine) {
+ SetupClient();
+ ExpectClientResult();
+ ReadFromClient(kBadHeader);
+ ExpectClientError(501, "Server only accepts HTTP/1.x requests");
+}
+
+TEST_F(HTTPProxyTest, ReadBadHostname) {
+ SetupClient();
+ ExpectClientResult();
+ ReadFromClient(kBadHostnameLine);
+ ExpectClientInternalError();
+}
+
+TEST_F(HTTPProxyTest, GoodFirstLineWithoutURL) {
+ SetupClient();
+ ExpectClientHeaderTimeout();
+ ReadFromClient(kBasicGetHeader);
+ ExpectClientVersion("1.1");
+ ExpectServerHostname("");
+ ExpectFirstLine(kBasicGetHeader);
+}
+
+TEST_F(HTTPProxyTest, GoodFirstLineWithURL) {
+ SetupClient();
+ ExpectClientHeaderTimeout();
+ ReadFromClient(kBasicGetHeaderWithURL);
+ ExpectClientVersion("1.1");
+ ExpectServerHostname("www.chromium.org");
+ ExpectFirstLine(kBasicGetHeader);
+}
+
+TEST_F(HTTPProxyTest, GoodFirstLineWithURLNoSlash) {
+ SetupClient();
+ ExpectClientHeaderTimeout();
+ ReadFromClient(kBasicGetHeaderWithURLNoTrailingSlash);
+ ExpectClientVersion("1.1");
+ ExpectServerHostname("www.chromium.org");
+ ExpectFirstLine(kBasicGetHeader);
+}
+
+TEST_F(HTTPProxyTest, NoHostInRequest) {
+ SetupClient();
+ ExpectClientResult();
+ ReadFromClient(CreateRequest("/", "1.1", ""));
+ ExpectClientError(400, "I don't know what host you want me to connect to");
+}
+
+TEST_F(HTTPProxyTest, TooManyColonsInHost) {
+ SetupClient();
+ ExpectClientResult();
+ ReadFromClient(CreateRequest("/", "1.1", "Host: www.chromium.org:80:40506"));
+ ExpectClientError(400, "Too many colons in hostname");
+}
+
+TEST_F(HTTPProxyTest, DNSRequestFailure) {
+ SetupClient();
+ ExpectDNSRequest("www.chromium.org", false);
+ ExpectClientResult();
+ ReadFromClient(CreateRequest("/", "1.1", "Host: www.chromium.org:40506"));
+ ExpectClientError(502, "Could not resolve hostname");
+}
+
+TEST_F(HTTPProxyTest, DNSRequestDelayedFailure) {
+ SetupClient();
+ ExpectDNSRequest("www.chromium.org", true);
+ ReadFromClient(CreateRequest("/", "1.1", "Host: www.chromium.org:40506"));
+ ExpectClientResult();
+ const std::string not_found_error(DNSClient::kErrorNotFound);
+ ExpectDNSFailure(not_found_error);
+ GetDNSResult(false);
+ ExpectClientError(502, string("Could not resolve hostname: ") +
+ not_found_error);
+}
+
+TEST_F(HTTPProxyTest, TrailingClientData) {
+ SetupClient();
+ ExpectDNSRequest("www.chromium.org", true);
+ const string trailing_data("Trailing client data");
+ ReadFromClient(CreateRequest("/", "1.1", "Host: www.chromium.org:40506") +
+ trailing_data);
+ EXPECT_EQ(GetClientData().GetLength() - trailing_data.length(),
+ FindInRequest(trailing_data));
+ EXPECT_EQ(HTTPProxy::kStateLookupServer, GetProxyState());
+}
+
+TEST_F(HTTPProxyTest, LineContinuation) {
+ SetupClient();
+ ExpectDNSRequest("www.chromium.org", true);
+ string text_to_keep("X-Long-Header: this is one line\r\n"
+ "\tand this is another");
+ ReadFromClient(CreateRequest("http://www.chromium.org/", "1.1",
+ text_to_keep));
+ EXPECT_NE(string::npos, FindInRequest(text_to_keep));
+}
+
+// NB: This tests two different things:
+// 1) That the system replaces the value for "Proxy-Connection" headers.
+// 2) That when it replaces a header, it also removes the text in the line
+// continuation.
+TEST_F(HTTPProxyTest, LineContinuationRemoval) {
+ SetupClient();
+ ExpectDNSRequest("www.chromium.org", true);
+ string text_to_remove("remove this text please");
+ ReadFromClient(CreateRequest("http://www.chromium.org/", "1.1",
+ string("Proxy-Connection: stuff\r\n\t") +
+ text_to_remove));
+ EXPECT_EQ(string::npos, FindInRequest(text_to_remove));
+ EXPECT_NE(string::npos, FindInRequest("Proxy-Connection: close\r\n"));
+}
+
+TEST_F(HTTPProxyTest, ConnectSynchronousFailure) {
+ SetupClient();
+ ExpectAsyncConnect(kServerAddress, kServerPort, false);
+ ExpectClientResult();
+ SetupConnect();
+ ExpectClientError(500, "Could not create socket to connect to server");
+}
+
+TEST_F(HTTPProxyTest, ConnectAsyncConnectFailure) {
+ SetupConnectAsync();
+ ExpectClientResult();
+ OnConnectCompletion(false, -1);
+ ExpectClientError(500, "Socket connection delayed failure");
+}
+
+TEST_F(HTTPProxyTest, ConnectSynchronousSuccess) {
+ SetupClient();
+ ExpectSyncConnect(kServerAddress, 999);
+ ExpectRepeatedServerOutput();
+ SetupConnectWithRequest("/", "1.1", "Host: www.chromium.org:999");
+ EXPECT_EQ(HTTPProxy::kStateTunnelData, GetProxyState());
+}
+
+TEST_F(HTTPProxyTest, ConnectIPAddresss) {
+ SetupClient();
+ ExpectSyncConnect(kServerAddress, 999);
+ ExpectRepeatedServerOutput();
+ ReadFromClient(CreateRequest("/", "1.1",
+ StringPrintf("Host: %s:999", kServerAddress)));
+ EXPECT_EQ(HTTPProxy::kStateTunnelData, GetProxyState());
+}
+
+TEST_F(HTTPProxyTest, ConnectAsyncConnectSuccess) {
+ SetupConnectComplete();
+}
+
+TEST_F(HTTPProxyTest, TunnelData) {
+ SetupConnectComplete();
+
+ // The proxy is waiting for the server to be ready to accept data.
+ EXPECT_CALL(sockets(), Send(kServerFD, _, _, 0))
+ .WillOnce(Return(10));
+ ExpectServerInput();
+ WriteToServer(kServerFD);
+ EXPECT_CALL(sockets(), Send(kServerFD, _, _, 0))
+ .WillOnce(ReturnArg<2>());
+ ExpectInputTimeout();
+ WriteToServer(kServerFD);
+ EXPECT_EQ(HTTPProxy::kStateTunnelData, GetProxyState());
+
+ // Tunnel a reply back to the client.
+ const string server_result("200 OK ... and so on");
+ ExpectClientResult();
+ ReadFromServer(server_result);
+ EXPECT_EQ(server_result,
+ string(reinterpret_cast<const char *>(
+ GetServerData().GetConstData()),
+ GetServerData().GetLength()));
+
+ // Allow part of the result string to be sent to the client.
+ const int part = server_result.length() / 2;
+ EXPECT_CALL(sockets(), Send(kClientFD, _, server_result.length(), 0))
+ .WillOnce(Return(part));
+ ExpectInputTimeout();
+ WriteToClient(kClientFD);
+ EXPECT_EQ(HTTPProxy::kStateTunnelData, GetProxyState());
+
+ // The Server closes the connection while the client is still reading.
+ ExpectInputTimeout();
+ ReadFromServer("");
+ EXPECT_EQ(HTTPProxy::kStateFlushResponse, GetProxyState());
+
+ // When the last part of the response is written to the client, we close
+ // all connections.
+ EXPECT_CALL(sockets(), Send(kClientFD, _, server_result.length() - part, 0))
+ .WillOnce(ReturnArg<2>());
+ ExpectTunnelClose();
+ WriteToClient(kClientFD);
+ EXPECT_EQ(HTTPProxy::kStateWaitConnection, GetProxyState());
+}
+
+TEST_F(HTTPProxyTest, TunnelDataFailWriteClient) {
+ SetupConnectComplete();
+ EXPECT_CALL(sockets(), Send(kClientFD, _, _, 0))
+ .WillOnce(Return(-1));
+ ExpectTunnelClose();
+ WriteToClient(kClientFD);
+ ExpectClientReset();
+ EXPECT_EQ(HTTPProxy::kStateWaitConnection, GetProxyState());
+}
+
+TEST_F(HTTPProxyTest, TunnelDataFailWriteServer) {
+ SetupConnectComplete();
+ EXPECT_CALL(sockets(), Send(kServerFD, _, _, 0))
+ .WillOnce(Return(-1));
+ ExpectTunnelClose();
+ WriteToServer(kServerFD);
+ ExpectClientReset();
+ EXPECT_EQ(HTTPProxy::kStateWaitConnection, GetProxyState());
+}
+
+TEST_F(HTTPProxyTest, TunnelDataFailClientClose) {
+ SetupConnectComplete();
+ ExpectTunnelClose();
+ ReadFromClient("");
+ ExpectClientReset();
+ EXPECT_EQ(HTTPProxy::kStateWaitConnection, GetProxyState());
+}
+
+TEST_F(HTTPProxyTest, TunnelDataFailServerClose) {
+ SetupConnectComplete();
+ ExpectTunnelClose();
+ ReadFromServer("");
+ ExpectClientReset();
+ EXPECT_EQ(HTTPProxy::kStateWaitConnection, GetProxyState());
+}
+
+TEST_F(HTTPProxyTest, StopClient) {
+ SetupConnectComplete();
+ EXPECT_CALL(sockets(), Close(kClientFD))
+ .WillOnce(Return(0));
+ EXPECT_CALL(sockets(), Close(kServerFD))
+ .WillOnce(Return(0));
+ StopClient();
+ ExpectClientReset();
+ EXPECT_EQ(HTTPProxy::kStateWaitConnection, GetProxyState());
+}
+
+} // namespace shill
diff --git a/mock_async_connection.cc b/mock_async_connection.cc
new file mode 100644
index 0000000..cd36973
--- /dev/null
+++ b/mock_async_connection.cc
@@ -0,0 +1,16 @@
+// Copyright (c) 2011 The Chromium OS Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "shill/mock_async_connection.h"
+
+#include "shill/ip_address.h"
+
+namespace shill {
+
+MockAsyncConnection::MockAsyncConnection()
+ : AsyncConnection("", NULL, NULL, NULL) {}
+
+MockAsyncConnection::~MockAsyncConnection() {}
+
+} // namespace shill
diff --git a/mock_async_connection.h b/mock_async_connection.h
new file mode 100644
index 0000000..7f714c1
--- /dev/null
+++ b/mock_async_connection.h
@@ -0,0 +1,29 @@
+// Copyright (c) 2011 The Chromium OS Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef SHILL_MOCK_ASYNC_CONNECTION_H_
+#define SHILL_MOCK_ASYNC_CONNECTION_H_
+
+#include <base/basictypes.h>
+#include <gmock/gmock.h>
+
+#include "shill/async_connection.h"
+
+namespace shill {
+
+class MockAsyncConnection : public AsyncConnection {
+ public:
+ MockAsyncConnection();
+ virtual ~MockAsyncConnection();
+
+ MOCK_METHOD2(Start, bool(const IPAddress &address, int port));
+ MOCK_METHOD0(Stop, void());
+
+ private:
+ DISALLOW_COPY_AND_ASSIGN(MockAsyncConnection);
+};
+
+} // namespace shill
+
+#endif // SHILL_MOCK_ASYNC_CONNECTION_H_
diff --git a/mock_dns_client.cc b/mock_dns_client.cc
new file mode 100644
index 0000000..f964602
--- /dev/null
+++ b/mock_dns_client.cc
@@ -0,0 +1,22 @@
+// Copyright (c) 2011 The Chromium OS Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "shill/mock_dns_client.h"
+
+#include <vector>
+#include <string>
+
+#include "shill/ip_address.h"
+
+using std::string;
+using std::vector;
+
+namespace shill {
+
+MockDNSClient::MockDNSClient()
+ : DNSClient(IPAddress::kFamilyIPv4, "", vector<string>(), 0, NULL, NULL) {}
+
+MockDNSClient::~MockDNSClient() {}
+
+} // namespace shill
diff --git a/mock_dns_client.h b/mock_dns_client.h
new file mode 100644
index 0000000..73cf62d
--- /dev/null
+++ b/mock_dns_client.h
@@ -0,0 +1,31 @@
+// Copyright (c) 2011 The Chromium OS Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef SHILL_MOCK_DNS_CLIENT_H_
+#define SHILL_MOCK_DNS_CLIENT_H_
+
+#include "shill/dns_client.h"
+
+#include <base/basictypes.h>
+#include <gmock/gmock.h>
+
+namespace shill {
+
+class MockDNSClient : public DNSClient {
+ public:
+ MockDNSClient();
+ virtual ~MockDNSClient();
+
+ MOCK_METHOD1(Start, bool(const std::string &hostname));
+ MOCK_METHOD0(Stop, void());
+ MOCK_CONST_METHOD0(address, const IPAddress &());
+ MOCK_CONST_METHOD0(error, const std::string &());
+
+ private:
+ DISALLOW_COPY_AND_ASSIGN(MockDNSClient);
+};
+
+} // namespace shill
+
+#endif // SHILL_MOCK_DNS_CLIENT_H_
diff --git a/mock_sockets.h b/mock_sockets.h
index 8d80269..b095ce9 100644
--- a/mock_sockets.h
+++ b/mock_sockets.h
@@ -17,10 +17,20 @@
MockSockets();
virtual ~MockSockets();
+ MOCK_METHOD3(Accept,
+ int(int sockfd, struct sockaddr *addr, socklen_t *addrlen));
MOCK_METHOD3(Bind,
int(int sockfd, const struct sockaddr *addr, socklen_t addrlen));
+ MOCK_METHOD2(BindToDevice, int(int sockfd, const std::string &device));
MOCK_METHOD1(Close, int(int fd));
+ MOCK_METHOD3(Connect,
+ int(int sockfd, const struct sockaddr *addr, socklen_t addrlen));
+ MOCK_METHOD0(Error, int());
+ MOCK_METHOD3(GetSockName,
+ int(int sockfd, struct sockaddr *addr, socklen_t *addrlen));
+ MOCK_METHOD1(GetSocketError, int(int fd));
MOCK_METHOD3(Ioctl, int(int d, int request, void *argp));
+ MOCK_METHOD2(Listen, int(int d, int backlog));
MOCK_METHOD4(Send,
ssize_t(int sockfd, const void *buf, size_t len, int flags));
MOCK_METHOD6(SendTo, ssize_t(int sockfd,
@@ -29,6 +39,7 @@
int flags,
const struct sockaddr *dest_addr,
socklen_t addrlen));
+ MOCK_METHOD1(SetNonBlocking, int(int sockfd));
MOCK_METHOD3(Socket, int(int domain, int type, int protocol));
private:
diff --git a/sockets.cc b/sockets.cc
index b1662f7..032ee4f 100644
--- a/sockets.cc
+++ b/sockets.cc
@@ -2,28 +2,81 @@
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
+#include "shill/sockets.h"
+
+#include <errno.h>
+#include <fcntl.h>
+#include <net/if.h>
+#include <stdio.h>
#include <sys/ioctl.h>
#include <sys/socket.h>
#include <unistd.h>
-#include "shill/sockets.h"
+#include <base/logging.h>
namespace shill {
Sockets::~Sockets() {}
+int Sockets::Accept(int sockfd, struct sockaddr *addr, socklen_t *addrlen) {
+ return accept(sockfd, addr, addrlen);
+}
+
int Sockets::Bind(int sockfd, const struct sockaddr *addr, socklen_t addrlen) {
return bind(sockfd, addr, addrlen);
}
+int Sockets::BindToDevice(int sockfd, const std::string &device) {
+ char dev_name[IFNAMSIZ];
+ CHECK_GT(sizeof(dev_name), device.length());
+ memset(&dev_name, 0, sizeof(dev_name));
+ snprintf(dev_name, sizeof(dev_name), "%s", device.c_str());
+ return setsockopt(sockfd, SOL_SOCKET, SO_BINDTODEVICE, &dev_name,
+ sizeof(dev_name));
+}
+
int Sockets::Close(int fd) {
return close(fd);
}
+int Sockets::Connect(int sockfd, const struct sockaddr *addr,
+ socklen_t addrlen) {
+ return connect(sockfd, addr, addrlen);
+}
+
+int Sockets::Error() {
+ return errno;
+}
+
+std::string Sockets::ErrorString() {
+ return std::string(strerror(Error()));
+}
+
+int Sockets::GetSockName(int sockfd,
+ struct sockaddr *addr,
+ socklen_t *addrlen) {
+ return getsockname(sockfd, addr, addrlen);
+}
+
+
+int Sockets::GetSocketError(int sockfd) {
+ int error;
+ socklen_t optlen = sizeof(error);
+ if (getsockopt(sockfd, SOL_SOCKET, SO_ERROR, &error, &optlen) == 0) {
+ return error;
+ }
+ return -1;
+}
+
+
int Sockets::Ioctl(int d, int request, void *argp) {
return ioctl(d, request, argp);
}
+int Sockets::Listen(int sockfd, int backlog) {
+ return listen(sockfd, backlog);
+}
+
ssize_t Sockets::Send(int sockfd, const void *buf, size_t len, int flags) {
return send(sockfd, buf, len, flags);
}
@@ -33,6 +86,10 @@
return sendto(sockfd, buf, len, flags, dest_addr, addrlen);
}
+int Sockets::SetNonBlocking(int sockfd) {
+ return fcntl(sockfd, F_SETFL, fcntl(sockfd, F_GETFL) | O_NONBLOCK);
+}
+
int Sockets::Socket(int domain, int type, int protocol) {
return socket(domain, type, protocol);
}
diff --git a/sockets.h b/sockets.h
index 31d718e..e0cc70b 100644
--- a/sockets.h
+++ b/sockets.h
@@ -5,6 +5,8 @@
#ifndef SHILL_SOCKETS_H_
#define SHILL_SOCKETS_H_
+#include <string>
+
#include <sys/socket.h>
#include <sys/types.h>
@@ -17,15 +19,41 @@
public:
virtual ~Sockets();
+ // accept
+ virtual int Accept(int sockfd, struct sockaddr *addr, socklen_t *addrlen);
+
// bind
virtual int Bind(int sockfd, const struct sockaddr *addr, socklen_t addrlen);
+ // setsockopt(s, SOL_SOCKET, SO_BINDTODEVICE ...)
+ virtual int BindToDevice(int sockfd, const std::string &device);
+
// close
virtual int Close(int fd);
+ // connect
+ virtual int Connect(int sockfd, const struct sockaddr *addr,
+ socklen_t addrlen);
+
+ // errno
+ virtual int Error();
+
+ // errno
+ virtual std::string ErrorString();
+
+ // getsockname
+ virtual int GetSockName(int sockfd, struct sockaddr *addr,
+ socklen_t *addrlen);
+
+ // getsockopt(sockfd, SOL_SOCKET, SO_ERROR, ...)
+ virtual int GetSocketError(int sockfd);
+
// ioctl
virtual int Ioctl(int d, int request, void *argp);
+ // listen
+ virtual int Listen(int sockfd, int backlog);
+
// send
virtual ssize_t Send(int sockfd, const void *buf, size_t len, int flags);
@@ -33,6 +61,9 @@
virtual ssize_t SendTo(int sockfd, const void *buf, size_t len, int flags,
const struct sockaddr *dest_addr, socklen_t addrlen);
+ // fcntl(sk, F_SETFL, fcntl(sk, F_GETFL) | O_NONBLOCK)
+ virtual int SetNonBlocking(int sockfd);
+
// socket
virtual int Socket(int domain, int type, int protocol);
};