shill: Add Portal Detection object
Add a utility object that will perform a repeated test of an
HTTP URL and return the result to a callback.
BUG=chromium-os:23318
TEST=New unit tests
Change-Id: I0449dbe51fb1dcef2ecd3bb88de1bcaf2950f749
Reviewed-on: https://gerrit.chromium.org/gerrit/15472
Commit-Ready: Paul Stewart <pstew@chromium.org>
Reviewed-by: Paul Stewart <pstew@chromium.org>
Tested-by: Paul Stewart <pstew@chromium.org>
diff --git a/Makefile b/Makefile
index f3bc5ab..7050a3e 100644
--- a/Makefile
+++ b/Makefile
@@ -142,6 +142,7 @@
modem_manager_proxy.o \
modem_proxy.o \
modem_simple_proxy.o \
+ portal_detector.o \
power_manager_proxy.o \
profile.o \
profile_dbus_adaptor.o \
@@ -220,6 +221,7 @@
mock_dns_client.o \
mock_event_dispatcher.o \
mock_glib.o \
+ mock_http_request.o \
mock_ipconfig.o \
mock_manager.o \
mock_metrics.o \
@@ -247,6 +249,7 @@
modem_manager_unittest.o \
modem_unittest.o \
nice_mock_control.o \
+ portal_detector_unittest.o \
profile_dbus_property_exporter_unittest.o \
profile_unittest.o \
property_accessor_unittest.o \
diff --git a/default_profile.cc b/default_profile.cc
index 27607e1..280856a 100644
--- a/default_profile.cc
+++ b/default_profile.cc
@@ -1,4 +1,4 @@
-// Copyright (c) 2011 The Chromium OS Authors. All rights reserved.
+// Copyright (c) 2012 The Chromium OS Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
@@ -13,6 +13,7 @@
#include "shill/adaptor_interfaces.h"
#include "shill/control_interface.h"
#include "shill/manager.h"
+#include "shill/portal_detector.h"
#include "shill/store_interface.h"
using std::vector;
@@ -30,6 +31,8 @@
const char DefaultProfile::kStorageName[] = "Name";
// static
const char DefaultProfile::kStorageOfflineMode[] = "OfflineMode";
+// static
+const char DefaultProfile::kStoragePortalURL[] = "PortalURL";
DefaultProfile::DefaultProfile(ControlInterface *control,
Manager *manager,
@@ -58,6 +61,10 @@
storage()->GetString(kStorageId,
kStorageCheckPortalList,
&manager_props->check_portal_list);
+ if (!storage()->GetString(kStorageId, kStoragePortalURL,
+ &manager_props->portal_url)) {
+ manager_props->portal_url = PortalDetector::kDefaultURL;
+ }
return true;
}
@@ -68,6 +75,9 @@
storage()->SetString(kStorageId,
kStorageCheckPortalList,
props_.check_portal_list);
+ storage()->SetString(kStorageId,
+ kStoragePortalURL,
+ props_.portal_url);
vector<DeviceRefPtr>::iterator it;
for (it = manager()->devices_begin(); it != manager()->devices_end(); ++it) {
if (!(*it)->Save(storage())) {
diff --git a/default_profile.h b/default_profile.h
index 7e0f5cf..58f619b 100644
--- a/default_profile.h
+++ b/default_profile.h
@@ -1,4 +1,4 @@
-// Copyright (c) 2011 The Chromium OS Authors. All rights reserved.
+// Copyright (c) 2012 The Chromium OS Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
@@ -49,6 +49,7 @@
private:
friend class DefaultProfileTest;
FRIEND_TEST(DefaultProfileTest, GetStoragePath);
+ FRIEND_TEST(DefaultProfileTest, LoadManagerDefaultProperties);
FRIEND_TEST(DefaultProfileTest, LoadManagerProperties);
FRIEND_TEST(DefaultProfileTest, Save);
@@ -58,6 +59,7 @@
static const char kStorageHostName[];
static const char kStorageName[];
static const char kStorageOfflineMode[];
+ static const char kStoragePortalURL[];
const FilePath storage_path_;
const Manager::Properties &props_;
diff --git a/default_profile_unittest.cc b/default_profile_unittest.cc
index 695e092..e66a469 100644
--- a/default_profile_unittest.cc
+++ b/default_profile_unittest.cc
@@ -19,6 +19,7 @@
#include "shill/mock_control.h"
#include "shill/mock_device.h"
#include "shill/mock_store.h"
+#include "shill/portal_detector.h"
#include "shill/property_store_unittest.h"
using std::map;
@@ -115,6 +116,10 @@
DefaultProfile::kStorageCheckPortalList,
""))
.WillOnce(Return(true));
+ EXPECT_CALL(*storage.get(), SetString(DefaultProfile::kStorageId,
+ DefaultProfile::kStoragePortalURL,
+ ""))
+ .WillOnce(Return(true));
EXPECT_CALL(*storage.get(), Flush()).WillOnce(Return(true));
EXPECT_CALL(*device_.get(), Save(storage.get())).WillOnce(Return(true));
@@ -125,6 +130,34 @@
manager()->DeregisterDevice(device_);
}
+TEST_F(DefaultProfileTest, LoadManagerDefaultProperties) {
+ scoped_ptr<MockStore> storage(new MockStore);
+ EXPECT_CALL(*storage.get(), GetString(DefaultProfile::kStorageId,
+ DefaultProfile::kStorageHostName,
+ _))
+ .WillOnce(Return(false));
+ EXPECT_CALL(*storage.get(), GetBool(DefaultProfile::kStorageId,
+ DefaultProfile::kStorageOfflineMode,
+ _))
+ .WillOnce(Return(false));
+ EXPECT_CALL(*storage.get(), GetString(DefaultProfile::kStorageId,
+ DefaultProfile::kStorageCheckPortalList,
+ _))
+ .WillOnce(Return(false));
+ EXPECT_CALL(*storage.get(), GetString(DefaultProfile::kStorageId,
+ DefaultProfile::kStoragePortalURL,
+ _))
+ .WillOnce(Return(false));
+ profile_->set_storage(storage.release());
+
+ Manager::Properties manager_props;
+ ASSERT_TRUE(profile_->LoadManagerProperties(&manager_props));
+ EXPECT_EQ("", manager_props.host_name);
+ EXPECT_FALSE(manager_props.offline_mode);
+ EXPECT_EQ("", manager_props.check_portal_list);
+ EXPECT_EQ(PortalDetector::kDefaultURL, manager_props.portal_url);
+}
+
TEST_F(DefaultProfileTest, LoadManagerProperties) {
scoped_ptr<MockStore> storage(new MockStore);
const string host_name("hostname");
@@ -141,6 +174,11 @@
DefaultProfile::kStorageCheckPortalList,
_))
.WillOnce(DoAll(SetArgumentPointee<2>(portal_list), Return(true)));
+ const string portal_url("http://www.chromium.org");
+ EXPECT_CALL(*storage.get(), GetString(DefaultProfile::kStorageId,
+ DefaultProfile::kStoragePortalURL,
+ _))
+ .WillOnce(DoAll(SetArgumentPointee<2>(portal_url), Return(true)));
profile_->set_storage(storage.release());
Manager::Properties manager_props;
@@ -148,6 +186,7 @@
EXPECT_EQ(host_name, manager_props.host_name);
EXPECT_TRUE(manager_props.offline_mode);
EXPECT_EQ(portal_list, manager_props.check_portal_list);
+ EXPECT_EQ(portal_url, manager_props.portal_url);
}
TEST_F(DefaultProfileTest, GetStoragePath) {
diff --git a/dns_client.cc b/dns_client.cc
index ba1961d..7fdbfbe 100644
--- a/dns_client.cc
+++ b/dns_client.cc
@@ -1,4 +1,4 @@
-// Copyright (c) 2011 The Chromium OS Authors. All rights reserved.
+// Copyright (c) 2012 The Chromium OS Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
@@ -114,7 +114,7 @@
}
running_ = true;
- time_->GetTimeOfDay(&resolver_state_->start_time_, NULL);
+ time_->GetTimeMonotonic(&resolver_state_->start_time_);
error_.clear();
ares_->GetHostByName(resolver_state_->channel, hostname.c_str(),
address_.family(), ReceiveDNSReplyCB, this);
@@ -283,7 +283,7 @@
// Schedule timer event for the earlier of our timeout or one requested by
// the resolver library.
struct timeval now, elapsed_time, timeout_tv;
- time_->GetTimeOfDay(&now, NULL);
+ time_->GetTimeMonotonic(&now);
timersub(&now, &resolver_state_->start_time_, &elapsed_time);
timeout_tv.tv_sec = timeout_ms_ / 1000;
timeout_tv.tv_usec = (timeout_ms_ % 1000) * 1000;
diff --git a/dns_client_unittest.cc b/dns_client_unittest.cc
index 2df51cc..720240e 100644
--- a/dns_client_unittest.cc
+++ b/dns_client_unittest.cc
@@ -1,4 +1,4 @@
-// Copyright (c) 2011 The Chromium OS Authors. All rights reserved.
+// Copyright (c) 2012 The Chromium OS Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
@@ -63,7 +63,7 @@
}
virtual void SetUp() {
- EXPECT_CALL(time_, GetTimeOfDay(_, _))
+ EXPECT_CALL(time_, GetTimeMonotonic(_))
.WillRepeatedly(DoAll(SetArgumentPointee<0>(time_val_), Return(0)));
SetInActive();
}
@@ -79,7 +79,7 @@
void AdvanceTime(int time_ms) {
struct timeval adv_time = { time_ms/1000, (time_ms % 1000) * 1000 };
timeradd(&time_val_, &adv_time, &time_val_);
- EXPECT_CALL(time_, GetTimeOfDay(_, _))
+ EXPECT_CALL(time_, GetTimeMonotonic(_))
.WillRepeatedly(DoAll(SetArgumentPointee<0>(time_val_), Return(0)));
}
@@ -274,7 +274,7 @@
SetupRequest(kGoodName, kGoodServer);
struct timeval init_time_val = time_val_;
AdvanceTime(kAresTimeoutMS);
- EXPECT_CALL(time_, GetTimeOfDay(_, _))
+ EXPECT_CALL(time_, GetTimeMonotonic(_))
.WillOnce(DoAll(SetArgumentPointee<0>(init_time_val), Return(0)))
.WillRepeatedly(DoAll(SetArgumentPointee<0>(time_val_), Return(0)));
EXPECT_CALL(callback_target_, CallTarget(false));
diff --git a/http_request.h b/http_request.h
index 406a94d..1700ab7 100644
--- a/http_request.h
+++ b/http_request.h
@@ -66,15 +66,15 @@
// This (Start) function returns a failure result if the request
// failed during initialization, or kResultInProgress if the request
// has started successfully and is now in progress.
- Result Start(const HTTPURL &url,
- Callback1<int>::Type *read_event_callback,
- Callback1<Result>::Type *result_callback);
+ virtual Result Start(const HTTPURL &url,
+ Callback1<int>::Type *read_event_callback,
+ Callback1<Result>::Type *result_callback);
// Stop the current HTTPRequest. No callback is called as a side
// effect of this function.
- void Stop();
+ virtual void Stop();
- const ByteString &response_data() const { return response_data_; }
+ virtual const ByteString &response_data() const { return response_data_; }
private:
friend class HTTPRequestTest;
diff --git a/mock_http_request.cc b/mock_http_request.cc
new file mode 100644
index 0000000..75dd5d6
--- /dev/null
+++ b/mock_http_request.cc
@@ -0,0 +1,18 @@
+// Copyright (c) 2012 The Chromium OS Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "shill/mock_http_request.h"
+
+#include "shill/connection.h"
+
+namespace shill {
+
+MockHTTPRequest::MockHTTPRequest(ConnectionRefPtr connection)
+ : HTTPRequest(connection,
+ reinterpret_cast<EventDispatcher *>(NULL),
+ reinterpret_cast<Sockets *>(NULL)) {}
+
+MockHTTPRequest::~MockHTTPRequest() {}
+
+} // namespace shill
diff --git a/mock_http_request.h b/mock_http_request.h
new file mode 100644
index 0000000..dd8c911
--- /dev/null
+++ b/mock_http_request.h
@@ -0,0 +1,34 @@
+// Copyright (c) 2012 The Chromium OS Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef SHILL_MOCK_HTTP_REQUEST_H_
+#define SHILL_MOCK_HTTP_REQUEST_H_
+
+#include <base/basictypes.h>
+#include <gmock/gmock.h>
+
+#include "shill/http_request.h"
+#include "shill/http_url.h" // MOCK_METHOD3() call below needs sizeof(HTTPURL).
+
+namespace shill {
+
+class MockHTTPRequest : public HTTPRequest {
+ public:
+ MockHTTPRequest(ConnectionRefPtr connection);
+ virtual ~MockHTTPRequest();
+
+ MOCK_METHOD3(Start, HTTPRequest::Result(
+ const HTTPURL &url,
+ Callback1<int>::Type *read_event_callback,
+ Callback1<Result>::Type *result_callback));
+ MOCK_METHOD0(Stop, void());
+ MOCK_CONST_METHOD0(response_data, const ByteString &());
+
+ private:
+ DISALLOW_COPY_AND_ASSIGN(MockHTTPRequest);
+};
+
+} // namespace shill
+
+#endif // SHILL_MOCK_HTTP_REQUEST_H_
diff --git a/mock_time.h b/mock_time.h
index 7199822..ae253bc 100644
--- a/mock_time.h
+++ b/mock_time.h
@@ -1,4 +1,4 @@
-// Copyright (c) 2011 The Chromium OS Authors. All rights reserved.
+// Copyright (c) 2012 The Chromium OS Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
@@ -17,6 +17,7 @@
MockTime();
virtual ~MockTime();
+ MOCK_METHOD1(GetTimeMonotonic, int(struct timeval *tv));
MOCK_METHOD2(GetTimeOfDay, int(struct timeval *tv, struct timezone *tz));
private:
diff --git a/portal_detector.cc b/portal_detector.cc
new file mode 100644
index 0000000..35228d1
--- /dev/null
+++ b/portal_detector.cc
@@ -0,0 +1,249 @@
+// Copyright (c) 2012 The Chromium OS Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "shill/portal_detector.h"
+
+#include <string>
+
+#include <base/logging.h>
+#include <base/string_number_conversions.h>
+#include <base/string_util.h>
+#include <base/stringprintf.h>
+
+#include "shill/async_connection.h"
+#include "shill/connection.h"
+#include "shill/dns_client.h"
+#include "shill/event_dispatcher.h"
+#include "shill/http_url.h"
+#include "shill/ip_address.h"
+#include "shill/sockets.h"
+
+using base::StringPrintf;
+using std::string;
+
+namespace shill {
+
+const char PortalDetector::kDefaultURL[] =
+ "http://clients3.google.com/generate_204";
+const char PortalDetector::kResponseExpected[] = "HTTP/1.1 204";
+
+const int PortalDetector::kMaxRequestAttempts = 3;
+const int PortalDetector::kMinTimeBetweenAttemptsSeconds = 3;
+const int PortalDetector::kRequestTimeoutSeconds = 10;
+
+const char PortalDetector::kPhaseConnectionString[] = "Connection";
+const char PortalDetector::kPhaseDNSString[] = "DNS";
+const char PortalDetector::kPhaseHTTPString[] = "HTTP";
+const char PortalDetector::kPhaseContentString[] = "Content";
+const char PortalDetector::kPhaseUnknownString[] = "Unknown";
+
+const char PortalDetector::kStatusFailureString[] = "Failure";
+const char PortalDetector::kStatusSuccessString[] = "Success";
+const char PortalDetector::kStatusTimeoutString[] = "Timeout";
+
+PortalDetector::PortalDetector(
+ ConnectionRefPtr connection,
+ EventDispatcher *dispatcher,
+ Callback1<const Result &>::Type *callback)
+ : attempt_count_(0),
+ connection_(connection),
+ dispatcher_(dispatcher),
+ portal_result_callback_(callback),
+ request_read_callback_(
+ NewCallback(this, &PortalDetector::RequestReadCallback)),
+ request_result_callback_(
+ NewCallback(this, &PortalDetector::RequestResultCallback)),
+ task_factory_(this),
+ time_(Time::GetInstance()) { }
+
+PortalDetector::~PortalDetector() {
+ Stop();
+}
+
+bool PortalDetector::Start(const std::string &url_string) {
+ VLOG(3) << "In " << __func__;
+
+ DCHECK(!request_.get());
+
+ if (!url_.ParseFromString(url_string)) {
+ LOG(ERROR) << "Failed to parse URL string: " << url_string;
+ return false;
+ }
+
+ request_.reset(new HTTPRequest(connection_, dispatcher_, &sockets_));
+ attempt_count_ = 0;
+ StartAttempt();
+ return true;
+}
+
+void PortalDetector::Stop() {
+ VLOG(3) << "In " << __func__;
+
+ if (!request_.get()) {
+ return;
+ }
+
+ StopAttempt();
+ attempt_count_ = 0;
+ request_.reset();
+}
+
+// static
+const string PortalDetector::PhaseToString(Phase phase) {
+ switch (phase) {
+ case kPhaseConnection:
+ return kPhaseConnectionString;
+ case kPhaseDNS:
+ return kPhaseDNSString;
+ case kPhaseHTTP:
+ return kPhaseHTTPString;
+ case kPhaseContent:
+ return kPhaseContentString;
+ case kPhaseUnknown:
+ default:
+ return kPhaseUnknownString;
+ }
+}
+
+// static
+const string PortalDetector::StatusToString(Status status) {
+ switch (status) {
+ case kStatusSuccess:
+ return kStatusSuccessString;
+ case kStatusTimeout:
+ return kStatusTimeoutString;
+ case kStatusFailure:
+ default:
+ return kStatusFailureString;
+ }
+}
+
+void PortalDetector::CompleteAttempt(Result result) {
+ LOG(INFO) << StringPrintf("Portal detection completed attempt %d with "
+ "phase==%s, status==%s",
+ attempt_count_,
+ PhaseToString(result.phase).c_str(),
+ StatusToString(result.status).c_str());
+ StopAttempt();
+ if (result.status != kStatusSuccess && attempt_count_ < kMaxRequestAttempts) {
+ StartAttempt();
+ } else {
+ result.final = true;
+ Stop();
+ }
+
+ portal_result_callback_->Run(result);
+}
+
+PortalDetector::Result PortalDetector::GetPortalResultForRequestResult(
+ HTTPRequest::Result result) {
+ switch (result) {
+ case HTTPRequest::kResultSuccess:
+ // The request completed without receiving the expected payload.
+ return Result(kPhaseContent, kStatusFailure);
+ case HTTPRequest::kResultDNSFailure:
+ return Result(kPhaseDNS, kStatusFailure);
+ case HTTPRequest::kResultDNSTimeout:
+ return Result(kPhaseDNS, kStatusTimeout);
+ case HTTPRequest::kResultConnectionFailure:
+ return Result(kPhaseConnection, kStatusFailure);
+ case HTTPRequest::kResultConnectionTimeout:
+ return Result(kPhaseConnection, kStatusTimeout);
+ case HTTPRequest::kResultRequestFailure:
+ case HTTPRequest::kResultResponseFailure:
+ return Result(kPhaseHTTP, kStatusFailure);
+ case HTTPRequest::kResultRequestTimeout:
+ case HTTPRequest::kResultResponseTimeout:
+ return Result(kPhaseHTTP, kStatusTimeout);
+ case HTTPRequest::kResultUnknown:
+ default:
+ return Result(kPhaseUnknown, kStatusFailure);
+ }
+}
+
+void PortalDetector::RequestReadCallback(int /*read_length*/) {
+ const string response_expected(kResponseExpected);
+ const ByteString &response_data = request_->response_data();
+ bool expected_length_received = false;
+ int compare_length = 0;
+ if (response_data.GetLength() < response_expected.length()) {
+ // There isn't enough data yet for a final decision, but we can still
+ // test to see if the partial string matches so far.
+ expected_length_received = false;
+ compare_length = response_data.GetLength();
+ } else {
+ expected_length_received = true;
+ compare_length = response_expected.length();
+ }
+
+ if (ByteString(response_expected.substr(0, compare_length), false).Equals(
+ ByteString(response_data.GetConstData(), compare_length))) {
+ if (expected_length_received) {
+ CompleteAttempt(Result(kPhaseContent, kStatusSuccess));
+ }
+ // Otherwise, we wait for more data from the server.
+ } else {
+ CompleteAttempt(Result(kPhaseContent, kStatusFailure));
+ }
+}
+
+void PortalDetector::RequestResultCallback(HTTPRequest::Result result) {
+ CompleteAttempt(GetPortalResultForRequestResult(result));
+}
+
+void PortalDetector::StartAttempt() {
+ int64 next_attempt_delay = 0;
+ if (attempt_count_ > 0) {
+ // Ensure that attempts are spaced at least by a minimal interval.
+ struct timeval now, elapsed_time;
+ time_->GetTimeMonotonic(&now);
+ timersub(&now, &attempt_start_time_, &elapsed_time);
+
+ if (elapsed_time.tv_sec < kMinTimeBetweenAttemptsSeconds) {
+ struct timeval remaining_time = { kMinTimeBetweenAttemptsSeconds, 0 };
+ timersub(&remaining_time, &elapsed_time, &remaining_time);
+ next_attempt_delay =
+ remaining_time.tv_sec * 1000 + remaining_time.tv_usec / 1000;
+ }
+ }
+ dispatcher_->PostDelayedTask(
+ task_factory_.NewRunnableMethod(&PortalDetector::StartAttemptTask),
+ next_attempt_delay);
+}
+
+void PortalDetector::StartAttemptTask() {
+ time_->GetTimeMonotonic(&attempt_start_time_);
+ ++attempt_count_;
+
+ LOG(INFO) << StringPrintf("Portal detection starting attempt %d of %d\n",
+ attempt_count_, kMaxRequestAttempts);
+
+ HTTPRequest::Result result =
+ request_->Start(url_, request_read_callback_.get(),
+ request_result_callback_.get());
+ if (result != HTTPRequest::kResultInProgress) {
+ CompleteAttempt(GetPortalResultForRequestResult(result));
+ return;
+ }
+
+ dispatcher_->PostDelayedTask(
+ task_factory_.NewRunnableMethod(&PortalDetector::TimeoutAttemptTask),
+ kRequestTimeoutSeconds * 1000);
+}
+
+void PortalDetector::StopAttempt() {
+ request_->Stop();
+ task_factory_.RevokeAll();
+}
+
+void PortalDetector::TimeoutAttemptTask() {
+ LOG(ERROR) << "Request timed out";
+ if (request_->response_data().GetLength()) {
+ CompleteAttempt(Result(kPhaseContent, kStatusTimeout));
+ } else {
+ CompleteAttempt(Result(kPhaseUnknown, kStatusTimeout));
+ }
+}
+
+} // namespace shill
diff --git a/portal_detector.h b/portal_detector.h
new file mode 100644
index 0000000..1c1079e
--- /dev/null
+++ b/portal_detector.h
@@ -0,0 +1,148 @@
+// Copyright (c) 2012 The Chromium OS Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef SHILL_PORTAL_DETECTOR_
+#define SHILL_PORTAL_DETECTOR_
+
+#include <string>
+#include <vector>
+
+#include <base/callback_old.h>
+#include <base/memory/ref_counted.h>
+#include <base/memory/scoped_ptr.h>
+#include <base/task.h>
+#include <gtest/gtest_prod.h> // for FRIEND_TEST
+
+#include "shill/http_request.h"
+#include "shill/http_url.h"
+#include "shill/refptr_types.h"
+#include "shill/shill_time.h"
+#include "shill/sockets.h"
+
+namespace shill {
+
+class EventDispatcher;
+class PortalDetector;
+class Time;
+
+// The PortalDetector class implements the portal detection
+// facility in shill, which is responsible for checking to see
+// if a connection has "general internet connectivity".
+//
+// This information can be used for ranking one connection
+// against another, or for informing UI and other components
+// outside the connection manager whether the connection seems
+// available for "general use" or if further user action may be
+// necessary (e.g, click through of a WiFi Hotspot's splash
+// page).
+//
+// This is achieved by trying to access a URL and expecting a
+// specific response. Any result that deviates from this result
+// (DNS or HTTP errors, as well as deviations from the expected
+// content) are considered failures.
+class PortalDetector {
+ public:
+ enum Phase {
+ kPhaseConnection,
+ kPhaseDNS,
+ kPhaseHTTP,
+ kPhaseContent,
+ kPhaseUnknown
+ };
+
+ enum Status {
+ kStatusFailure,
+ kStatusSuccess,
+ kStatusTimeout
+ };
+
+ struct Result {
+ Result() : phase(kPhaseUnknown), status(kStatusFailure), final(false) {}
+ Result(Phase phase_in, Status status_in)
+ : phase(phase_in), status(status_in), final(false) {}
+ Result(Phase phase_in, Status status_in, bool final_in)
+ : phase(phase_in), status(status_in), final(final_in) {}
+ Phase phase;
+ Status status;
+ bool final;
+ };
+
+ static const char kDefaultURL[];
+ static const char kResponseExpected[];
+
+ PortalDetector(ConnectionRefPtr connection,
+ EventDispatcher *dispatcher,
+ Callback1<const Result&>::Type *callback);
+ virtual ~PortalDetector();
+
+ // Start a portal detection test. Returns true if |url_string| correctly
+ // parses as a URL. Returns false (and does not start) if the |url_string|
+ // fails to parse.
+ //
+ // As each attempt completes the callback handed to the constructor will
+ // be called. The PortalDetector will try up to kMaxRequestAttempts times
+ // to successfully retrieve the URL. If the attempt is successful or
+ // this is the last attempt, the "final" flag in the Result structure will
+ // be true, otherwise it will be false, and the PortalDetector will
+ // schedule the next attempt.
+ bool Start(const std::string &url_string);
+
+ // End the current portal detection process if one exists, and do not call
+ // the callback.
+ void Stop();
+
+ static const std::string PhaseToString(Phase phase);
+ static const std::string StatusToString(Status status);
+ static Result GetPortalResultForRequestResult(HTTPRequest::Result result);
+
+ private:
+ friend class PortalDetectorTest;
+ FRIEND_TEST(PortalDetectorTest, StartAttemptFailed);
+ FRIEND_TEST(PortalDetectorTest, StartAttemptRepeated);
+ FRIEND_TEST(PortalDetectorTest, AttemptCount);
+
+ // Number of times to attempt connection.
+ static const int kMaxRequestAttempts;
+ // Minimum time between attempts to connect to server.
+ static const int kMinTimeBetweenAttemptsSeconds;
+ // Time to wait for request to complete.
+ static const int kRequestTimeoutSeconds;
+
+ static const char kPhaseConnectionString[];
+ static const char kPhaseDNSString[];
+ static const char kPhaseHTTPString[];
+ static const char kPhaseContentString[];
+ static const char kPhaseUnknownString[];
+
+ static const char kStatusFailureString[];
+ static const char kStatusSuccessString[];
+ static const char kStatusTimeoutString[];
+
+ void CompleteAttempt(Result result);
+ void RequestReadCallback(int read_length);
+ void RequestResultCallback(HTTPRequest::Result result);
+ void StartAttempt();
+ void StartAttemptTask();
+ void StopAttempt();
+ void TimeoutAttemptTask();
+
+ int attempt_count_;
+ struct timeval attempt_start_time_;
+ ConnectionRefPtr connection_;
+ EventDispatcher *dispatcher_;
+ Callback1<const Result &>::Type *portal_result_callback_;
+ scoped_ptr<HTTPRequest> request_;
+ scoped_ptr<Callback1<int>::Type> request_read_callback_;
+ scoped_ptr<Callback1<HTTPRequest::Result>::Type> request_result_callback_;
+ Sockets sockets_;
+ ScopedRunnableMethodFactory<PortalDetector> task_factory_;
+ Time *time_;
+ HTTPURL url_;
+
+ DISALLOW_COPY_AND_ASSIGN(PortalDetector);
+};
+
+} // namespace shill
+
+#endif // SHILL_PORTAL_DETECTOR_
diff --git a/portal_detector_unittest.cc b/portal_detector_unittest.cc
new file mode 100644
index 0000000..6beb68c
--- /dev/null
+++ b/portal_detector_unittest.cc
@@ -0,0 +1,411 @@
+// Copyright (c) 2012 The Chromium OS Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "shill/portal_detector.h"
+
+#include <string>
+
+#include <base/memory/scoped_ptr.h>
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+#include "shill/mock_connection.h"
+#include "shill/mock_control.h"
+#include "shill/mock_device_info.h"
+#include "shill/mock_event_dispatcher.h"
+#include "shill/mock_http_request.h"
+#include "shill/mock_time.h"
+
+using std::string;
+using std::vector;
+using testing::_;
+using testing::AtLeast;
+using testing::DoAll;
+using testing::InSequence;
+using testing::NiceMock;
+using testing::Return;
+using testing::ReturnRef;
+using testing::SetArgumentPointee;
+using testing::StrictMock;
+using testing::Test;
+
+namespace shill {
+
+namespace {
+const char kBadURL[] = "badurl";
+const char kInterfaceName[] = "int0";
+const char kURL[] = "http://www.chromium.org";
+const char kDNSServer0[] = "8.8.8.8";
+const char kDNSServer1[] = "8.8.4.4";
+const char *kDNSServers[] = { kDNSServer0, kDNSServer1 };
+} // namespace {}
+
+MATCHER_P(IsResult, result, "") {
+ return (result.phase == arg.phase &&
+ result.status == arg.status &&
+ result.final == arg.final);
+}
+
+class PortalDetectorTest : public Test {
+ public:
+ PortalDetectorTest()
+ : device_info_(new NiceMock<MockDeviceInfo>(
+ &control_,
+ reinterpret_cast<EventDispatcher *>(NULL),
+ reinterpret_cast<Metrics *>(NULL),
+ reinterpret_cast<Manager *>(NULL))),
+ connection_(new StrictMock<MockConnection>(device_info_.get())),
+ portal_detector_(new PortalDetector(
+ connection_.get(),
+ &dispatcher_,
+ callback_target_.result_callback())),
+ interface_name_(kInterfaceName),
+ dns_servers_(kDNSServers, kDNSServers + 2),
+ http_request_(NULL) {
+ current_time_.tv_sec = current_time_.tv_usec = 0;
+ }
+
+ virtual void SetUp() {
+ EXPECT_CALL(*connection_.get(), interface_name())
+ .WillRepeatedly(ReturnRef(interface_name_));
+ portal_detector_->time_ = &time_;
+ EXPECT_CALL(time_, GetTimeMonotonic(_))
+ .WillRepeatedly(Invoke(this, &PortalDetectorTest::GetTimeMonotonic));
+ EXPECT_CALL(*connection_.get(), dns_servers())
+ .WillRepeatedly(ReturnRef(dns_servers_));
+ }
+
+ virtual void TearDown() {
+ if (portal_detector_->request_.get()) {
+ EXPECT_CALL(*http_request(), Stop());
+
+ // Delete the portal detector while expectations still exist.
+ portal_detector_.reset();
+ }
+ }
+
+ protected:
+ class CallbackTarget {
+ public:
+ CallbackTarget()
+ : result_callback_(NewCallback(this, &CallbackTarget::ResultCallback)) {
+ }
+
+ MOCK_METHOD1(ResultCallback, void(const PortalDetector::Result &result));
+ Callback1<const PortalDetector::Result &>::Type *result_callback() {
+ return result_callback_.get();
+ }
+
+ private:
+ scoped_ptr<Callback1<const PortalDetector::Result &>::Type>
+ result_callback_;
+ };
+
+ void AssignHTTPRequest() {
+ http_request_ = new StrictMock<MockHTTPRequest>(connection_);
+ portal_detector_->request_.reset(http_request_); // Passes ownership.
+ }
+
+ bool StartPortalRequest(const string &url_string) {
+ bool ret = portal_detector_->Start(url_string);
+ if (ret) {
+ AssignHTTPRequest();
+ }
+ return ret;
+ }
+
+ void TimeoutAttempt() {
+ portal_detector_->TimeoutAttemptTask();
+ }
+
+ MockHTTPRequest *http_request() { return http_request_; }
+ PortalDetector *portal_detector() { return portal_detector_.get(); }
+ MockEventDispatcher &dispatcher() { return dispatcher_; }
+ CallbackTarget &callback_target() { return callback_target_; }
+ ByteString &response_data() { return response_data_; }
+
+ void ExpectReset() {
+ EXPECT_FALSE(portal_detector_->attempt_count_);
+ EXPECT_EQ(callback_target_.result_callback(),
+ portal_detector_->portal_result_callback_);
+ EXPECT_FALSE(portal_detector_->request_.get());
+ }
+
+ void ExpectAttemptRetry(const PortalDetector::Result &result) {
+ EXPECT_CALL(callback_target(),
+ ResultCallback(IsResult(result)));
+
+ // Expect the PortalDetector to stop the current request.
+ EXPECT_CALL(*http_request(), Stop());
+
+ // Expect the PortalDetector to schedule the next attempt.
+ EXPECT_CALL(
+ dispatcher(),
+ PostDelayedTask(
+ _, PortalDetector::kMinTimeBetweenAttemptsSeconds * 1000));
+ }
+
+ void AdvanceTime(int milliseconds) {
+ struct timeval tv = { milliseconds / 1000, (milliseconds % 1000) * 1000 };
+ timeradd(¤t_time_, &tv, ¤t_time_);
+ }
+
+ void StartAttempt() {
+ EXPECT_CALL(dispatcher(), PostDelayedTask(_, 0));
+ EXPECT_TRUE(StartPortalRequest(kURL));
+
+ // Expect that the request will be started -- return failure.
+ EXPECT_CALL(*http_request(), Start(_, _, _))
+ .WillOnce(Return(HTTPRequest::kResultInProgress));
+ EXPECT_CALL(dispatcher(), PostDelayedTask(
+ _, PortalDetector::kRequestTimeoutSeconds * 1000));
+
+ portal_detector()->StartAttemptTask();
+ }
+
+ void AppendReadData(const string &read_data) {
+ response_data_.Append(ByteString(read_data, false));
+ EXPECT_CALL(*http_request_, response_data())
+ .WillOnce(ReturnRef(response_data_));
+ portal_detector_->RequestReadCallback(response_data_.GetLength());
+ }
+
+ private:
+ int GetTimeMonotonic(struct timeval *tv) {
+ *tv = current_time_;
+ return 0;
+ }
+
+ StrictMock<MockEventDispatcher> dispatcher_;
+ MockControl control_;
+ scoped_ptr<MockDeviceInfo> device_info_;
+ scoped_refptr<MockConnection> connection_;
+ CallbackTarget callback_target_;
+ scoped_ptr<PortalDetector> portal_detector_;
+ StrictMock<MockTime> time_;
+ struct timeval current_time_;
+ const string interface_name_;
+ vector<string> dns_servers_;
+ ByteString response_data_;
+
+ // Owned by the PortalDetector, but tracked here for EXPECT_CALL()
+ MockHTTPRequest *http_request_;
+};
+
+TEST_F(PortalDetectorTest, Constructor) {
+ ExpectReset();
+}
+
+TEST_F(PortalDetectorTest, InvalidURL) {
+ EXPECT_FALSE(StartPortalRequest(kBadURL));
+ ExpectReset();
+}
+
+TEST_F(PortalDetectorTest, StartAttemptFailed) {
+ EXPECT_CALL(dispatcher(), PostDelayedTask(_, 0));
+ EXPECT_TRUE(StartPortalRequest(kURL));
+
+ // Expect that the request will be started -- return failure.
+ EXPECT_CALL(*http_request(), Start(_, _, _))
+ .WillOnce(Return(HTTPRequest::kResultConnectionFailure));
+ // Expect a non-final failure to be relayed to the caller.
+ ExpectAttemptRetry(PortalDetector::Result(
+ PortalDetector::kPhaseConnection,
+ PortalDetector::kStatusFailure,
+ false));
+ portal_detector()->StartAttemptTask();
+}
+
+TEST_F(PortalDetectorTest, StartAttemptRepeated) {
+ EXPECT_CALL(dispatcher(), PostDelayedTask(_, 0));
+ portal_detector()->StartAttempt();
+
+ AssignHTTPRequest();
+ EXPECT_CALL(*http_request(), Start(_, _, _))
+ .WillOnce(Return(HTTPRequest::kResultInProgress));
+ EXPECT_CALL(
+ dispatcher(),
+ PostDelayedTask(
+ _, PortalDetector::kRequestTimeoutSeconds * 1000));
+ portal_detector()->StartAttemptTask();
+
+ // A second attempt should be delayed by kMinTimeBetweenAttemptsSeconds.
+ EXPECT_CALL(
+ dispatcher(),
+ PostDelayedTask(
+ _, PortalDetector::kMinTimeBetweenAttemptsSeconds * 1000));
+ portal_detector()->StartAttempt();
+}
+
+TEST_F(PortalDetectorTest, AttemptCount) {
+ // Expect the PortalDetector to immediately post a task for the each attempt.
+ EXPECT_CALL(dispatcher(), PostDelayedTask(_, 0))
+ .Times(PortalDetector::kMaxRequestAttempts);
+ EXPECT_TRUE(StartPortalRequest(kURL));
+
+ // Expect that the request will be started -- return failure.
+ EXPECT_CALL(*http_request(), Start(_, _, _))
+ .Times(PortalDetector::kMaxRequestAttempts)
+ .WillRepeatedly(Return(HTTPRequest::kResultInProgress));
+
+ // Each HTTP request that gets started will have a request timeout.
+ EXPECT_CALL(dispatcher(), PostDelayedTask(
+ _, PortalDetector::kRequestTimeoutSeconds * 1000))
+ .Times(PortalDetector::kMaxRequestAttempts);
+
+ {
+ InSequence s;
+
+ // Expect non-final failures for all attempts but the last.
+ EXPECT_CALL(callback_target(),
+ ResultCallback(IsResult(
+ PortalDetector::Result(
+ PortalDetector::kPhaseDNS,
+ PortalDetector::kStatusFailure,
+ false))))
+ .Times(PortalDetector::kMaxRequestAttempts - 1);
+
+ // Expect a single final failure.
+ EXPECT_CALL(callback_target(),
+ ResultCallback(IsResult(
+ PortalDetector::Result(
+ PortalDetector::kPhaseDNS,
+ PortalDetector::kStatusFailure,
+ true))))
+ .Times(1);
+ }
+
+ // Expect the PortalDetector to stop the current request each time, plus
+ // an extra time in PortalDetector::Stop().
+ EXPECT_CALL(*http_request(), Stop())
+ .Times(PortalDetector::kMaxRequestAttempts + 1);
+
+
+ for (int i = 0; i < PortalDetector::kMaxRequestAttempts; i++) {
+ portal_detector()->StartAttemptTask();
+ AdvanceTime(PortalDetector::kMinTimeBetweenAttemptsSeconds * 1000);
+ portal_detector()->RequestResultCallback(HTTPRequest::kResultDNSFailure);
+ }
+
+ ExpectReset();
+}
+
+TEST_F(PortalDetectorTest, ReadBadHeader) {
+ StartAttempt();
+
+ ExpectAttemptRetry(PortalDetector::Result(
+ PortalDetector::kPhaseContent,
+ PortalDetector::kStatusFailure,
+ false));
+ AppendReadData("X");
+}
+
+TEST_F(PortalDetectorTest, RequestTimeout) {
+ StartAttempt();
+ ExpectAttemptRetry(PortalDetector::Result(
+ PortalDetector::kPhaseUnknown,
+ PortalDetector::kStatusTimeout,
+ false));
+
+ EXPECT_CALL(*http_request(), response_data())
+ .WillOnce(ReturnRef(response_data()));
+
+ TimeoutAttempt();
+}
+
+TEST_F(PortalDetectorTest, ReadPartialHeaderTimeout) {
+ StartAttempt();
+
+ const string response_expected(PortalDetector::kResponseExpected);
+ const size_t partial_size = response_expected.length() / 2;
+ AppendReadData(response_expected.substr(0, partial_size));
+
+ ExpectAttemptRetry(PortalDetector::Result(
+ PortalDetector::kPhaseContent,
+ PortalDetector::kStatusTimeout,
+ false));
+
+ EXPECT_CALL(*http_request(), response_data())
+ .WillOnce(ReturnRef(response_data()));
+
+ TimeoutAttempt();
+}
+
+TEST_F(PortalDetectorTest, ReadCompleteHeader) {
+ const string response_expected(PortalDetector::kResponseExpected);
+ const size_t partial_size = response_expected.length() / 2;
+
+ StartAttempt();
+ AppendReadData(response_expected.substr(0, partial_size));
+
+ EXPECT_CALL(callback_target(),
+ ResultCallback(IsResult(
+ PortalDetector::Result(
+ PortalDetector::kPhaseContent,
+ PortalDetector::kStatusSuccess,
+ true))));
+ EXPECT_CALL(*http_request(), Stop())
+ .Times(2);
+
+ AppendReadData(response_expected.substr(partial_size));
+}
+struct ResultMapping {
+ ResultMapping() : http_result(HTTPRequest::kResultUnknown), portal_result() {}
+ ResultMapping(HTTPRequest::Result in_http_result,
+ const PortalDetector::Result &in_portal_result)
+ : http_result(in_http_result),
+ portal_result(in_portal_result) {}
+ HTTPRequest::Result http_result;
+ PortalDetector::Result portal_result;
+};
+
+class PortalDetectorResultMappingTest
+ : public testing::TestWithParam<ResultMapping> {};
+
+TEST_P(PortalDetectorResultMappingTest, MapResult) {
+ PortalDetector::Result portal_result =
+ PortalDetector::GetPortalResultForRequestResult(GetParam().http_result);
+ EXPECT_EQ(portal_result.phase, GetParam().portal_result.phase);
+ EXPECT_EQ(portal_result.status, GetParam().portal_result.status);
+}
+
+INSTANTIATE_TEST_CASE_P(
+ PortalResultMappingTest,
+ PortalDetectorResultMappingTest,
+ ::testing::Values(
+ ResultMapping(HTTPRequest::kResultUnknown,
+ PortalDetector::Result(PortalDetector::kPhaseUnknown,
+ PortalDetector::kStatusFailure)),
+ ResultMapping(HTTPRequest::kResultInProgress,
+ PortalDetector::Result(PortalDetector::kPhaseUnknown,
+ PortalDetector::kStatusFailure)),
+ ResultMapping(HTTPRequest::kResultDNSFailure,
+ PortalDetector::Result(PortalDetector::kPhaseDNS,
+ PortalDetector::kStatusFailure)),
+ ResultMapping(HTTPRequest::kResultDNSTimeout,
+ PortalDetector::Result(PortalDetector::kPhaseDNS,
+ PortalDetector::kStatusTimeout)),
+ ResultMapping(HTTPRequest::kResultConnectionFailure,
+ PortalDetector::Result(PortalDetector::kPhaseConnection,
+ PortalDetector::kStatusFailure)),
+ ResultMapping(HTTPRequest::kResultConnectionTimeout,
+ PortalDetector::Result(PortalDetector::kPhaseConnection,
+ PortalDetector::kStatusTimeout)),
+ ResultMapping(HTTPRequest::kResultRequestFailure,
+ PortalDetector::Result(PortalDetector::kPhaseHTTP,
+ PortalDetector::kStatusFailure)),
+ ResultMapping(HTTPRequest::kResultRequestTimeout,
+ PortalDetector::Result(PortalDetector::kPhaseHTTP,
+ PortalDetector::kStatusTimeout)),
+ ResultMapping(HTTPRequest::kResultResponseFailure,
+ PortalDetector::Result(PortalDetector::kPhaseHTTP,
+ PortalDetector::kStatusFailure)),
+ ResultMapping(HTTPRequest::kResultResponseTimeout,
+ PortalDetector::Result(PortalDetector::kPhaseHTTP,
+ PortalDetector::kStatusTimeout)),
+ ResultMapping(HTTPRequest::kResultSuccess,
+ PortalDetector::Result(PortalDetector::kPhaseContent,
+ PortalDetector::kStatusFailure))));
+
+} // namespace shill
diff --git a/shill_time.cc b/shill_time.cc
index 7bae785..0c88e41 100644
--- a/shill_time.cc
+++ b/shill_time.cc
@@ -1,9 +1,11 @@
-// Copyright (c) 2011 The Chromium OS Authors. All rights reserved.
+// Copyright (c) 2012 The Chromium OS Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "shill/shill_time.h"
+#include <time.h>
+
namespace shill {
static base::LazyInstance<Time> g_time(base::LINKER_INITIALIZED);
@@ -16,6 +18,17 @@
return g_time.Pointer();
}
+int Time::GetTimeMonotonic(struct timeval *tv) {
+ struct timespec ts;
+ if (clock_gettime(CLOCK_MONOTONIC, &ts) != 0) {
+ return -1;
+ }
+
+ tv->tv_sec = ts.tv_sec;
+ tv->tv_usec = ts.tv_nsec / 1000;
+ return 0;
+}
+
int Time::GetTimeOfDay(struct timeval *tv, struct timezone *tz) {
return gettimeofday(tv, tz);
}
diff --git a/shill_time.h b/shill_time.h
index 846acd4..17d748f 100644
--- a/shill_time.h
+++ b/shill_time.h
@@ -1,4 +1,4 @@
-// Copyright (c) 2011 The Chromium OS Authors. All rights reserved.
+// Copyright (c) 2012 The Chromium OS Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
@@ -18,6 +18,9 @@
static Time *GetInstance();
+ // clock_gettime(CLOCK_MONOTONIC ...
+ virtual int GetTimeMonotonic(struct timeval *tv);
+
// gettimeofday
virtual int GetTimeOfDay(struct timeval *tv, struct timezone *tz);