shill: openvpn: Allow automatic reconnect when underlying connection reconnects.
When the underlying connection drops, restart the openvpn client, thus
resetting the connection and initiating a reconnect attempt. The
reconnect will be held until a new carrier connection is established,
up to the connect timeout of one minute.
Also, explicitly process SUCCESS management API messages so that
they're not logged as WARNINGs.
BUG=chromium-os:32416
TEST=unit tests; while connected to corp OpenVPN, switch between APs
and observe VPN reconnects; while connected to corp OpenVPN,
disconnect or disable WiFi, reconnect or re-enable it in a couple of
seconds and observe VPN reconnects.
Change-Id: I672db885d589fa020419d0badb480aee9bcc851a
Reviewed-on: https://gerrit.chromium.org/gerrit/42616
Tested-by: Darin Petkov <petkov@chromium.org>
Reviewed-by: Paul Stewart <pstew@chromium.org>
Commit-Queue: Darin Petkov <petkov@chromium.org>
diff --git a/l2tp_ipsec_driver.cc b/l2tp_ipsec_driver.cc
index 89d9a3c..fdd18d4 100644
--- a/l2tp_ipsec_driver.cc
+++ b/l2tp_ipsec_driver.cc
@@ -111,7 +111,12 @@
}
void L2TPIPSecDriver::OnConnectionDisconnected() {
- LOG(ERROR) << "VPN connection disconnected.";
+ LOG(INFO) << "Underlying connection disconnected.";
+ Cleanup(Service::kStateIdle);
+}
+
+void L2TPIPSecDriver::OnConnectTimeout() {
+ VPNDriver::OnConnectTimeout();
Cleanup(Service::kStateFailure);
}
diff --git a/l2tp_ipsec_driver.h b/l2tp_ipsec_driver.h
index e7231b8..60722cf 100644
--- a/l2tp_ipsec_driver.h
+++ b/l2tp_ipsec_driver.h
@@ -2,8 +2,8 @@
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
-#ifndef SHILL_L2TP_IPSEC_DRIVER_
-#define SHILL_L2TP_IPSEC_DRIVER_
+#ifndef SHILL_L2TP_IPSEC_DRIVER_H_
+#define SHILL_L2TP_IPSEC_DRIVER_H_
#include <vector>
@@ -48,13 +48,15 @@
GLib *glib);
virtual ~L2TPIPSecDriver();
+ protected:
// Inherited from VPNDriver.
virtual bool ClaimInterface(const std::string &link_name,
int interface_index);
virtual void Connect(const VPNServiceRefPtr &service, Error *error);
virtual void Disconnect();
- virtual void OnConnectionDisconnected();
virtual std::string GetProviderType() const;
+ virtual void OnConnectionDisconnected();
+ virtual void OnConnectTimeout();
private:
friend class L2TPIPSecDriverTest;
@@ -147,4 +149,4 @@
} // namespace shill
-#endif // SHILL_L2TP_IPSEC_DRIVER_
+#endif // SHILL_L2TP_IPSEC_DRIVER_H_
diff --git a/l2tp_ipsec_driver_unittest.cc b/l2tp_ipsec_driver_unittest.cc
index 9c20ce4..b27b6c2 100644
--- a/l2tp_ipsec_driver_unittest.cc
+++ b/l2tp_ipsec_driver_unittest.cc
@@ -83,6 +83,30 @@
return driver_->args();
}
+ string GetProviderType() {
+ return driver_->GetProviderType();
+ }
+
+ void SetService(const VPNServiceRefPtr &service) {
+ driver_->service_ = service;
+ }
+
+ VPNServiceRefPtr GetService() {
+ return driver_->service_;
+ }
+
+ void OnConnectTimeout() {
+ driver_->OnConnectTimeout();
+ }
+
+ void StartConnectTimeout() {
+ driver_->StartConnectTimeout();
+ }
+
+ bool IsConnectTimeoutStarted() {
+ return driver_->IsConnectTimeoutStarted();
+ }
+
// Used to assert that a flag appears in the options.
void ExpectInFlags(const vector<string> &options, const string &flag,
const string &value);
@@ -165,7 +189,7 @@
}
TEST_F(L2TPIPSecDriverTest, GetProviderType) {
- EXPECT_EQ(flimflam::kProviderL2tpIpsec, driver_->GetProviderType());
+ EXPECT_EQ(flimflam::kProviderL2tpIpsec, GetProviderType());
}
TEST_F(L2TPIPSecDriverTest, Cleanup) {
@@ -452,11 +476,20 @@
TEST_F(L2TPIPSecDriverTest, OnConnectionDisconnected) {
driver_->service_ = service_;
- EXPECT_CALL(*service_, SetState(Service::kStateFailure));
+ EXPECT_CALL(*service_, SetState(Service::kStateIdle));
driver_->OnConnectionDisconnected();
EXPECT_FALSE(driver_->service_);
}
+TEST_F(L2TPIPSecDriverTest, OnConnectTimeout) {
+ StartConnectTimeout();
+ SetService(service_);
+ EXPECT_CALL(*service_, SetState(Service::kStateFailure));
+ OnConnectTimeout();
+ EXPECT_FALSE(GetService());
+ EXPECT_FALSE(IsConnectTimeoutStarted());
+}
+
TEST_F(L2TPIPSecDriverTest, InitPropertyStore) {
// Sanity test property store initialization.
PropertyStore store;
diff --git a/mock_openvpn_management_server.h b/mock_openvpn_management_server.h
index f2ddaba..16e3053 100644
--- a/mock_openvpn_management_server.h
+++ b/mock_openvpn_management_server.h
@@ -2,8 +2,8 @@
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
-#ifndef SHILL_MOCK_OPENVPN_MANAGEMENT_SERVER_
-#define SHILL_MOCK_OPENVPN_MANAGEMENT_SERVER_
+#ifndef SHILL_MOCK_OPENVPN_MANAGEMENT_SERVER_H_
+#define SHILL_MOCK_OPENVPN_MANAGEMENT_SERVER_H_
#include <gmock/gmock.h>
@@ -22,6 +22,7 @@
MOCK_METHOD0(Stop, void());
MOCK_METHOD0(ReleaseHold, void());
MOCK_METHOD0(Hold, void());
+ MOCK_METHOD0(Restart, void());
private:
DISALLOW_COPY_AND_ASSIGN(MockOpenVPNManagementServer);
@@ -29,4 +30,4 @@
} // namespace shill
-#endif // SHILL_MOCK_OPENVPN_MANAGEMENT_SERVER_
+#endif // SHILL_MOCK_OPENVPN_MANAGEMENT_SERVER_H_
diff --git a/openvpn_driver.cc b/openvpn_driver.cc
index ca8aea4..5551baf 100644
--- a/openvpn_driver.cc
+++ b/openvpn_driver.cc
@@ -737,7 +737,19 @@
}
void OpenVPNDriver::OnConnectionDisconnected() {
- LOG(ERROR) << "VPN connection disconnected.";
+ LOG(INFO) << "Underlying connection disconnected.";
+ // Restart the OpenVPN client forcing a reconnect attempt.
+ management_server_->Restart();
+ // Indicate reconnect state right away to drop the VPN connection and start
+ // the connect timeout. This ensures that any miscommunication between shill
+ // and openvpn will not lead to a permanently stale connectivity state. Note
+ // that a subsequent invocation of OnReconnecting due to a RECONNECTING
+ // message will essentially be a no-op.
+ OnReconnecting();
+}
+
+void OpenVPNDriver::OnConnectTimeout() {
+ VPNDriver::OnConnectTimeout();
Cleanup(Service::kStateFailure);
}
diff --git a/openvpn_driver.h b/openvpn_driver.h
index 2fcc2cd..dc85576 100644
--- a/openvpn_driver.h
+++ b/openvpn_driver.h
@@ -2,8 +2,8 @@
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
-#ifndef SHILL_OPENVPN_DRIVER_
-#define SHILL_OPENVPN_DRIVER_
+#ifndef SHILL_OPENVPN_DRIVER_H_
+#define SHILL_OPENVPN_DRIVER_H_
#include <map>
#include <string>
@@ -49,19 +49,6 @@
GLib *glib);
virtual ~OpenVPNDriver();
- // Inherited from VPNDriver. |Connect| initiates the VPN connection by
- // creating a tunnel device. When the device index becomes available, this
- // instance is notified through |ClaimInterface| and resumes the connection
- // process by setting up and spawning an external 'openvpn' process. IP
- // configuration settings are passed back from the external process through
- // the |Notify| RPC service method.
- virtual void Connect(const VPNServiceRefPtr &service, Error *error);
- virtual bool ClaimInterface(const std::string &link_name,
- int interface_index);
- virtual void Disconnect();
- virtual void OnConnectionDisconnected();
- virtual std::string GetProviderType() const;
-
virtual void OnReconnecting();
virtual void Cleanup(Service::ConnectState state);
@@ -76,6 +63,21 @@
const std::string &option,
std::vector<std::string> *options);
+ protected:
+ // Inherited from VPNDriver. |Connect| initiates the VPN connection by
+ // creating a tunnel device. When the device index becomes available, this
+ // instance is notified through |ClaimInterface| and resumes the connection
+ // process by setting up and spawning an external 'openvpn' process. IP
+ // configuration settings are passed back from the external process through
+ // the |Notify| RPC service method.
+ virtual void Connect(const VPNServiceRefPtr &service, Error *error);
+ virtual bool ClaimInterface(const std::string &link_name,
+ int interface_index);
+ virtual void Disconnect();
+ virtual std::string GetProviderType() const;
+ virtual void OnConnectionDisconnected();
+ virtual void OnConnectTimeout();
+
private:
friend class OpenVPNDriverTest;
FRIEND_TEST(OpenVPNDriverTest, ClaimInterface);
@@ -95,7 +97,6 @@
FRIEND_TEST(OpenVPNDriverTest, InitPKCS11Options);
FRIEND_TEST(OpenVPNDriverTest, Notify);
FRIEND_TEST(OpenVPNDriverTest, NotifyFail);
- FRIEND_TEST(OpenVPNDriverTest, OnConnectionDisconnected);
FRIEND_TEST(OpenVPNDriverTest, OnDefaultServiceChanged);
FRIEND_TEST(OpenVPNDriverTest, OnOpenVPNDied);
FRIEND_TEST(OpenVPNDriverTest, OnReconnecting);
@@ -213,4 +214,4 @@
} // namespace shill
-#endif // SHILL_OPENVPN_DRIVER_
+#endif // SHILL_OPENVPN_DRIVER_H_
diff --git a/openvpn_driver_unittest.cc b/openvpn_driver_unittest.cc
index 8ad04c1..7a049d1 100644
--- a/openvpn_driver_unittest.cc
+++ b/openvpn_driver_unittest.cc
@@ -122,6 +122,34 @@
return &driver_->sockets_;
}
+ void SetDevice(const VPNRefPtr &device) {
+ driver_->device_ = device;
+ }
+
+ void SetService(const VPNServiceRefPtr &service) {
+ driver_->service_ = service;
+ }
+
+ VPNServiceRefPtr GetService() {
+ return driver_->service_;
+ }
+
+ void OnConnectionDisconnected() {
+ driver_->OnConnectionDisconnected();
+ }
+
+ void OnConnectTimeout() {
+ driver_->OnConnectTimeout();
+ }
+
+ void StartConnectTimeout() {
+ driver_->StartConnectTimeout();
+ }
+
+ bool IsConnectTimeoutStarted() {
+ return driver_->IsConnectTimeoutStarted();
+ }
+
// Used to assert that a flag appears in the options.
void ExpectInFlags(const vector<string> &options, const string &flag,
const string &value);
@@ -857,10 +885,22 @@
}
TEST_F(OpenVPNDriverTest, OnConnectionDisconnected) {
- driver_->service_ = service_;
+ EXPECT_CALL(*management_server_, Restart());
+ SetDevice(device_);
+ SetService(service_);
+ EXPECT_CALL(*device_, OnDisconnected());
+ EXPECT_CALL(*service_, SetState(Service::kStateAssociating));
+ OnConnectionDisconnected();
+ EXPECT_TRUE(IsConnectTimeoutStarted());
+}
+
+TEST_F(OpenVPNDriverTest, OnConnectTimeout) {
+ StartConnectTimeout();
+ SetService(service_);
EXPECT_CALL(*service_, SetState(Service::kStateFailure));
- driver_->OnConnectionDisconnected();
- EXPECT_FALSE(driver_->service_);
+ OnConnectTimeout();
+ EXPECT_FALSE(GetService());
+ EXPECT_FALSE(IsConnectTimeoutStarted());
}
TEST_F(OpenVPNDriverTest, OnReconnecting) {
diff --git a/openvpn_management_server.cc b/openvpn_management_server.cc
index 11724a4..ff1955a 100644
--- a/openvpn_management_server.cc
+++ b/openvpn_management_server.cc
@@ -140,6 +140,11 @@
hold_release_ = false;
}
+void OpenVPNManagementServer::Restart() {
+ LOG(INFO) << "Restart.";
+ SendSignal("SIGUSR1");
+}
+
void OpenVPNManagementServer::OnReady(int fd) {
SLOG(VPN, 2) << __func__ << "(" << fd << ")";
connected_socket_ = sockets_->Accept(fd, NULL, NULL);
@@ -174,7 +179,8 @@
!ProcessNeedPasswordMessage(message) &&
!ProcessFailedPasswordMessage(message) &&
!ProcessStateMessage(message) &&
- !ProcessHoldMessage(message)) {
+ !ProcessHoldMessage(message) &&
+ !ProcessSuccessMessage(message)) {
LOG(WARNING) << "OpenVPN management message ignored: " << message;
}
}
@@ -183,7 +189,7 @@
if (!StartsWithASCII(message, ">INFO:", true)) {
return false;
}
- LOG(INFO) << "Processing info message.";
+ LOG(INFO) << message;
return true;
}
@@ -309,10 +315,10 @@
if (!StartsWithASCII(message, ">STATE:", true)) {
return false;
}
- LOG(INFO) << "Processing state message.";
vector<string> details;
SplitString(message, ',', &details);
if (details.size() > 1) {
+ LOG(INFO) << "Processing state message: " << details[1];
if (details[1] == "RECONNECTING") {
driver_->OnReconnecting();
}
@@ -325,7 +331,7 @@
if (!StartsWithASCII(message, ">HOLD:Waiting for hold release", true)) {
return false;
}
- LOG(INFO) << "Processing hold message.";
+ LOG(INFO) << "Client waiting for hold release.";
hold_waiting_ = true;
if (hold_release_) {
ReleaseHold();
@@ -333,6 +339,14 @@
return true;
}
+bool OpenVPNManagementServer::ProcessSuccessMessage(const string &message) {
+ if (!StartsWithASCII(message, "SUCCESS: ", true)) {
+ return false;
+ }
+ LOG(INFO) << message;
+ return true;
+}
+
// static
string OpenVPNManagementServer::EscapeToQuote(const string &str) {
string escaped;
@@ -371,6 +385,11 @@
EscapeToQuote(password).c_str()));
}
+void OpenVPNManagementServer::SendSignal(const string &signal) {
+ SLOG(VPN, 2) << __func__ << "(" << signal << ")";
+ Send(StringPrintf("signal %s\n", signal.c_str()));
+}
+
void OpenVPNManagementServer::SendHoldRelease() {
SLOG(VPN, 2) << __func__;
Send("hold release\n");
diff --git a/openvpn_management_server.h b/openvpn_management_server.h
index 2d2851a..5e7ef2d 100644
--- a/openvpn_management_server.h
+++ b/openvpn_management_server.h
@@ -2,8 +2,8 @@
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
-#ifndef SHILL_OPENVPN_MANAGEMENT_SERVER_
-#define SHILL_OPENVPN_MANAGEMENT_SERVER_
+#ifndef SHILL_OPENVPN_MANAGEMENT_SERVER_H_
+#define SHILL_OPENVPN_MANAGEMENT_SERVER_H_
#include <string>
#include <vector>
@@ -45,6 +45,9 @@
// existing connection, nor sends any commands to the openvpn client.
virtual void Hold();
+ // Restarts openvpn causing a disconnect followed by a reconnect attempt.
+ virtual void Restart();
+
private:
friend class OpenVPNManagementServerTest;
FRIEND_TEST(OpenVPNManagementServerTest, EscapeToQuote);
@@ -86,6 +89,7 @@
void SendUsername(const std::string &tag, const std::string &username);
void SendPassword(const std::string &tag, const std::string &password);
void SendHoldRelease();
+ void SendSignal(const std::string &signal);
void ProcessMessage(const std::string &message);
bool ProcessInfoMessage(const std::string &message);
@@ -93,6 +97,7 @@
bool ProcessFailedPasswordMessage(const std::string &message);
bool ProcessStateMessage(const std::string &message);
bool ProcessHoldMessage(const std::string &message);
+ bool ProcessSuccessMessage(const std::string &message);
void PerformStaticChallenge(const std::string &tag);
void PerformAuthentication(const std::string &tag);
@@ -128,4 +133,4 @@
} // namespace shill
-#endif // SHILL_OPENVPN_MANAGEMENT_SERVER_
+#endif // SHILL_OPENVPN_MANAGEMENT_SERVER_H_
diff --git a/openvpn_management_server_unittest.cc b/openvpn_management_server_unittest.cc
index 59d2f90..c26ffaa 100644
--- a/openvpn_management_server_unittest.cc
+++ b/openvpn_management_server_unittest.cc
@@ -41,6 +41,9 @@
virtual ~OpenVPNManagementServerTest() {}
+ protected:
+ static const int kConnectedSocket;
+
void SetSockets() { server_.sockets_ = &sockets_; }
void SetDispatcher() { server_.dispatcher_ = &dispatcher_; }
void ExpectNotStarted() { EXPECT_FALSE(server_.IsStarted()); }
@@ -84,6 +87,11 @@
ExpectSend("hold release\n");
}
+ void ExpectRestart() {
+ SetConnectedSocket();
+ ExpectSend("signal SIGUSR1\n");
+ }
+
InputData CreateInputDataFromString(const string &str) {
InputData data(
reinterpret_cast<unsigned char *>(const_cast<char *>(str.data())),
@@ -91,8 +99,13 @@
return data;
}
- protected:
- static const int kConnectedSocket;
+ void SendSignal(const string &signal) {
+ server_.SendSignal(signal);
+ }
+
+ bool ProcessSuccessMessage(const string &message) {
+ return server_.ProcessSuccessMessage(message);
+ }
GLib glib_;
MockOpenVPNDriver driver_;
@@ -207,7 +220,8 @@
">PASSWORD:Need 'User-Specific TPM Token FOO' ...\n"
">PASSWORD:Verification Failed: .\n"
">STATE:123,RECONNECTING,detail,...,...\n"
- ">HOLD:Waiting for hold release";
+ ">HOLD:Waiting for hold release\n"
+ "SUCCESS: Hold released.";
InputData data = CreateInputDataFromString(s);
ExpectStaticChallengeResponse();
ExpectPINResponse();
@@ -241,9 +255,14 @@
server_.ProcessMessage(">STATE:123,RECONNECTING,detail,...,...");
}
+TEST_F(OpenVPNManagementServerTest, ProcessSuccessMessage) {
+ EXPECT_FALSE(ProcessSuccessMessage("foo"));
+ EXPECT_TRUE(ProcessSuccessMessage("SUCCESS: foo"));
+}
+
TEST_F(OpenVPNManagementServerTest, ProcessInfoMessage) {
EXPECT_FALSE(server_.ProcessInfoMessage("foo"));
- EXPECT_TRUE(server_.ProcessInfoMessage(">INFO:"));
+ EXPECT_TRUE(server_.ProcessInfoMessage(">INFO:foo"));
}
TEST_F(OpenVPNManagementServerTest, ProcessStateMessage) {
@@ -380,6 +399,17 @@
server_.ProcessFailedPasswordMessage(">PASSWORD:Verification Failed: ."));
}
+TEST_F(OpenVPNManagementServerTest, SendSignal) {
+ SetConnectedSocket();
+ ExpectSend("signal SIGUSR2\n");
+ SendSignal("SIGUSR2");
+}
+
+TEST_F(OpenVPNManagementServerTest, Restart) {
+ ExpectRestart();
+ server_.Restart();
+}
+
TEST_F(OpenVPNManagementServerTest, SendHoldRelease) {
ExpectHoldRelease();
server_.SendHoldRelease();
diff --git a/vpn_driver.cc b/vpn_driver.cc
index 5175187..c6fbf72 100644
--- a/vpn_driver.cc
+++ b/vpn_driver.cc
@@ -181,9 +181,8 @@
}
void VPNDriver::OnConnectTimeout() {
- LOG(ERROR) << "VPN connection timeout.";
+ LOG(INFO) << "VPN connect timeout.";
StopConnectTimeout();
- OnConnectionDisconnected();
}
} // namespace shill
diff --git a/vpn_driver.h b/vpn_driver.h
index f539d0a..4200b7a 100644
--- a/vpn_driver.h
+++ b/vpn_driver.h
@@ -2,8 +2,8 @@
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
-#ifndef SHILL_VPN_DRIVER_
-#define SHILL_VPN_DRIVER_
+#ifndef SHILL_VPN_DRIVER_H_
+#define SHILL_VPN_DRIVER_H_
#include <string>
@@ -32,9 +32,11 @@
int interface_index) = 0;
virtual void Connect(const VPNServiceRefPtr &service, Error *error) = 0;
virtual void Disconnect() = 0;
- virtual void OnConnectionDisconnected() = 0;
virtual std::string GetProviderType() const = 0;
+ // Invoked by VPNService when the underlying connection disconnects.
+ virtual void OnConnectionDisconnected() = 0;
+
virtual void InitPropertyStore(PropertyStore *store);
virtual bool Load(StoreInterface *storage, const std::string &storage_id);
@@ -76,6 +78,10 @@
// Returns true if a connect timeout is scheduled, false otherwise.
bool IsConnectTimeoutStarted() const;
+ // Called if a connect timeout scheduled through StartConnectTimeout
+ // fires. Cancels the timeout callback.
+ virtual void OnConnectTimeout();
+
private:
FRIEND_TEST(VPNDriverTest, ConnectTimeout);
@@ -86,10 +92,6 @@
void SetMappedProperty(
const size_t &index, const std::string &value, Error *error);
- // Called if a connect timeout scheduled through StartConnectTimeout
- // fires. Marks the callback as stopped and invokes OnConnectionDisconnected.
- void OnConnectTimeout();
-
base::WeakPtrFactory<VPNDriver> weak_ptr_factory_;
EventDispatcher *dispatcher_;
Manager *manager_;
@@ -105,4 +107,4 @@
} // namespace shill
-#endif // SHILL_VPN_DRIVER_
+#endif // SHILL_VPN_DRIVER_H_
diff --git a/vpn_driver_unittest.cc b/vpn_driver_unittest.cc
index fe221f3..af52ebb 100644
--- a/vpn_driver_unittest.cc
+++ b/vpn_driver_unittest.cc
@@ -291,7 +291,6 @@
EXPECT_TRUE(driver_.IsConnectTimeoutStarted());
driver_.dispatcher_ = NULL;
driver_.StartConnectTimeout(); // Expect no crash.
- EXPECT_CALL(driver_, OnConnectionDisconnected());
dispatcher_.DispatchPendingEvents();
EXPECT_TRUE(driver_.connect_timeout_callback_.IsCancelled());
EXPECT_FALSE(driver_.IsConnectTimeoutStarted());