shill: openvpn: Allow automatic reconnect when underlying connection reconnects.

When the underlying connection drops, restart the openvpn client, thus
resetting the connection and initiating a reconnect attempt. The
reconnect will be held until a new carrier connection is established,
up to the connect timeout of one minute.

Also, explicitly process SUCCESS management API messages so that
they're not logged as WARNINGs.

BUG=chromium-os:32416
TEST=unit tests; while connected to corp OpenVPN, switch between APs
and observe VPN reconnects; while connected to corp OpenVPN,
disconnect or disable WiFi, reconnect or re-enable it in a couple of
seconds and observe VPN reconnects.

Change-Id: I672db885d589fa020419d0badb480aee9bcc851a
Reviewed-on: https://gerrit.chromium.org/gerrit/42616
Tested-by: Darin Petkov <petkov@chromium.org>
Reviewed-by: Paul Stewart <pstew@chromium.org>
Commit-Queue: Darin Petkov <petkov@chromium.org>
diff --git a/l2tp_ipsec_driver.cc b/l2tp_ipsec_driver.cc
index 89d9a3c..fdd18d4 100644
--- a/l2tp_ipsec_driver.cc
+++ b/l2tp_ipsec_driver.cc
@@ -111,7 +111,12 @@
 }
 
 void L2TPIPSecDriver::OnConnectionDisconnected() {
-  LOG(ERROR) << "VPN connection disconnected.";
+  LOG(INFO) << "Underlying connection disconnected.";
+  Cleanup(Service::kStateIdle);
+}
+
+void L2TPIPSecDriver::OnConnectTimeout() {
+  VPNDriver::OnConnectTimeout();
   Cleanup(Service::kStateFailure);
 }
 
diff --git a/l2tp_ipsec_driver.h b/l2tp_ipsec_driver.h
index e7231b8..60722cf 100644
--- a/l2tp_ipsec_driver.h
+++ b/l2tp_ipsec_driver.h
@@ -2,8 +2,8 @@
 // Use of this source code is governed by a BSD-style license that can be
 // found in the LICENSE file.
 
-#ifndef SHILL_L2TP_IPSEC_DRIVER_
-#define SHILL_L2TP_IPSEC_DRIVER_
+#ifndef SHILL_L2TP_IPSEC_DRIVER_H_
+#define SHILL_L2TP_IPSEC_DRIVER_H_
 
 #include <vector>
 
@@ -48,13 +48,15 @@
                   GLib *glib);
   virtual ~L2TPIPSecDriver();
 
+ protected:
   // Inherited from VPNDriver.
   virtual bool ClaimInterface(const std::string &link_name,
                               int interface_index);
   virtual void Connect(const VPNServiceRefPtr &service, Error *error);
   virtual void Disconnect();
-  virtual void OnConnectionDisconnected();
   virtual std::string GetProviderType() const;
+  virtual void OnConnectionDisconnected();
+  virtual void OnConnectTimeout();
 
  private:
   friend class L2TPIPSecDriverTest;
@@ -147,4 +149,4 @@
 
 }  // namespace shill
 
-#endif  // SHILL_L2TP_IPSEC_DRIVER_
+#endif  // SHILL_L2TP_IPSEC_DRIVER_H_
diff --git a/l2tp_ipsec_driver_unittest.cc b/l2tp_ipsec_driver_unittest.cc
index 9c20ce4..b27b6c2 100644
--- a/l2tp_ipsec_driver_unittest.cc
+++ b/l2tp_ipsec_driver_unittest.cc
@@ -83,6 +83,30 @@
     return driver_->args();
   }
 
+  string GetProviderType() {
+    return driver_->GetProviderType();
+  }
+
+  void SetService(const VPNServiceRefPtr &service) {
+    driver_->service_ = service;
+  }
+
+  VPNServiceRefPtr GetService() {
+    return driver_->service_;
+  }
+
+  void OnConnectTimeout() {
+    driver_->OnConnectTimeout();
+  }
+
+  void StartConnectTimeout() {
+    driver_->StartConnectTimeout();
+  }
+
+  bool IsConnectTimeoutStarted() {
+    return driver_->IsConnectTimeoutStarted();
+  }
+
   // Used to assert that a flag appears in the options.
   void ExpectInFlags(const vector<string> &options, const string &flag,
                      const string &value);
@@ -165,7 +189,7 @@
 }
 
 TEST_F(L2TPIPSecDriverTest, GetProviderType) {
-  EXPECT_EQ(flimflam::kProviderL2tpIpsec, driver_->GetProviderType());
+  EXPECT_EQ(flimflam::kProviderL2tpIpsec, GetProviderType());
 }
 
 TEST_F(L2TPIPSecDriverTest, Cleanup) {
@@ -452,11 +476,20 @@
 
 TEST_F(L2TPIPSecDriverTest, OnConnectionDisconnected) {
   driver_->service_ = service_;
-  EXPECT_CALL(*service_, SetState(Service::kStateFailure));
+  EXPECT_CALL(*service_, SetState(Service::kStateIdle));
   driver_->OnConnectionDisconnected();
   EXPECT_FALSE(driver_->service_);
 }
 
+TEST_F(L2TPIPSecDriverTest, OnConnectTimeout) {
+  StartConnectTimeout();
+  SetService(service_);
+  EXPECT_CALL(*service_, SetState(Service::kStateFailure));
+  OnConnectTimeout();
+  EXPECT_FALSE(GetService());
+  EXPECT_FALSE(IsConnectTimeoutStarted());
+}
+
 TEST_F(L2TPIPSecDriverTest, InitPropertyStore) {
   // Sanity test property store initialization.
   PropertyStore store;
diff --git a/mock_openvpn_management_server.h b/mock_openvpn_management_server.h
index f2ddaba..16e3053 100644
--- a/mock_openvpn_management_server.h
+++ b/mock_openvpn_management_server.h
@@ -2,8 +2,8 @@
 // Use of this source code is governed by a BSD-style license that can be
 // found in the LICENSE file.
 
-#ifndef SHILL_MOCK_OPENVPN_MANAGEMENT_SERVER_
-#define SHILL_MOCK_OPENVPN_MANAGEMENT_SERVER_
+#ifndef SHILL_MOCK_OPENVPN_MANAGEMENT_SERVER_H_
+#define SHILL_MOCK_OPENVPN_MANAGEMENT_SERVER_H_
 
 #include <gmock/gmock.h>
 
@@ -22,6 +22,7 @@
   MOCK_METHOD0(Stop, void());
   MOCK_METHOD0(ReleaseHold, void());
   MOCK_METHOD0(Hold, void());
+  MOCK_METHOD0(Restart, void());
 
  private:
   DISALLOW_COPY_AND_ASSIGN(MockOpenVPNManagementServer);
@@ -29,4 +30,4 @@
 
 }  // namespace shill
 
-#endif  // SHILL_MOCK_OPENVPN_MANAGEMENT_SERVER_
+#endif  // SHILL_MOCK_OPENVPN_MANAGEMENT_SERVER_H_
diff --git a/openvpn_driver.cc b/openvpn_driver.cc
index ca8aea4..5551baf 100644
--- a/openvpn_driver.cc
+++ b/openvpn_driver.cc
@@ -737,7 +737,19 @@
 }
 
 void OpenVPNDriver::OnConnectionDisconnected() {
-  LOG(ERROR) << "VPN connection disconnected.";
+  LOG(INFO) << "Underlying connection disconnected.";
+  // Restart the OpenVPN client forcing a reconnect attempt.
+  management_server_->Restart();
+  // Indicate reconnect state right away to drop the VPN connection and start
+  // the connect timeout. This ensures that any miscommunication between shill
+  // and openvpn will not lead to a permanently stale connectivity state. Note
+  // that a subsequent invocation of OnReconnecting due to a RECONNECTING
+  // message will essentially be a no-op.
+  OnReconnecting();
+}
+
+void OpenVPNDriver::OnConnectTimeout() {
+  VPNDriver::OnConnectTimeout();
   Cleanup(Service::kStateFailure);
 }
 
diff --git a/openvpn_driver.h b/openvpn_driver.h
index 2fcc2cd..dc85576 100644
--- a/openvpn_driver.h
+++ b/openvpn_driver.h
@@ -2,8 +2,8 @@
 // Use of this source code is governed by a BSD-style license that can be
 // found in the LICENSE file.
 
-#ifndef SHILL_OPENVPN_DRIVER_
-#define SHILL_OPENVPN_DRIVER_
+#ifndef SHILL_OPENVPN_DRIVER_H_
+#define SHILL_OPENVPN_DRIVER_H_
 
 #include <map>
 #include <string>
@@ -49,19 +49,6 @@
                 GLib *glib);
   virtual ~OpenVPNDriver();
 
-  // Inherited from VPNDriver. |Connect| initiates the VPN connection by
-  // creating a tunnel device. When the device index becomes available, this
-  // instance is notified through |ClaimInterface| and resumes the connection
-  // process by setting up and spawning an external 'openvpn' process. IP
-  // configuration settings are passed back from the external process through
-  // the |Notify| RPC service method.
-  virtual void Connect(const VPNServiceRefPtr &service, Error *error);
-  virtual bool ClaimInterface(const std::string &link_name,
-                              int interface_index);
-  virtual void Disconnect();
-  virtual void OnConnectionDisconnected();
-  virtual std::string GetProviderType() const;
-
   virtual void OnReconnecting();
 
   virtual void Cleanup(Service::ConnectState state);
@@ -76,6 +63,21 @@
                   const std::string &option,
                   std::vector<std::string> *options);
 
+ protected:
+  // Inherited from VPNDriver. |Connect| initiates the VPN connection by
+  // creating a tunnel device. When the device index becomes available, this
+  // instance is notified through |ClaimInterface| and resumes the connection
+  // process by setting up and spawning an external 'openvpn' process. IP
+  // configuration settings are passed back from the external process through
+  // the |Notify| RPC service method.
+  virtual void Connect(const VPNServiceRefPtr &service, Error *error);
+  virtual bool ClaimInterface(const std::string &link_name,
+                              int interface_index);
+  virtual void Disconnect();
+  virtual std::string GetProviderType() const;
+  virtual void OnConnectionDisconnected();
+  virtual void OnConnectTimeout();
+
  private:
   friend class OpenVPNDriverTest;
   FRIEND_TEST(OpenVPNDriverTest, ClaimInterface);
@@ -95,7 +97,6 @@
   FRIEND_TEST(OpenVPNDriverTest, InitPKCS11Options);
   FRIEND_TEST(OpenVPNDriverTest, Notify);
   FRIEND_TEST(OpenVPNDriverTest, NotifyFail);
-  FRIEND_TEST(OpenVPNDriverTest, OnConnectionDisconnected);
   FRIEND_TEST(OpenVPNDriverTest, OnDefaultServiceChanged);
   FRIEND_TEST(OpenVPNDriverTest, OnOpenVPNDied);
   FRIEND_TEST(OpenVPNDriverTest, OnReconnecting);
@@ -213,4 +214,4 @@
 
 }  // namespace shill
 
-#endif  // SHILL_OPENVPN_DRIVER_
+#endif  // SHILL_OPENVPN_DRIVER_H_
diff --git a/openvpn_driver_unittest.cc b/openvpn_driver_unittest.cc
index 8ad04c1..7a049d1 100644
--- a/openvpn_driver_unittest.cc
+++ b/openvpn_driver_unittest.cc
@@ -122,6 +122,34 @@
     return &driver_->sockets_;
   }
 
+  void SetDevice(const VPNRefPtr &device) {
+    driver_->device_ = device;
+  }
+
+  void SetService(const VPNServiceRefPtr &service) {
+    driver_->service_ = service;
+  }
+
+  VPNServiceRefPtr GetService() {
+    return driver_->service_;
+  }
+
+  void OnConnectionDisconnected() {
+    driver_->OnConnectionDisconnected();
+  }
+
+  void OnConnectTimeout() {
+    driver_->OnConnectTimeout();
+  }
+
+  void StartConnectTimeout() {
+    driver_->StartConnectTimeout();
+  }
+
+  bool IsConnectTimeoutStarted() {
+    return driver_->IsConnectTimeoutStarted();
+  }
+
   // Used to assert that a flag appears in the options.
   void ExpectInFlags(const vector<string> &options, const string &flag,
                      const string &value);
@@ -857,10 +885,22 @@
 }
 
 TEST_F(OpenVPNDriverTest, OnConnectionDisconnected) {
-  driver_->service_ = service_;
+  EXPECT_CALL(*management_server_, Restart());
+  SetDevice(device_);
+  SetService(service_);
+  EXPECT_CALL(*device_, OnDisconnected());
+  EXPECT_CALL(*service_, SetState(Service::kStateAssociating));
+  OnConnectionDisconnected();
+  EXPECT_TRUE(IsConnectTimeoutStarted());
+}
+
+TEST_F(OpenVPNDriverTest, OnConnectTimeout) {
+  StartConnectTimeout();
+  SetService(service_);
   EXPECT_CALL(*service_, SetState(Service::kStateFailure));
-  driver_->OnConnectionDisconnected();
-  EXPECT_FALSE(driver_->service_);
+  OnConnectTimeout();
+  EXPECT_FALSE(GetService());
+  EXPECT_FALSE(IsConnectTimeoutStarted());
 }
 
 TEST_F(OpenVPNDriverTest, OnReconnecting) {
diff --git a/openvpn_management_server.cc b/openvpn_management_server.cc
index 11724a4..ff1955a 100644
--- a/openvpn_management_server.cc
+++ b/openvpn_management_server.cc
@@ -140,6 +140,11 @@
   hold_release_ = false;
 }
 
+void OpenVPNManagementServer::Restart() {
+  LOG(INFO) << "Restart.";
+  SendSignal("SIGUSR1");
+}
+
 void OpenVPNManagementServer::OnReady(int fd) {
   SLOG(VPN, 2) << __func__ << "(" << fd << ")";
   connected_socket_ = sockets_->Accept(fd, NULL, NULL);
@@ -174,7 +179,8 @@
       !ProcessNeedPasswordMessage(message) &&
       !ProcessFailedPasswordMessage(message) &&
       !ProcessStateMessage(message) &&
-      !ProcessHoldMessage(message)) {
+      !ProcessHoldMessage(message) &&
+      !ProcessSuccessMessage(message)) {
     LOG(WARNING) << "OpenVPN management message ignored: " << message;
   }
 }
@@ -183,7 +189,7 @@
   if (!StartsWithASCII(message, ">INFO:", true)) {
     return false;
   }
-  LOG(INFO) << "Processing info message.";
+  LOG(INFO) << message;
   return true;
 }
 
@@ -309,10 +315,10 @@
   if (!StartsWithASCII(message, ">STATE:", true)) {
     return false;
   }
-  LOG(INFO) << "Processing state message.";
   vector<string> details;
   SplitString(message, ',', &details);
   if (details.size() > 1) {
+    LOG(INFO) << "Processing state message: " << details[1];
     if (details[1] == "RECONNECTING") {
       driver_->OnReconnecting();
     }
@@ -325,7 +331,7 @@
   if (!StartsWithASCII(message, ">HOLD:Waiting for hold release", true)) {
     return false;
   }
-  LOG(INFO) << "Processing hold message.";
+  LOG(INFO) << "Client waiting for hold release.";
   hold_waiting_ = true;
   if (hold_release_) {
     ReleaseHold();
@@ -333,6 +339,14 @@
   return true;
 }
 
+bool OpenVPNManagementServer::ProcessSuccessMessage(const string &message) {
+  if (!StartsWithASCII(message, "SUCCESS: ", true)) {
+    return false;
+  }
+  LOG(INFO) << message;
+  return true;
+}
+
 // static
 string OpenVPNManagementServer::EscapeToQuote(const string &str) {
   string escaped;
@@ -371,6 +385,11 @@
                     EscapeToQuote(password).c_str()));
 }
 
+void OpenVPNManagementServer::SendSignal(const string &signal) {
+  SLOG(VPN, 2) << __func__ << "(" << signal << ")";
+  Send(StringPrintf("signal %s\n", signal.c_str()));
+}
+
 void OpenVPNManagementServer::SendHoldRelease() {
   SLOG(VPN, 2) << __func__;
   Send("hold release\n");
diff --git a/openvpn_management_server.h b/openvpn_management_server.h
index 2d2851a..5e7ef2d 100644
--- a/openvpn_management_server.h
+++ b/openvpn_management_server.h
@@ -2,8 +2,8 @@
 // Use of this source code is governed by a BSD-style license that can be
 // found in the LICENSE file.
 
-#ifndef SHILL_OPENVPN_MANAGEMENT_SERVER_
-#define SHILL_OPENVPN_MANAGEMENT_SERVER_
+#ifndef SHILL_OPENVPN_MANAGEMENT_SERVER_H_
+#define SHILL_OPENVPN_MANAGEMENT_SERVER_H_
 
 #include <string>
 #include <vector>
@@ -45,6 +45,9 @@
   // existing connection, nor sends any commands to the openvpn client.
   virtual void Hold();
 
+  // Restarts openvpn causing a disconnect followed by a reconnect attempt.
+  virtual void Restart();
+
  private:
   friend class OpenVPNManagementServerTest;
   FRIEND_TEST(OpenVPNManagementServerTest, EscapeToQuote);
@@ -86,6 +89,7 @@
   void SendUsername(const std::string &tag, const std::string &username);
   void SendPassword(const std::string &tag, const std::string &password);
   void SendHoldRelease();
+  void SendSignal(const std::string &signal);
 
   void ProcessMessage(const std::string &message);
   bool ProcessInfoMessage(const std::string &message);
@@ -93,6 +97,7 @@
   bool ProcessFailedPasswordMessage(const std::string &message);
   bool ProcessStateMessage(const std::string &message);
   bool ProcessHoldMessage(const std::string &message);
+  bool ProcessSuccessMessage(const std::string &message);
 
   void PerformStaticChallenge(const std::string &tag);
   void PerformAuthentication(const std::string &tag);
@@ -128,4 +133,4 @@
 
 }  // namespace shill
 
-#endif  // SHILL_OPENVPN_MANAGEMENT_SERVER_
+#endif  // SHILL_OPENVPN_MANAGEMENT_SERVER_H_
diff --git a/openvpn_management_server_unittest.cc b/openvpn_management_server_unittest.cc
index 59d2f90..c26ffaa 100644
--- a/openvpn_management_server_unittest.cc
+++ b/openvpn_management_server_unittest.cc
@@ -41,6 +41,9 @@
 
   virtual ~OpenVPNManagementServerTest() {}
 
+ protected:
+  static const int kConnectedSocket;
+
   void SetSockets() { server_.sockets_ = &sockets_; }
   void SetDispatcher() { server_.dispatcher_ = &dispatcher_; }
   void ExpectNotStarted() { EXPECT_FALSE(server_.IsStarted()); }
@@ -84,6 +87,11 @@
     ExpectSend("hold release\n");
   }
 
+  void ExpectRestart() {
+    SetConnectedSocket();
+    ExpectSend("signal SIGUSR1\n");
+  }
+
   InputData CreateInputDataFromString(const string &str) {
     InputData data(
         reinterpret_cast<unsigned char *>(const_cast<char *>(str.data())),
@@ -91,8 +99,13 @@
     return data;
   }
 
- protected:
-  static const int kConnectedSocket;
+  void SendSignal(const string &signal) {
+    server_.SendSignal(signal);
+  }
+
+  bool ProcessSuccessMessage(const string &message) {
+    return server_.ProcessSuccessMessage(message);
+  }
 
   GLib glib_;
   MockOpenVPNDriver driver_;
@@ -207,7 +220,8 @@
         ">PASSWORD:Need 'User-Specific TPM Token FOO' ...\n"
         ">PASSWORD:Verification Failed: .\n"
         ">STATE:123,RECONNECTING,detail,...,...\n"
-        ">HOLD:Waiting for hold release";
+        ">HOLD:Waiting for hold release\n"
+        "SUCCESS: Hold released.";
     InputData data = CreateInputDataFromString(s);
     ExpectStaticChallengeResponse();
     ExpectPINResponse();
@@ -241,9 +255,14 @@
   server_.ProcessMessage(">STATE:123,RECONNECTING,detail,...,...");
 }
 
+TEST_F(OpenVPNManagementServerTest, ProcessSuccessMessage) {
+  EXPECT_FALSE(ProcessSuccessMessage("foo"));
+  EXPECT_TRUE(ProcessSuccessMessage("SUCCESS: foo"));
+}
+
 TEST_F(OpenVPNManagementServerTest, ProcessInfoMessage) {
   EXPECT_FALSE(server_.ProcessInfoMessage("foo"));
-  EXPECT_TRUE(server_.ProcessInfoMessage(">INFO:"));
+  EXPECT_TRUE(server_.ProcessInfoMessage(">INFO:foo"));
 }
 
 TEST_F(OpenVPNManagementServerTest, ProcessStateMessage) {
@@ -380,6 +399,17 @@
       server_.ProcessFailedPasswordMessage(">PASSWORD:Verification Failed: ."));
 }
 
+TEST_F(OpenVPNManagementServerTest, SendSignal) {
+  SetConnectedSocket();
+  ExpectSend("signal SIGUSR2\n");
+  SendSignal("SIGUSR2");
+}
+
+TEST_F(OpenVPNManagementServerTest, Restart) {
+  ExpectRestart();
+  server_.Restart();
+}
+
 TEST_F(OpenVPNManagementServerTest, SendHoldRelease) {
   ExpectHoldRelease();
   server_.SendHoldRelease();
diff --git a/vpn_driver.cc b/vpn_driver.cc
index 5175187..c6fbf72 100644
--- a/vpn_driver.cc
+++ b/vpn_driver.cc
@@ -181,9 +181,8 @@
 }
 
 void VPNDriver::OnConnectTimeout() {
-  LOG(ERROR) << "VPN connection timeout.";
+  LOG(INFO) << "VPN connect timeout.";
   StopConnectTimeout();
-  OnConnectionDisconnected();
 }
 
 }  // namespace shill
diff --git a/vpn_driver.h b/vpn_driver.h
index f539d0a..4200b7a 100644
--- a/vpn_driver.h
+++ b/vpn_driver.h
@@ -2,8 +2,8 @@
 // Use of this source code is governed by a BSD-style license that can be
 // found in the LICENSE file.
 
-#ifndef SHILL_VPN_DRIVER_
-#define SHILL_VPN_DRIVER_
+#ifndef SHILL_VPN_DRIVER_H_
+#define SHILL_VPN_DRIVER_H_
 
 #include <string>
 
@@ -32,9 +32,11 @@
                               int interface_index) = 0;
   virtual void Connect(const VPNServiceRefPtr &service, Error *error) = 0;
   virtual void Disconnect() = 0;
-  virtual void OnConnectionDisconnected() = 0;
   virtual std::string GetProviderType() const = 0;
 
+  // Invoked by VPNService when the underlying connection disconnects.
+  virtual void OnConnectionDisconnected() = 0;
+
   virtual void InitPropertyStore(PropertyStore *store);
 
   virtual bool Load(StoreInterface *storage, const std::string &storage_id);
@@ -76,6 +78,10 @@
   // Returns true if a connect timeout is scheduled, false otherwise.
   bool IsConnectTimeoutStarted() const;
 
+  // Called if a connect timeout scheduled through StartConnectTimeout
+  // fires. Cancels the timeout callback.
+  virtual void OnConnectTimeout();
+
  private:
   FRIEND_TEST(VPNDriverTest, ConnectTimeout);
 
@@ -86,10 +92,6 @@
   void SetMappedProperty(
       const size_t &index, const std::string &value, Error *error);
 
-  // Called if a connect timeout scheduled through StartConnectTimeout
-  // fires. Marks the callback as stopped and invokes OnConnectionDisconnected.
-  void OnConnectTimeout();
-
   base::WeakPtrFactory<VPNDriver> weak_ptr_factory_;
   EventDispatcher *dispatcher_;
   Manager *manager_;
@@ -105,4 +107,4 @@
 
 }  // namespace shill
 
-#endif  // SHILL_VPN_DRIVER_
+#endif  // SHILL_VPN_DRIVER_H_
diff --git a/vpn_driver_unittest.cc b/vpn_driver_unittest.cc
index fe221f3..af52ebb 100644
--- a/vpn_driver_unittest.cc
+++ b/vpn_driver_unittest.cc
@@ -291,7 +291,6 @@
   EXPECT_TRUE(driver_.IsConnectTimeoutStarted());
   driver_.dispatcher_ = NULL;
   driver_.StartConnectTimeout();  // Expect no crash.
-  EXPECT_CALL(driver_, OnConnectionDisconnected());
   dispatcher_.DispatchPendingEvents();
   EXPECT_TRUE(driver_.connect_timeout_callback_.IsCancelled());
   EXPECT_FALSE(driver_.IsConnectTimeoutStarted());