shill: Add message-specific callbacks to the nl80211 code.

This allows users to create a one-off nl80211 message, install a
callback for that message, and then send the message to the kernel.
When the response is issued, the callback will be called and then
removed from the system.

BUG=chromium-os:35129
TEST=Manual and unit tests (including new ones).

Change-Id: I06bf8d16629f3eac226209b49827289e7a9cdca3
Reviewed-on: https://gerrit.chromium.org/gerrit/36904
Reviewed-by: Christopher Wiley <wiley@chromium.org>
Commit-Ready: Wade Guthrie <wdg@chromium.org>
Tested-by: Wade Guthrie <wdg@chromium.org>
diff --git a/callback80211_object.cc b/callback80211_object.cc
index 6c4d617..3133100 100644
--- a/callback80211_object.cc
+++ b/callback80211_object.cc
@@ -24,7 +24,10 @@
 namespace shill {
 
 Callback80211Object::Callback80211Object(Config80211 *config80211)
-    : config80211_(config80211), weak_ptr_factory_(this) {
+    : weak_ptr_factory_(this),
+      callback_(Bind(&Callback80211Object::Config80211MessageCallback,
+                     weak_ptr_factory_.GetWeakPtr())),
+      config80211_(config80211) {
 }
 
 Callback80211Object::~Callback80211Object() {
@@ -54,8 +57,6 @@
 
 bool Callback80211Object::InstallAsBroadcastCallback() {
   if (config80211_) {
-    callback_ = Bind(&Callback80211Object::Config80211MessageCallback,
-                     weak_ptr_factory_.GetWeakPtr());
     return config80211_->AddBroadcastCallback(callback_);
   }
   return false;
diff --git a/callback80211_object.h b/callback80211_object.h
index 0037e98..85a7875 100644
--- a/callback80211_object.h
+++ b/callback80211_object.h
@@ -39,12 +39,6 @@
   void SetName(std::string name) { name_ = name;}
   const std::string &GetName() { return name_; }
 
- protected:
-  // This is the closure that contains *|this| and a pointer to the message
-  // handling callback, below.  It is used in |DeinstallAsCallback|.
-  Config80211::Callback callback_;
-  Config80211 *config80211_;
-
  private:
   // TODO(wdg): remove debug code:
   std::string name_;
@@ -57,6 +51,13 @@
   // callback.
   base::WeakPtrFactory<Callback80211Object> weak_ptr_factory_;
 
+ protected:
+  // This is the closure that contains *|this| and a pointer to the message
+  // handling callback, below.  It is used in |DeinstallAsCallback|.
+  Config80211::Callback callback_;
+  Config80211 *config80211_;
+
+ private:
   DISALLOW_COPY_AND_ASSIGN(Callback80211Object);
 };
 
diff --git a/config80211.cc b/config80211.cc
index 4dc9975..fa47642 100644
--- a/config80211.cc
+++ b/config80211.cc
@@ -136,6 +136,38 @@
   broadcast_callbacks_.clear();
 }
 
+bool Config80211::SetMessageCallback(const KernelBoundNlMessage &message,
+                                     const Callback &callback) {
+  LOG(INFO) << "Setting callback for message " << message.GetId();
+  uint32_t message_id = message.GetId();
+  if (message_id == 0) {
+    LOG(ERROR) << "Message ID 0 is reserved for broadcast callbacks.";
+    return false;
+  }
+
+  if (ContainsKey(message_callbacks_, message_id)) {
+    LOG(ERROR) << "Already a callback assigned for id " << message_id;
+    return false;
+  }
+
+  if (callback.is_null()) {
+    LOG(ERROR) << "Trying to add a NULL callback for id " << message_id;
+    return false;
+  }
+
+  message_callbacks_[message_id] = callback;
+  return true;
+}
+
+bool Config80211::UnsetMessageCallbackById(uint32_t message_id) {
+  if (!ContainsKey(message_callbacks_, message_id)) {
+    LOG(WARNING) << "No callback assigned for id " << message_id;
+    return false;
+  }
+  message_callbacks_.erase(message_id);
+  return true;
+}
+
 // static
 bool Config80211::GetEventTypeString(EventType type, string *value) {
   if (!value) {
@@ -243,20 +275,29 @@
     SLOG(WiFi, 3) << __func__ << "(msg:NULL)";
   } else {
     SLOG(WiFi, 3) << __func__ << "(msg:" << msg->nlmsg_seq << ")";
-    list<Callback>::iterator i = broadcast_callbacks_.begin();
-    while (i != broadcast_callbacks_.end()) {
-      SLOG(WiFi, 3) << "    " << __func__ << " - found callback";
-      if (i->is_null()) {
-        // How did this get in here?
-        LOG(WARNING) << "Removing NULL callback from list";
-        list<Callback>::iterator temp = i;
-        ++temp;
-        broadcast_callbacks_.erase(i);
-        i = temp;
+    // Call (then erase) any message-specific callback.
+    if (ContainsKey(message_callbacks_, message->GetId())) {
+      SLOG(WiFi, 3) << "found message-specific callback";
+      if (message_callbacks_[message->GetId()].is_null()) {
+        LOG(ERROR) << "Callback exists but is NULL for ID " << message->GetId();
       } else {
-        SLOG(WiFi, 3) << "      " << __func__ << " - calling callback";
-        i->Run(*message);
-        ++i;
+        message_callbacks_[message->GetId()].Run(*message);
+      }
+      UnsetMessageCallbackById(message->GetId());
+    } else {
+      list<Callback>::iterator i = broadcast_callbacks_.begin();
+      while (i != broadcast_callbacks_.end()) {
+        SLOG(WiFi, 3) << "found a broadcast callback";
+        if (i->is_null()) {
+          list<Callback>::iterator temp = i;
+          ++temp;
+          broadcast_callbacks_.erase(i);
+          i = temp;
+        } else {
+          SLOG(WiFi, 3) << "      " << __func__ << " - calling callback";
+          i->Run(*message);
+          ++i;
+        }
       }
     }
   }
diff --git a/config80211.h b/config80211.h
index 4e13aa5..f32be1a 100644
--- a/config80211.h
+++ b/config80211.h
@@ -61,6 +61,8 @@
 #ifndef SHILL_CONFIG80211_H_
 #define SHILL_CONFIG80211_H_
 
+#include <gtest/gtest_prod.h>  // for FRIEND_TEST
+
 #include <iomanip>
 #include <list>
 #include <map>
@@ -132,10 +134,16 @@
   // Uninstall all Config80211 broadcast Callbacks.
   void ClearBroadcastCallbacks();
 
-  // TODO(wdg): Add 'SendMessage(KernelBoundNlMessage *message,
-  //                             Config80211::Callback *callback);
-  // Config80211 needs to handle out-of-order responses using a
-  // map <sequence_number, callback> to match callback with message.
+  // Install a Config80211 Callback to handle the response to a specific
+  // message.
+  // TODO(wdg): Eventually, this should also include a timeout and a callback
+  // to call in case of timeout.
+  bool SetMessageCallback(const KernelBoundNlMessage &message,
+                          const Callback &callback);
+
+  // Uninstall a Config80211 Callback for a specific message using the
+  // message's sequence number.
+  bool UnsetMessageCallbackById(uint32_t sequence_number);
 
   // Return a string corresponding to the passed-in EventType.
   static bool GetEventTypeString(EventType type, std::string *value);
@@ -155,8 +163,11 @@
 
  private:
   friend class Config80211Test;
+  FRIEND_TEST(Config80211Test, BroadcastCallbackTest);
+  FRIEND_TEST(Config80211Test, MessageCallbackTest);
   typedef std::map<EventType, std::string> EventTypeStrings;
   typedef std::set<EventType> SubscribedEvents;
+  typedef std::map<uint32_t, Callback> MessageCallbacks;
 
   // Sign-up to receive and log multicast events of a specific type (assumes
   // wifi is up).
@@ -187,8 +198,8 @@
   // User-supplied callback object when _it_ gets called to read libnl data.
   std::list<Callback> broadcast_callbacks_;
 
-  // TODO(wdg): implement the following.
-  // std::map<uint32_t, Callback> message_callback_;
+  // Message-specific callbacks, mapped by message ID.
+  MessageCallbacks  message_callbacks_;
 
   static EventTypeStrings *event_types_;
 
diff --git a/config80211_unittest.cc b/config80211_unittest.cc
index 4bcf6f2..9bae303 100644
--- a/config80211_unittest.cc
+++ b/config80211_unittest.cc
@@ -18,9 +18,13 @@
 #include <base/bind.h>
 #include <gmock/gmock.h>
 #include <gtest/gtest.h>
+#include <net/if.h>
 #include <netlink/attr.h>
+#include <netlink/genl/genl.h>
+#include <netlink/msg.h>
 #include <netlink/netlink.h>
 
+#include "shill/kernel_bound_nlmessage.h"
 #include "shill/mock_callback80211_object.h"
 #include "shill/mock_nl80211_socket.h"
 #include "shill/nl80211_socket.h"
@@ -32,6 +36,7 @@
 using std::string;
 using std::vector;
 using testing::_;
+using testing::Invoke;
 using testing::Return;
 using testing::Test;
 
@@ -339,6 +344,8 @@
 
 }  // namespace
 
+unsigned int MockNl80211Socket::number_ = 0;
+
 class Config80211Test : public Test {
  public:
   Config80211Test() : config80211_(Config80211::GetInstance()) {}
@@ -390,6 +397,8 @@
   EXPECT_CALL(socket_, AddGroupMembership(_)).Times(0);
   EXPECT_CALL(socket_, DisableSequenceChecking()).Times(0);
   EXPECT_CALL(socket_, SetNetlinkCallback(_, _)).Times(0);
+  EXPECT_CALL(socket_, GetSequenceNumber())
+      .WillRepeatedly(Invoke(&MockNl80211Socket::GetNextNumber));
 
   EXPECT_TRUE(config80211_->AddBroadcastCallback(
       callback_object.GetCallback()));
@@ -439,6 +448,129 @@
   config80211_->SetWifiState(Config80211::kWifiUp);
 }
 
+TEST_F(Config80211Test, BroadcastCallbackTest) {
+  SetupConfig80211Object();
+
+  nlmsghdr *message = const_cast<nlmsghdr *>(
+        reinterpret_cast<const nlmsghdr *>(kNL80211_CMD_DISCONNECT));
+
+  MockCallback80211 callback1(config80211_);
+  MockCallback80211 callback2(config80211_);
+  EXPECT_CALL(socket_, GetSequenceNumber())
+      .WillRepeatedly(Invoke(&MockNl80211Socket::GetNextNumber));
+
+  // Simple, 1 callback, case.
+  EXPECT_CALL(callback1, Config80211MessageCallback(_)).Times(1);
+  EXPECT_TRUE(callback1.InstallAsBroadcastCallback());
+  config80211_->OnNlMessageReceived(message);
+
+  // Add a second callback.
+  EXPECT_CALL(callback1, Config80211MessageCallback(_)).Times(1);
+  EXPECT_CALL(callback2, Config80211MessageCallback(_)).Times(1);
+  EXPECT_TRUE(callback2.InstallAsBroadcastCallback());
+  config80211_->OnNlMessageReceived(message);
+
+  // Verify that a callback can't be added twice.
+  EXPECT_CALL(callback1, Config80211MessageCallback(_)).Times(1);
+  EXPECT_CALL(callback2, Config80211MessageCallback(_)).Times(1);
+  EXPECT_FALSE(callback1.InstallAsBroadcastCallback());
+  config80211_->OnNlMessageReceived(message);
+
+  // Check that we can remove a callback.
+  EXPECT_CALL(callback1, Config80211MessageCallback(_)).Times(0);
+  EXPECT_CALL(callback2, Config80211MessageCallback(_)).Times(1);
+  EXPECT_TRUE(callback1.DeinstallAsCallback());
+  config80211_->OnNlMessageReceived(message);
+
+  // Check that re-adding the callback goes smoothly.
+  EXPECT_CALL(callback1, Config80211MessageCallback(_)).Times(1);
+  EXPECT_CALL(callback2, Config80211MessageCallback(_)).Times(1);
+  EXPECT_TRUE(callback1.InstallAsBroadcastCallback());
+  config80211_->OnNlMessageReceived(message);
+
+  // Check that ClearBroadcastCallbacks works.
+  config80211_->ClearBroadcastCallbacks();
+  EXPECT_CALL(callback1, Config80211MessageCallback(_)).Times(0);
+  EXPECT_CALL(callback2, Config80211MessageCallback(_)).Times(0);
+  config80211_->OnNlMessageReceived(message);
+}
+
+TEST_F(Config80211Test, MessageCallbackTest) {
+  // Setup.
+  SetupConfig80211Object();
+
+  EXPECT_CALL(socket_, GetSequenceNumber())
+      .WillRepeatedly(Invoke(&MockNl80211Socket::GetNextNumber));
+
+  MockCallback80211 callback_broadcast(config80211_);
+  EXPECT_TRUE(callback_broadcast.InstallAsBroadcastCallback());
+
+  KernelBoundNlMessage sent_message_1;
+  MockCallback80211 callback_sent_1(config80211_);
+  EXPECT_TRUE(sent_message_1.Init());
+  EXPECT_TRUE(sent_message_1.AddNetlinkHeader(&socket_, 0, NL_AUTO_SEQ, 0, 0, 0,
+                                              CTRL_CMD_GETFAMILY, 0));
+  LOG(INFO) << "Message 1 id:" << sent_message_1.GetId();
+
+  KernelBoundNlMessage sent_message_2;
+  MockCallback80211 callback_sent_2(config80211_);
+  EXPECT_TRUE(sent_message_2.Init());
+  EXPECT_TRUE(sent_message_2.AddNetlinkHeader(&socket_, 0, NL_AUTO_SEQ, 0, 0, 0,
+                                              CTRL_CMD_GETFAMILY, 0));
+  LOG(INFO) << "Message 2 id:" << sent_message_2.GetId();
+
+  // This is more testing the test code than the code, itself.
+  EXPECT_NE(sent_message_1.GetId(), sent_message_2.GetId());
+
+  // Set up the received message as a response to sent_message_1.
+  scoped_array<unsigned char> message_memory(
+      new unsigned char[sizeof(kNL80211_CMD_DISCONNECT)]);
+  memcpy(message_memory.get(), kNL80211_CMD_DISCONNECT,
+         sizeof(kNL80211_CMD_DISCONNECT));
+  nlmsghdr *received_message =
+        reinterpret_cast<nlmsghdr *>(message_memory.get());
+  received_message->nlmsg_seq = sent_message_1.GetId();
+
+  // Now, we can start the actual test...
+
+  // Verify that generic callback gets called for a message when no
+  // message-specific callback has been installed.
+  EXPECT_CALL(callback_broadcast, Config80211MessageCallback(_)).Times(1);
+  config80211_->OnNlMessageReceived(received_message);
+
+  // Install message-based callback; verify that message callback gets called.
+  EXPECT_TRUE(config80211_->SetMessageCallback(sent_message_1,
+                                               callback_sent_1.GetCallback()));
+  EXPECT_CALL(callback_sent_1, Config80211MessageCallback(_)).Times(1);
+  config80211_->OnNlMessageReceived(received_message);
+
+  // Verify that broadcast callback is called for the message after the
+  // message-specific callback is called once.
+  EXPECT_CALL(callback_broadcast, Config80211MessageCallback(_)).Times(1);
+  config80211_->OnNlMessageReceived(received_message);
+
+  // Install and then uninstall message-specific callback; verify broadcast
+  // callback is called on message receipt.
+  EXPECT_TRUE(config80211_->SetMessageCallback(sent_message_1,
+                                               callback_sent_1.GetCallback()));
+  EXPECT_TRUE(config80211_->UnsetMessageCallbackById(sent_message_1.GetId()));
+  EXPECT_CALL(callback_broadcast, Config80211MessageCallback(_)).Times(1);
+  config80211_->OnNlMessageReceived(received_message);
+
+  // Install callback for different message; verify that broadcast callback is
+  // called for _this_ message.
+  EXPECT_TRUE(config80211_->SetMessageCallback(sent_message_2,
+                                               callback_sent_2.GetCallback()));
+  EXPECT_CALL(callback_broadcast, Config80211MessageCallback(_)).Times(1);
+  config80211_->OnNlMessageReceived(received_message);
+
+  // Change the ID for the message to that of the second callback; verify that
+  // the appropriate callback is called for _that_ message.
+  received_message->nlmsg_seq = sent_message_2.GetId();
+  EXPECT_CALL(callback_sent_2, Config80211MessageCallback(_)).Times(1);
+  config80211_->OnNlMessageReceived(received_message);
+}
+
 TEST_F(Config80211Test, NL80211_CMD_TRIGGER_SCAN) {
   UserBoundNlMessage *message = UserBoundNlMessageFactory::CreateMessage(
       const_cast<nlmsghdr *>(
diff --git a/kernel_bound_nlmessage.cc b/kernel_bound_nlmessage.cc
index 9b35ff3..ae8d8ad 100644
--- a/kernel_bound_nlmessage.cc
+++ b/kernel_bound_nlmessage.cc
@@ -50,7 +50,8 @@
   return header->nlmsg_seq;
 }
 
-bool KernelBoundNlMessage::AddNetlinkHeader(uint32_t port, uint32_t seq,
+bool KernelBoundNlMessage::AddNetlinkHeader(NetlinkSocket *socket,
+                                            uint32_t port, uint32_t seq,
                                             int family_id, int hdrlen,
                                             int flags, uint8_t cmd,
                                             uint8_t version) {
@@ -76,6 +77,12 @@
     return false;
   }
 
+  // Manually set the sequence number if it's zero.
+  struct nlmsghdr *header = nlmsg_hdr(message_);
+  if (header != 0 && seq == NL_AUTO_SEQ && header->nlmsg_seq == 0) {
+    header->nlmsg_seq = socket->GetSequenceNumber();
+  }
+
   return true;
 }
 
@@ -102,12 +109,6 @@
     return -1;
   }
 
-  // Manually set the sequence number -- seems to work.
-  struct nlmsghdr *header = nlmsg_hdr(message_);
-  if (header != 0) {
-    header->nlmsg_seq = socket->GetSequenceNumber();
-  }
-
   // Complete AND SEND a message.
   int result = nl_send_auto_complete(
       const_cast<struct nl_sock *>(socket->GetConstNlSock()), message_);
diff --git a/kernel_bound_nlmessage.h b/kernel_bound_nlmessage.h
index b303f51..d3469f2 100644
--- a/kernel_bound_nlmessage.h
+++ b/kernel_bound_nlmessage.h
@@ -51,8 +51,9 @@
   uint32_t GetId() const;
 
   // Add a netlink header to the message.
-  bool AddNetlinkHeader(uint32_t port, uint32_t seq, int family_id, int hdrlen,
-                        int flags, uint8_t cmd, uint8_t version);
+  bool AddNetlinkHeader(NetlinkSocket *socket, uint32_t port, uint32_t seq,
+                        int family_id, int hdrlen, int flags, uint8_t cmd,
+                        uint8_t version);
 
   // Add a netlink attribute to the message.
   int AddAttribute(int attrtype, int attrlen, const void *data);
diff --git a/mock_callback80211_object.h b/mock_callback80211_object.h
index 067e6e5..d10c6a5 100644
--- a/mock_callback80211_object.h
+++ b/mock_callback80211_object.h
@@ -8,6 +8,7 @@
 #include <gmock/gmock.h>
 
 #include "shill/callback80211_object.h"
+#include "shill/config80211.h"
 
 
 namespace shill {
@@ -17,6 +18,8 @@
 
 class MockCallback80211 : public Callback80211Object {
  public:
+  Config80211::Callback &GetCallback() { return callback_; }
+
   explicit MockCallback80211(Config80211 *config80211)
       : Callback80211Object(config80211) {}
   MOCK_METHOD1(Config80211MessageCallback, void(const UserBoundNlMessage &msg));
diff --git a/mock_nl80211_socket.h b/mock_nl80211_socket.h
index 1da26d1..6c1b0f9 100644
--- a/mock_nl80211_socket.h
+++ b/mock_nl80211_socket.h
@@ -29,7 +29,17 @@
   using Nl80211Socket::SetNetlinkCallback;
   MOCK_METHOD2(SetNetlinkCallback, bool(nl_recvmsg_msg_cb_t on_netlink_data,
                                         void *callback_parameter));
+  MOCK_METHOD0(GetSequenceNumber, unsigned int());
+
+  static unsigned int GetNextNumber() {
+    if (++number_ == 0)
+      number_ = 1;
+    return number_;
+  }
+
  private:
+  static unsigned int number_;
+
   DISALLOW_COPY_AND_ASSIGN(MockNl80211Socket);
 };
 
diff --git a/netlink_socket.cc b/netlink_socket.cc
index 9d8caa3..15fa99c 100644
--- a/netlink_socket.cc
+++ b/netlink_socket.cc
@@ -168,6 +168,18 @@
   return true;
 }
 
+unsigned int NetlinkSocket::GetSequenceNumber() {
+  unsigned int number = nl_socket_use_seq(nl_sock_);
+  if (number == 0) {
+    number = nl_socket_use_seq(nl_sock_);
+  }
+  if (number == 0) {
+    LOG(WARNING) << "Couldn't get non-zero sequence number";
+    number = 1;
+  }
+  return number;
+}
+
 bool NetlinkSocket::SetNetlinkCallback(nl_recvmsg_msg_cb_t on_netlink_data,
                                        void *callback_parameter) {
   if (!nl_sock_) {
diff --git a/netlink_socket.h b/netlink_socket.h
index 0db5c3a..6f0d60e 100644
--- a/netlink_socket.h
+++ b/netlink_socket.h
@@ -116,9 +116,9 @@
   // function if NULL).
   virtual bool GetMessagesUsingCallback(NetlinkSocket::Callback *callback);
 
-  virtual unsigned int GetSequenceNumber() {
-    return nl_socket_use_seq(nl_sock_);
-  }
+  // Get the next message sequence number for this socket.  Disallow zero so
+  // that we can use that as the 'broadcast' sequence number.
+  virtual unsigned int GetSequenceNumber();
 
   // This method is called |callback_function| to differentiate it from the
   // 'callback' method in KernelBoundNlMessage since they return different
diff --git a/nl80211_socket.cc b/nl80211_socket.cc
index 1f9409c..47126b2 100644
--- a/nl80211_socket.cc
+++ b/nl80211_socket.cc
@@ -181,8 +181,8 @@
     return -1;
   }
 
-  if (!message.AddNetlinkHeader(NL_AUTO_PID, NL_AUTO_SEQ, GetFamilyId(), 0, 0,
-                                CTRL_CMD_GETFAMILY, 0))
+  if (!message.AddNetlinkHeader(this, NL_AUTO_PID, NL_AUTO_SEQ, GetFamilyId(),
+                                0, 0, CTRL_CMD_GETFAMILY, 0))
     return -1;
 
   int result = message.AddAttribute(CTRL_ATTR_FAMILY_NAME,