shill: Hide netlink sequence numbers and headers from users

Hide details of the netlink protocol from creators of nl80211 messages
to simplify the process of sending messages to the kernel.  This changes
the process to send a message and set a callback from:

KernelBoundNlMessage message;
// ... populate message with fields
message.AddNetlinkHeader(&socket_, 0, NL_AUTO_SEQ, 0, 0, 0,
                         CTRL_CMD_GETFAMILY, 0));
Config80211::GetInstance()->SendMessage(&message, callback);

to:

KernelBoundNlMessage message(CTRL_CMD_GETFAMILY);
// ... populate message with fields
Config80211::GetInstance()->SendMessage(&message, callback);

BUG=chromium-os:36122
TEST=Unit tests

Change-Id: Id0bc931b6952fc0eafdf73499c1b48d550575fce
Reviewed-on: https://gerrit.chromium.org/gerrit/37724
Reviewed-by: Wade Guthrie <wdg@chromium.org>
Commit-Ready: Christopher Wiley <wiley@chromium.org>
Tested-by: Christopher Wiley <wiley@chromium.org>
diff --git a/config80211.cc b/config80211.cc
index 14217fb..320e9cc 100644
--- a/config80211.cc
+++ b/config80211.cc
@@ -152,43 +152,29 @@
     LOG(ERROR) << "Message is NULL.";
     return false;
   }
-  if (!SetMessageCallback(*message, callback)) {
+  uint32 sequence_number = sock_->Send(message);
+  if (!sequence_number) {
+    LOG(ERROR) << "Failed to send nl80211 message.";
     return false;
   }
-
-  message->Send(sock_);
-  return true;
-}
-
-bool Config80211::SetMessageCallback(const KernelBoundNlMessage &message,
-                                     const Callback &callback) {
-  LOG(INFO) << "Setting callback for message " << message.GetId();
-  uint32_t message_id = message.GetId();
-  if (message_id == 0) {
-    LOG(ERROR) << "Message ID 0 is reserved for broadcast callbacks.";
-    return false;
-  }
-
-  if (ContainsKey(message_callbacks_, message_id)) {
-    LOG(ERROR) << "Already a callback assigned for id " << message_id;
-    return false;
-  }
-
   if (callback.is_null()) {
-    LOG(ERROR) << "Trying to add a NULL callback for id " << message_id;
+    LOG(INFO) << "Callback for message was null.";
+    return true;
+  }
+  if (ContainsKey(message_callbacks_, sequence_number)) {
+    LOG(ERROR) << "Sent message, but already had a callback for this message?";
     return false;
   }
-
-  message_callbacks_[message_id] = callback;
+  message_callbacks_[sequence_number] = callback;
+  LOG(INFO) << "Sent nl80211 message with sequence number: " << sequence_number;
   return true;
 }
 
-bool Config80211::UnsetMessageCallbackById(uint32_t message_id) {
-  if (!ContainsKey(message_callbacks_, message_id)) {
-    LOG(WARNING) << "No callback assigned for id " << message_id;
+bool Config80211::RemoveMessageCallback(const KernelBoundNlMessage &message) {
+  if (!ContainsKey(message_callbacks_, message.sequence_number())) {
     return false;
   }
-  message_callbacks_.erase(message_id);
+  message_callbacks_.erase(message.sequence_number());
   return true;
 }
 
@@ -288,37 +274,38 @@
 }
 
 int Config80211::OnNlMessageReceived(nlmsghdr *msg) {
+  if (!msg) {
+    LOG(ERROR) << __func__ << "() called with null header.";
+    return NL_SKIP;
+  }
+  const uint32 sequence_number = msg->nlmsg_seq;
   SLOG(WiFi, 3) << "\t  Entering " << __func__
-                << "( msg:" << msg->nlmsg_seq << ")";
+                << "( msg:" << sequence_number << ")";
   scoped_ptr<UserBoundNlMessage> message(
       UserBoundNlMessageFactory::CreateMessage(msg));
   if (message == NULL) {
     SLOG(WiFi, 3) << __func__ << "(msg:NULL)";
-  } else {
-    SLOG(WiFi, 3) << __func__ << "(msg:" << msg->nlmsg_seq << ")";
-    // Call (then erase) any message-specific callback.
-    if (ContainsKey(message_callbacks_, message->GetId())) {
-      SLOG(WiFi, 3) << "found message-specific callback";
-      if (message_callbacks_[message->GetId()].is_null()) {
-        LOG(ERROR) << "Callback exists but is NULL for ID " << message->GetId();
-      } else {
-        message_callbacks_[message->GetId()].Run(*message);
-      }
-      UnsetMessageCallbackById(message->GetId());
+    return NL_SKIP;  // Skip current message, continue parsing buffer.
+  }
+  // Call (then erase) any message-specific callback.
+  if (ContainsKey(message_callbacks_, sequence_number)) {
+    SLOG(WiFi, 3) << "found message-specific callback";
+    if (message_callbacks_[sequence_number].is_null()) {
+      LOG(ERROR) << "Callback exists but is NULL for ID " << sequence_number;
     } else {
-      list<Callback>::iterator i = broadcast_callbacks_.begin();
-      while (i != broadcast_callbacks_.end()) {
-        SLOG(WiFi, 3) << "found a broadcast callback";
-        if (i->is_null()) {
-          list<Callback>::iterator temp = i;
-          ++temp;
-          broadcast_callbacks_.erase(i);
-          i = temp;
-        } else {
-          SLOG(WiFi, 3) << "      " << __func__ << " - calling callback";
-          i->Run(*message);
-          ++i;
-        }
+      message_callbacks_[sequence_number].Run(*message);
+    }
+    message_callbacks_.erase(sequence_number);
+  } else {
+    list<Callback>::iterator i = broadcast_callbacks_.begin();
+    while (i != broadcast_callbacks_.end()) {
+      SLOG(WiFi, 3) << "found a broadcast callback";
+      if (i->is_null()) {
+        i = broadcast_callbacks_.erase(i);
+      } else {
+        SLOG(WiFi, 3) << "      " << __func__ << " - calling callback";
+        i->Run(*message);
+        ++i;
       }
     }
   }
diff --git a/config80211.h b/config80211.h
index ac07b05..300230a 100644
--- a/config80211.h
+++ b/config80211.h
@@ -140,17 +140,8 @@
   // to call in case of timeout.
   bool SendMessage(KernelBoundNlMessage *message, const Callback &callback);
 
-  // Install a Config80211 Callback to handle the response to a specific
-  // message.
-  // TODO(wdg): Eventually, this should also include a timeout and a callback
-  // to call in case of timeout.
-  // TODO(wdg): maybe this should be private and only accessed by tests.
-  bool SetMessageCallback(const KernelBoundNlMessage &message,
-                          const Callback &callback);
-
-  // Uninstall a Config80211 Callback for a specific message using the
-  // message's sequence number.
-  bool UnsetMessageCallbackById(uint32_t sequence_number);
+  // Uninstall a Config80211 Callback for a specific message.
+  bool RemoveMessageCallback(const KernelBoundNlMessage &message);
 
   // Return a string corresponding to the passed-in EventType.
   static bool GetEventTypeString(EventType type, std::string *value);
diff --git a/config80211_unittest.cc b/config80211_unittest.cc
index c128793..5d4bfa4 100644
--- a/config80211_unittest.cc
+++ b/config80211_unittest.cc
@@ -343,8 +343,17 @@
 };
 
 }  // namespace
-
-unsigned int MockNl80211Socket::number_ = 0;
+uint32 MockNl80211Socket::Send(KernelBoundNlMessage *message) {
+  // Don't need a real family id; this is never sent.
+  const uint32 family_id = 0;
+  uint32 sequence_number = ++sequence_number_;
+  if (genlmsg_put(message->message(), NL_AUTO_PID, sequence_number, family_id,
+                  0, 0, message->command(), 0) == NULL) {
+    LOG(ERROR) << "genlmsg_put returned a NULL pointer.";
+    return 0;
+  }
+  return sequence_number;
+}
 
 class Config80211Test : public Test {
  public:
@@ -396,8 +405,6 @@
   // (shouldn't actually send the subscription request until wifi comes up).
   EXPECT_CALL(socket_, AddGroupMembership(_)).Times(0);
   EXPECT_CALL(socket_, SetNetlinkCallback(_, _)).Times(0);
-  EXPECT_CALL(socket_, GetSequenceNumber())
-      .WillRepeatedly(Invoke(&MockNl80211Socket::GetNextNumber));
 
   EXPECT_TRUE(config80211_->AddBroadcastCallback(
       callback_object.GetCallback()));
@@ -449,8 +456,6 @@
 
   MockCallback80211 callback1(config80211_);
   MockCallback80211 callback2(config80211_);
-  EXPECT_CALL(socket_, GetSequenceNumber())
-      .WillRepeatedly(Invoke(&MockNl80211Socket::GetNextNumber));
 
   // Simple, 1 callback, case.
   EXPECT_CALL(callback1, Config80211MessageCallback(_)).Times(1);
@@ -492,28 +497,16 @@
   // Setup.
   SetupConfig80211Object();
 
-  EXPECT_CALL(socket_, GetSequenceNumber())
-      .WillRepeatedly(Invoke(&MockNl80211Socket::GetNextNumber));
-
   MockCallback80211 callback_broadcast(config80211_);
   EXPECT_TRUE(callback_broadcast.InstallAsBroadcastCallback());
 
-  KernelBoundNlMessage sent_message_1;
+  KernelBoundNlMessage sent_message_1(CTRL_CMD_GETFAMILY);
   MockCallback80211 callback_sent_1(config80211_);
   EXPECT_TRUE(sent_message_1.Init());
-  EXPECT_TRUE(sent_message_1.AddNetlinkHeader(&socket_, 0, NL_AUTO_SEQ, 0, 0, 0,
-                                              CTRL_CMD_GETFAMILY, 0));
-  LOG(INFO) << "Message 1 id:" << sent_message_1.GetId();
 
-  KernelBoundNlMessage sent_message_2;
+  KernelBoundNlMessage sent_message_2(CTRL_CMD_GETFAMILY);
   MockCallback80211 callback_sent_2(config80211_);
   EXPECT_TRUE(sent_message_2.Init());
-  EXPECT_TRUE(sent_message_2.AddNetlinkHeader(&socket_, 0, NL_AUTO_SEQ, 0, 0, 0,
-                                              CTRL_CMD_GETFAMILY, 0));
-  LOG(INFO) << "Message 2 id:" << sent_message_2.GetId();
-
-  // This is more testing the test code than the code, itself.
-  EXPECT_NE(sent_message_1.GetId(), sent_message_2.GetId());
 
   // Set up the received message as a response to sent_message_1.
   scoped_array<unsigned char> message_memory(
@@ -522,7 +515,6 @@
          sizeof(kNL80211_CMD_DISCONNECT));
   nlmsghdr *received_message =
         reinterpret_cast<nlmsghdr *>(message_memory.get());
-  received_message->nlmsg_seq = sent_message_1.GetId();
 
   // Now, we can start the actual test...
 
@@ -531,9 +523,11 @@
   EXPECT_CALL(callback_broadcast, Config80211MessageCallback(_)).Times(1);
   config80211_->OnNlMessageReceived(received_message);
 
-  // Install message-based callback; verify that message callback gets called.
-  EXPECT_TRUE(config80211_->SetMessageCallback(sent_message_1,
-                                               callback_sent_1.GetCallback()));
+  // Send the message and give our callback.  Verify that we get called back.
+  EXPECT_TRUE(config80211_->SendMessage(&sent_message_1,
+                                        callback_sent_1.GetCallback()));
+  // Make it appear that this message is in response to our sent message.
+  received_message->nlmsg_seq = socket_.GetLastSequenceNumber();
   EXPECT_CALL(callback_sent_1, Config80211MessageCallback(_)).Times(1);
   config80211_->OnNlMessageReceived(received_message);
 
@@ -544,22 +538,23 @@
 
   // Install and then uninstall message-specific callback; verify broadcast
   // callback is called on message receipt.
-  EXPECT_TRUE(config80211_->SetMessageCallback(sent_message_1,
-                                               callback_sent_1.GetCallback()));
-  EXPECT_TRUE(config80211_->UnsetMessageCallbackById(sent_message_1.GetId()));
+  EXPECT_TRUE(config80211_->SendMessage(&sent_message_1,
+                                        callback_sent_1.GetCallback()));
+  received_message->nlmsg_seq = socket_.GetLastSequenceNumber();
+  EXPECT_TRUE(config80211_->RemoveMessageCallback(sent_message_1));
   EXPECT_CALL(callback_broadcast, Config80211MessageCallback(_)).Times(1);
   config80211_->OnNlMessageReceived(received_message);
 
   // Install callback for different message; verify that broadcast callback is
   // called for _this_ message.
-  EXPECT_TRUE(config80211_->SetMessageCallback(sent_message_2,
-                                               callback_sent_2.GetCallback()));
+  EXPECT_TRUE(config80211_->SendMessage(&sent_message_2,
+                                        callback_sent_2.GetCallback()));
   EXPECT_CALL(callback_broadcast, Config80211MessageCallback(_)).Times(1);
   config80211_->OnNlMessageReceived(received_message);
 
   // Change the ID for the message to that of the second callback; verify that
   // the appropriate callback is called for _that_ message.
-  received_message->nlmsg_seq = sent_message_2.GetId();
+  received_message->nlmsg_seq = socket_.GetLastSequenceNumber();
   EXPECT_CALL(callback_sent_2, Config80211MessageCallback(_)).Times(1);
   config80211_->OnNlMessageReceived(received_message);
 }
diff --git a/kernel_bound_nlmessage.cc b/kernel_bound_nlmessage.cc
index 4e97abc..e1b8e7c 100644
--- a/kernel_bound_nlmessage.cc
+++ b/kernel_bound_nlmessage.cc
@@ -4,21 +4,12 @@
 
 #include "shill/kernel_bound_nlmessage.h"
 
-#include <net/if.h>
-#include <netlink/genl/genl.h>
 #include <netlink/msg.h>
-#include <netlink/netlink.h>
-
-#include <base/logging.h>
 
 #include "shill/logging.h"
-#include "shill/netlink_socket.h"
-#include "shill/scope_logger.h"
 
 namespace shill {
 
-const uint32_t KernelBoundNlMessage::kIllegalMessage = 0xFFFFFFFF;
-
 KernelBoundNlMessage::~KernelBoundNlMessage() {
   if (message_) {
     nlmsg_free(message_);
@@ -37,55 +28,6 @@
   return true;
 }
 
-uint32_t KernelBoundNlMessage::GetId() const {
-  if (!message_) {
-    LOG(ERROR) << "NULL |message_|";
-    return kIllegalMessage;
-  }
-  struct nlmsghdr *header = nlmsg_hdr(message_);
-  if (!header) {
-    LOG(ERROR) << "Couldn't make header";
-    return kIllegalMessage;
-  }
-  return header->nlmsg_seq;
-}
-
-bool KernelBoundNlMessage::AddNetlinkHeader(NetlinkSocket *socket,
-                                            uint32_t port, uint32_t seq,
-                                            int family_id, int hdrlen,
-                                            int flags, uint8_t cmd,
-                                            uint8_t version) {
-  if (!message_) {
-    LOG(ERROR) << "NULL |message_|";
-    return false;
-  }
-
-  // Parameters to genlmsg_put:
-  //  @message: struct nl_msg *message_.
-  //  @pid: netlink pid the message is addressed to.
-  //  @seq: sequence number (usually the one of the sender).
-  //  @family: generic netlink family.
-  //  @flags netlink message flags.
-  //  @cmd: netlink command.
-  //  @version: version of communication protocol.
-  // genlmsg_put returns a void * pointing to the header but we don't want to
-  // encourage its use outside of this object.
-
-  if (genlmsg_put(message_, port, seq, family_id, hdrlen, flags, cmd, version)
-      == NULL) {
-    LOG(ERROR) << "genlmsg_put returned a NULL pointer.";
-    return false;
-  }
-
-  // Manually set the sequence number if it's zero.
-  struct nlmsghdr *header = nlmsg_hdr(message_);
-  if (header != 0 && seq == NL_AUTO_SEQ && header->nlmsg_seq == 0) {
-    header->nlmsg_seq = socket->GetSequenceNumber();
-  }
-
-  return true;
-}
-
 int KernelBoundNlMessage::AddAttribute(int attrtype, int attrlen,
                                        const void *data) {
   if (!data) {
@@ -99,15 +41,11 @@
   return nla_put(message_, attrtype, attrlen, data);
 }
 
-bool KernelBoundNlMessage::Send(NetlinkSocket *socket) {
-  if (!socket) {
-    LOG(ERROR) << "NULL |socket| parameter";
-    return false;
+uint32 KernelBoundNlMessage::sequence_number() const {
+  if (message_ && nlmsg_hdr(message_)) {
+    return nlmsg_hdr(message_)->nlmsg_seq;
   }
-
-  SLOG(WiFi, 6) << "NL Message " << GetId() << " ===>";
-
-  return socket->Send(message_);
+  return 0;
 }
 
 }  // namespace shill.
diff --git a/kernel_bound_nlmessage.h b/kernel_bound_nlmessage.h
index b51d465..e4595c4 100644
--- a/kernel_bound_nlmessage.h
+++ b/kernel_bound_nlmessage.h
@@ -26,12 +26,10 @@
 #define SHILL_KERNEL_BOUND_NLMESSAGE_H_
 
 #include <base/basictypes.h>
-#include <base/bind.h>
 
 struct nl_msg;
 
 namespace shill {
-struct NetlinkSocket;
 
 // TODO(wdg): eventually, KernelBoundNlMessage and UserBoundNlMessage should
 // be combined into a monolithic NlMessage.
@@ -39,29 +37,29 @@
 // Provides a wrapper around a netlink message destined for kernel-space.
 class KernelBoundNlMessage {
  public:
-  KernelBoundNlMessage() : message_(NULL) {}
+  // |command| is a type of command understood by the kernel, for instance:
+  // CTRL_CMD_GETFAMILY.
+  explicit KernelBoundNlMessage(uint8 command)
+      : command_(command),
+        message_(NULL) {};
   virtual ~KernelBoundNlMessage();
 
   // Non-trivial initialization.
   bool Init();
 
-  // Message ID is equivalent to the message's sequence number.
-  uint32_t GetId() const;
-
-  // Add a netlink header to the message.
-  bool AddNetlinkHeader(NetlinkSocket *socket, uint32_t port, uint32_t seq,
-                        int family_id, int hdrlen, int flags, uint8_t cmd,
-                        uint8_t version);
-
   // Add a netlink attribute to the message.
   int AddAttribute(int attrtype, int attrlen, const void *data);
 
-  // Sends |this| over the netlink socket.
-  virtual bool Send(NetlinkSocket *socket);
+  uint8 command() const { return command_; }
+  // TODO(wiley) It would be better if messages were bags of attributes which
+  //             the socket collapses into binary blobs at send time.
+  struct nl_msg *message() const { return message_; }
+  // Returns 0 when unsent, > 0 otherwise.
+  uint32 sequence_number() const;
 
  private:
-  static const uint32_t kIllegalMessage;
-
+  uint8 command_;
+  // TODO(wiley) Rename to |raw_message_| (message.message() looks silly).
   struct nl_msg *message_;
 
   DISALLOW_COPY_AND_ASSIGN(KernelBoundNlMessage);
diff --git a/mock_nl80211_socket.h b/mock_nl80211_socket.h
index 20a9d70..51081d0 100644
--- a/mock_nl80211_socket.h
+++ b/mock_nl80211_socket.h
@@ -19,7 +19,7 @@
 
 class MockNl80211Socket : public Nl80211Socket {
  public:
-  MockNl80211Socket() {}
+  MockNl80211Socket() : sequence_number_(1) {}
   MOCK_METHOD0(Init, bool());
   MOCK_METHOD1(AddGroupMembership, bool(const std::string &group_name));
   using Nl80211Socket::GetMessages;
@@ -29,14 +29,12 @@
                                         void *callback_parameter));
   MOCK_METHOD0(GetSequenceNumber, unsigned int());
 
-  static unsigned int GetNextNumber() {
-    if (++number_ == 0)
-      number_ = 1;
-    return number_;
-  }
+  virtual uint32 Send(KernelBoundNlMessage *message);
+
+  uint32 GetLastSequenceNumber() const { return sequence_number_; }
 
  private:
-  static unsigned int number_;
+  uint32 sequence_number_;
 
   DISALLOW_COPY_AND_ASSIGN(MockNl80211Socket);
 };
diff --git a/netlink_socket.cc b/netlink_socket.cc
index 6b08920..87b200e 100644
--- a/netlink_socket.cc
+++ b/netlink_socket.cc
@@ -38,9 +38,7 @@
 
 #include <iomanip>
 
-#include <base/logging.h>
-
-#include "shill/scope_logger.h"
+#include "shill/logging.h"
 
 namespace shill {
 
@@ -113,23 +111,46 @@
   return true;
 }
 
-bool NetlinkSocket::Send(struct nl_msg *message) {
+uint32 NetlinkSocket::Send(struct nl_msg *message,
+                           uint8 command,
+                           int32 family_id) {
   if (!message) {
     LOG(ERROR) << "NULL |message|.";
-    return false;
+    return 0;
   }
 
   if (!nl_sock_) {
     LOG(ERROR) << "Need to initialize the socket first.";
-    return false;
+    return 0;
   }
 
+  // Parameters to genlmsg_put:
+  //  @message: a pointer to a struct nl_msg *message.
+  //  @pid: netlink pid the message is addressed to.
+  //  @seq: sequence number.
+  //  @family: netlink socket family (NETLINK_GENERIC for us)
+  //  @flags netlink message flags.
+  //  @hdrlen: Length of a user header (which we don't use)
+  //  @cmd: netlink command.
+  //  @version: version of communication protocol.
+  // genlmsg_put returns a void * pointing to the user header but we don't
+  // want to encourage its use outside of this object.
+
+  uint32 sequence_number = GetSequenceNumber();
+  if (genlmsg_put(message, NL_AUTO_PID, sequence_number, family_id,
+                  0, 0, command, 0) == NULL) {
+    LOG(ERROR) << "genlmsg_put returned a NULL pointer.";
+    return 0;
+  }
+
+  SLOG(WiFi, 6) << "NL Message " << sequence_number << " ===>";
+
   int result = nl_send_auto_complete(nl_sock_, message);
   if (result < 0) {
     LOG(ERROR) << "Failed call to 'nl_send_auto_complete': " << result;
-    return false;
+    return 0;
   }
-  return true;
+  return sequence_number;
 }
 
 
diff --git a/netlink_socket.h b/netlink_socket.h
index 7abc271..aad3d65 100644
--- a/netlink_socket.h
+++ b/netlink_socket.h
@@ -84,9 +84,6 @@
   // Non-trivial initialization.
   bool Init();
 
-  // Send a message.
-  virtual bool Send(struct nl_msg *message);
-
   // Disables sequence checking on the message stream.
   virtual bool DisableSequenceChecking();
 
@@ -117,6 +114,9 @@
   virtual std::string GetSocketFamilyName() const = 0;
 
  protected:
+  // Send a message, returns 0 on error, or the sequence number (> 0).
+  uint32 Send(struct nl_msg *message, uint8 command, int32 family_id);
+
   struct nl_sock *GetNlSock() { return nl_sock_; }
 
  private:
diff --git a/nl80211_socket.cc b/nl80211_socket.cc
index 08d9d59..d2c7fb0 100644
--- a/nl80211_socket.cc
+++ b/nl80211_socket.cc
@@ -39,12 +39,9 @@
 #include <sstream>
 #include <string>
 
-#include <base/logging.h>
-
 #include "shill/kernel_bound_nlmessage.h"
 #include "shill/logging.h"
 #include "shill/netlink_socket.h"
-#include "shill/scope_logger.h"
 #include "shill/user_bound_nlmessage.h"
 
 using std::string;
@@ -86,4 +83,11 @@
   return true;
 }
 
+uint32 Nl80211Socket::Send(KernelBoundNlMessage *message) {
+  CHECK(message);
+  return NetlinkSocket::Send(message->message(),
+                             message->command(),
+                             GetFamilyId());
+}
+
 }  // namespace shill
diff --git a/nl80211_socket.h b/nl80211_socket.h
index cc8823b..f7d4408 100644
--- a/nl80211_socket.h
+++ b/nl80211_socket.h
@@ -34,6 +34,7 @@
 #include <base/bind.h>
 
 #include "shill/netlink_socket.h"
+#include "shill/kernel_bound_nlmessage.h"
 
 struct nl_msg;
 struct sockaddr_nl;
@@ -65,6 +66,8 @@
     return Nl80211Socket::kSocketFamilyName;
   }
 
+  virtual uint32 Send(KernelBoundNlMessage *message);
+
  private:
   // The family name of this particular netlink socket.
   static const char kSocketFamilyName[];