shill: vpn: Claim interface from DeviceInfo

Maintain a list in the VPNProvider of all services created by it.
When DeviceInfo alerts it of a new tunnel interface, VPNProvider
will find a service->driver() to accept this interface.  If not
found, delete the interface (cleaning up after crash).  Also,
in OpenVPNDriver, create a tunnel device and hold enough state
to be able to claim it later.

BUG=chromium-os:26841,chromium-os:27156,chromium-os:27158
TEST=New unit tests.  Manual: Ensure tunnel devices get cleaned up on a
real system.

Change-Id: Iaaa44dc26830a2e8bf5dfea00d165ab8c034e6e9
Reviewed-on: https://gerrit.chromium.org/gerrit/17191
Reviewed-by: Darin Petkov <petkov@chromium.org>
Commit-Ready: Paul Stewart <pstew@chromium.org>
Tested-by: Paul Stewart <pstew@chromium.org>
diff --git a/device_info.cc b/device_info.cc
index 75f4734..31d6c39 100644
--- a/device_info.cc
+++ b/device_info.cc
@@ -265,7 +265,13 @@
         // Tunnel devices are managed by the VPN code.
         VLOG(2) << "Tunnel link " << link_name << " at index " << dev_index
                 << " -- notifying VPNProvider.";
-        // TODO(pstew): Notify VPNProvider once that method exists.
+        if (manager_->vpn_provider()->OnDeviceInfoAvailable(link_name,
+                                                             dev_index)) {
+          // VPN does not know anything about this tunnel, it is probably
+          // left over from a previous instance and should not exist.
+          VLOG(2) << "Tunnel link is unused.  Deleting.";
+          DeleteInterface(dev_index);
+        }
         return;
       default:
         device = new DeviceStub(control_interface_, dispatcher_, metrics_,
@@ -342,7 +348,7 @@
   return true;
 }
 
-bool DeviceInfo::CreateTunnelInterface(string *interface_name) {
+bool DeviceInfo::CreateTunnelInterface(string *interface_name) const {
   int fd = HANDLE_EINTR(open(kTunDeviceName, O_RDWR));
   if (fd < 0) {
     PLOG(ERROR) << "failed to open " << kTunDeviceName;
@@ -368,7 +374,7 @@
   return true;
 }
 
-bool DeviceInfo::DeleteInterface(int interface_index) {
+bool DeviceInfo::DeleteInterface(int interface_index) const {
   return rtnl_handler_->RemoveInterface(interface_index);
 }
 
diff --git a/device_info.h b/device_info.h
index 3911a75..7bae36f 100644
--- a/device_info.h
+++ b/device_info.h
@@ -68,8 +68,8 @@
   virtual bool GetAddresses(int interface_index,
                             std::vector<AddressData> *addresses) const;
   virtual void FlushAddresses(int interface_index) const;
-  virtual bool CreateTunnelInterface(std::string *interface_name);
-  virtual bool DeleteInterface(int interface_index);
+  virtual bool CreateTunnelInterface(std::string *interface_name) const;
+  virtual bool DeleteInterface(int interface_index) const;
 
  private:
   friend class DeviceInfoTest;
diff --git a/manager.h b/manager.h
index d24db88..c76732f 100644
--- a/manager.h
+++ b/manager.h
@@ -128,6 +128,7 @@
 
   virtual DeviceInfo *device_info() { return &device_info_; }
   ModemInfo *modem_info() { return &modem_info_; }
+  VPNProvider *vpn_provider() { return &vpn_provider_; }
   PropertyStore *mutable_store() { return &store_; }
   virtual const PropertyStore &store() const { return store_; }
 
diff --git a/mock_device_info.h b/mock_device_info.h
index a6190c3..72352f7 100644
--- a/mock_device_info.h
+++ b/mock_device_info.h
@@ -36,6 +36,8 @@
   MOCK_CONST_METHOD2(GetAddresses, bool(int interface_index,
                                         std::vector<AddressData>* addresses));
   MOCK_CONST_METHOD1(FlushAddresses, void(int interface_index));
+  MOCK_CONST_METHOD1(CreateTunnelInterface,  bool(std::string *interface_name));
+  MOCK_CONST_METHOD1(DeleteInterface, bool(int interface_index));
 
  private:
   DISALLOW_COPY_AND_ASSIGN(MockDeviceInfo);
diff --git a/mock_vpn_driver.h b/mock_vpn_driver.h
index 3dd5e35..7b1421b 100644
--- a/mock_vpn_driver.h
+++ b/mock_vpn_driver.h
@@ -16,6 +16,8 @@
   MockVPNDriver();
   virtual ~MockVPNDriver();
 
+  MOCK_METHOD2(ClaimInterface, bool(const std::string &link_name,
+                                    int interface_index));
   MOCK_METHOD1(Connect, void(Error *error));
 };
 
diff --git a/openvpn_driver.cc b/openvpn_driver.cc
index 166e79c..e90ef20 100644
--- a/openvpn_driver.cc
+++ b/openvpn_driver.cc
@@ -7,6 +7,7 @@
 #include <base/logging.h>
 #include <chromeos/dbus/service_constants.h>
 
+#include "shill/device_info.h"
 #include "shill/error.h"
 #include "shill/rpc_task.h"
 
@@ -20,12 +21,28 @@
 }  // namespace {}
 
 OpenVPNDriver::OpenVPNDriver(ControlInterface *control,
+                             DeviceInfo *device_info,
                              const KeyValueStore &args)
     : control_(control),
-      args_(args) {}
+      device_info_(device_info),
+      args_(args),
+      interface_index_(-1) {}
 
 OpenVPNDriver::~OpenVPNDriver() {}
 
+bool OpenVPNDriver::ClaimInterface(const string &link_name,
+                                   int interface_index) {
+  if (link_name != tunnel_interface_) {
+    return false;
+  }
+
+  VLOG(2) << "Claiming " << link_name << " for OpenVPN tunnel";
+
+  // TODO(petkov): Could create a VPNDevice or DeviceStub here instead.
+  interface_index_ = interface_index;
+  return true;
+}
+
 void OpenVPNDriver::Connect(Error *error) {
   // TODO(petkov): Allocate rpc_task_.
   error->Populate(Error::kNotSupported);
@@ -49,8 +66,15 @@
   options->push_back("--persist-key");
   options->push_back("--persist-tun");
 
-  // TODO(petkov): Add "--dev <interface_name>". For OpenVPN, the interface will
-  // be the tunnel device (crosbug.com/26841).
+  if (tunnel_interface_.empty() &&
+      !device_info_->CreateTunnelInterface(&tunnel_interface_)) {
+    Error::PopulateAndLog(
+        error, Error::kInternalError, "Could not create tunnel interface.");
+    return;
+  }
+
+  options->push_back("--dev");
+  options->push_back(tunnel_interface_);
   options->push_back("--dev-type");
   options->push_back("tun");
   options->push_back("--syslog");
diff --git a/openvpn_driver.h b/openvpn_driver.h
index 9a07652..57d2272 100644
--- a/openvpn_driver.h
+++ b/openvpn_driver.h
@@ -17,21 +17,28 @@
 namespace shill {
 
 class ControlInterface;
+class DeviceInfo;
 class Error;
 class RPCTask;
+class DeviceStub;
 
 class OpenVPNDriver : public VPNDriver {
  public:
-  OpenVPNDriver(ControlInterface *control, const KeyValueStore &args);
+  OpenVPNDriver(ControlInterface *control,
+                DeviceInfo *device_info,
+                const KeyValueStore &args);
   virtual ~OpenVPNDriver();
 
   // Inherited from VPNDriver.
+  virtual bool ClaimInterface(const std::string &link_name,
+                              int interface_index);
   virtual void Connect(Error *error);
 
  private:
   friend class OpenVPNDriverTest;
   FRIEND_TEST(OpenVPNDriverTest, AppendFlag);
   FRIEND_TEST(OpenVPNDriverTest, AppendValueOption);
+  FRIEND_TEST(OpenVPNDriverTest, ClaimInterface);
   FRIEND_TEST(OpenVPNDriverTest, InitOptions);
   FRIEND_TEST(OpenVPNDriverTest, InitOptionsNoHost);
 
@@ -45,8 +52,11 @@
                   std::vector<std::string> *options);
 
   ControlInterface *control_;
+  DeviceInfo *device_info_;
   KeyValueStore args_;
   scoped_ptr<RPCTask> rpc_task_;
+  std::string tunnel_interface_;
+  int interface_index_;
 
   DISALLOW_COPY_AND_ASSIGN(OpenVPNDriver);
 };
diff --git a/openvpn_driver_unittest.cc b/openvpn_driver_unittest.cc
index cedc0f7..6c5800e 100644
--- a/openvpn_driver_unittest.cc
+++ b/openvpn_driver_unittest.cc
@@ -11,12 +11,17 @@
 
 #include "shill/error.h"
 #include "shill/mock_adaptors.h"
+#include "shill/mock_device_info.h"
 #include "shill/nice_mock_control.h"
 #include "shill/rpc_task.h"
 
 using std::map;
 using std::string;
 using std::vector;
+using testing::_;
+using testing::DoAll;
+using testing::Return;
+using testing::SetArgumentPointee;
 
 namespace shill {
 
@@ -24,7 +29,8 @@
                           public RPCTaskDelegate {
  public:
   OpenVPNDriverTest()
-      : driver_(&control_, args_) {}
+      : device_info_(&control_, NULL, NULL, NULL),
+        driver_(&control_, &device_info_, args_) {}
 
   virtual ~OpenVPNDriverTest() {}
 
@@ -43,7 +49,13 @@
     driver_.args_ = args_;
   }
 
+  // Used to assert that a flag appears in the options.
+  void ExpectInFlags(const vector<string> &options, const string &flag,
+                     const string &value);
+
+
   NiceMockControl control_;
+  MockDeviceInfo device_info_;
   KeyValueStore args_;
   OpenVPNDriver driver_;
 };
@@ -58,6 +70,23 @@
 void OpenVPNDriverTest::Notify(const string &/*reason*/,
                                const map<string, string> &/*dict*/) {}
 
+void OpenVPNDriverTest::ExpectInFlags(const vector<string> &options,
+                                      const string &flag,
+                                      const string &value) {
+  vector<string>::const_iterator it =
+      std::find(options.begin(), options.end(), flag);
+
+  EXPECT_TRUE(it != options.end());
+  if (it != options.end())
+    return;                             // Don't crash below.
+  it++;
+  EXPECT_TRUE(it != options.end());
+  if (it != options.end())
+    return;                             // Don't crash below.
+  EXPECT_EQ(value, *it);
+}
+
+
 TEST_F(OpenVPNDriverTest, Connect) {
   Error error;
   driver_.Connect(&error);
@@ -79,15 +108,17 @@
   driver_.rpc_task_.reset(new RPCTask(&control_, this));
   Error error;
   vector<string> options;
+  const string kInterfaceName("tun0");
+  EXPECT_CALL(device_info_, CreateTunnelInterface(_))
+      .WillOnce(DoAll(SetArgumentPointee<0>(kInterfaceName), Return(true)));
   driver_.InitOptions(&options, &error);
   EXPECT_TRUE(error.IsSuccess());
   EXPECT_EQ("--client", options[0]);
-  EXPECT_EQ("--remote", options[2]);
-  EXPECT_EQ(kHost, options[3]);
+  ExpectInFlags(options, "--remote", kHost);
+  ExpectInFlags(options, "CONNMAN_PATH", RPCTaskMockAdaptor::kRpcId);
+  ExpectInFlags(options, "--dev", kInterfaceName);
   EXPECT_EQ("openvpn", options.back());
-  EXPECT_TRUE(
-      std::find(options.begin(), options.end(), RPCTaskMockAdaptor::kRpcId) !=
-      options.end());
+  EXPECT_EQ(kInterfaceName, driver_.tunnel_interface_);
 }
 
 TEST_F(OpenVPNDriverTest, AppendValueOption) {
@@ -127,4 +158,15 @@
   EXPECT_EQ(kOption2, options[1]);
 }
 
+TEST_F(OpenVPNDriverTest, ClaimInterface) {
+  const string kInterfaceName("tun0");
+  driver_.tunnel_interface_ = kInterfaceName;
+  const int kInterfaceIndex = 1122;
+  EXPECT_FALSE(driver_.ClaimInterface(kInterfaceName + "XXX", kInterfaceIndex));
+  EXPECT_EQ(-1, driver_.interface_index_);
+
+  EXPECT_TRUE(driver_.ClaimInterface(kInterfaceName, kInterfaceIndex));
+  EXPECT_EQ(kInterfaceIndex, driver_.interface_index_);
+}
+
 }  // namespace shill
diff --git a/vpn_driver.h b/vpn_driver.h
index 95cacbc..27fcfae 100644
--- a/vpn_driver.h
+++ b/vpn_driver.h
@@ -5,6 +5,8 @@
 #ifndef SHILL_VPN_DRIVER_
 #define SHILL_VPN_DRIVER_
 
+#include <string>
+
 #include <base/basictypes.h>
 
 namespace shill {
@@ -15,6 +17,8 @@
  public:
   virtual ~VPNDriver() {}
 
+  virtual bool ClaimInterface(const std::string &link_name,
+                              int interface_index) = 0;
   virtual void Connect(Error *error) = 0;
 };
 
diff --git a/vpn_provider.cc b/vpn_provider.cc
index 84903f7..0adffcb 100644
--- a/vpn_provider.cc
+++ b/vpn_provider.cc
@@ -8,10 +8,12 @@
 #include <chromeos/dbus/service_constants.h>
 
 #include "shill/error.h"
+#include "shill/manager.h"
 #include "shill/openvpn_driver.h"
 #include "shill/vpn_service.h"
 
 using std::string;
+using std::vector;
 
 namespace shill {
 
@@ -38,17 +40,38 @@
         error, Error::kNotSupported, "Missing VPN type property.");
     return NULL;
   }
+
   const string &type = args.GetString(flimflam::kProviderTypeProperty);
   scoped_ptr<VPNDriver> driver;
   if (type == flimflam::kProviderOpenVpn) {
-    driver.reset(new OpenVPNDriver(control_interface_, args));
+    driver.reset(new OpenVPNDriver(control_interface_,
+                                   manager_->device_info(), args));
   } else {
     Error::PopulateAndLog(
         error, Error::kNotSupported, "Unsupported VPN type: " + type);
     return NULL;
   }
-  return new VPNService(
-      control_interface_, dispatcher_, metrics_, manager_, driver.release());
+
+  services_.push_back(
+      new VPNService(
+          control_interface_, dispatcher_, metrics_, manager_,
+          driver.release()));
+
+  return services_.back();
+
+}
+
+bool VPNProvider::OnDeviceInfoAvailable(const string &link_name,
+                                        int interface_index) {
+  for (vector<VPNServiceRefPtr>::const_iterator it = services_.begin();
+       it != services_.end();
+       ++it) {
+    if ((*it)->driver()->ClaimInterface(link_name, interface_index)) {
+      return true;
+    }
+  }
+
+  return false;
 }
 
 }  // namespace shill
diff --git a/vpn_provider.h b/vpn_provider.h
index 2d65b0c..f4b532c 100644
--- a/vpn_provider.h
+++ b/vpn_provider.h
@@ -5,7 +5,11 @@
 #ifndef SHILL_VPN_PROVIDER_
 #define SHILL_VPN_PROVIDER_
 
+#include <string>
+#include <vector>
+
 #include <base/basictypes.h>
+#include <gtest/gtest_prod.h>  // for FRIEND_TEST
 
 #include "shill/refptr_types.h"
 
@@ -31,11 +35,18 @@
 
   VPNServiceRefPtr GetService(const KeyValueStore &args, Error *error);
 
+  // Offers an unclaimed interface to VPN services.  Returns true if this
+  // device has been accepted by a service.
+  bool OnDeviceInfoAvailable(const std::string &link_name, int interface_index);
+
  private:
+  FRIEND_TEST(VPNProviderTest, OnDeviceInfoAvailable);
+
   ControlInterface *control_interface_;
   EventDispatcher *dispatcher_;
   Metrics *metrics_;
   Manager *manager_;
+  std::vector<VPNServiceRefPtr> services_;
 
   DISALLOW_COPY_AND_ASSIGN(VPNProvider);
 };
diff --git a/vpn_provider_unittest.cc b/vpn_provider_unittest.cc
index e8f96b4..e35f5cf 100644
--- a/vpn_provider_unittest.cc
+++ b/vpn_provider_unittest.cc
@@ -10,21 +10,29 @@
 #include "shill/error.h"
 #include "shill/nice_mock_control.h"
 #include "shill/mock_adaptors.h"
+#include "shill/mock_manager.h"
 #include "shill/mock_metrics.h"
+#include "shill/mock_vpn_driver.h"
 #include "shill/vpn_service.h"
 
+using std::string;
+using testing::_;
+using testing::Return;
+
 namespace shill {
 
 class VPNProviderTest : public testing::Test {
  public:
   VPNProviderTest()
-      : provider_(&control_, NULL, &metrics_, NULL) {}
+      : manager_(&control_, NULL, &metrics_, NULL),
+        provider_(&control_, NULL, &metrics_, &manager_) {}
 
   virtual ~VPNProviderTest() {}
 
  protected:
   NiceMockControl control_;
   MockMetrics metrics_;
+  MockManager manager_;
   VPNProvider provider_;
 };
 
@@ -57,4 +65,34 @@
   EXPECT_TRUE(service);
 }
 
+TEST_F(VPNProviderTest, OnDeviceInfoAvailable) {
+  const string kInterfaceName("tun0");
+  const int kInterfaceIndex = 1;
+
+  scoped_ptr<MockVPNDriver> bad_driver(new MockVPNDriver());
+  EXPECT_CALL(*bad_driver.get(), ClaimInterface(_, _))
+      .Times(2)
+      .WillRepeatedly(Return(false));
+  provider_.services_.push_back(
+      new VPNService(&control_, NULL, &metrics_, NULL, bad_driver.release()));
+
+  EXPECT_FALSE(provider_.OnDeviceInfoAvailable(kInterfaceName,
+                                               kInterfaceIndex));
+
+  scoped_ptr<MockVPNDriver> good_driver(new MockVPNDriver());
+  EXPECT_CALL(*good_driver.get(), ClaimInterface(_, _))
+      .WillOnce(Return(true));
+  provider_.services_.push_back(
+      new VPNService(&control_, NULL, &metrics_, NULL, good_driver.release()));
+
+  scoped_ptr<MockVPNDriver> dup_driver(new MockVPNDriver());
+  EXPECT_CALL(*dup_driver.get(), ClaimInterface(_, _))
+      .Times(0);
+  provider_.services_.push_back(
+      new VPNService(&control_, NULL, &metrics_, NULL, dup_driver.release()));
+
+  EXPECT_TRUE(provider_.OnDeviceInfoAvailable(kInterfaceName, kInterfaceIndex));
+  provider_.services_.clear();
+}
+
 }  // namespace shill
diff --git a/vpn_service.h b/vpn_service.h
index 485867a..b768763 100644
--- a/vpn_service.h
+++ b/vpn_service.h
@@ -25,8 +25,8 @@
 
   // Inherited from Service.
   virtual void Connect(Error *error);
-
   virtual std::string GetStorageIdentifier() const;
+  VPNDriver *driver() { return driver_.get(); }
 
  private:
   FRIEND_TEST(VPNServiceTest, GetDeviceRpcId);