blob: 14dc4225754c0caddb15f1a807dc43aa03877fe5 [file] [log] [blame]
// Copyright (c) 2012 The Chromium OS Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "shill/netlink_manager.h"
#include <netlink/netlink.h>
#include <sys/select.h>
#include <sys/time.h>
#include <map>
#include <base/memory/weak_ptr.h>
#include <base/stl_util.h>
#include "shill/attribute_list.h"
#include "shill/error.h"
#include "shill/event_dispatcher.h"
#include "shill/generic_netlink_message.h"
#include "shill/io_handler.h"
#include "shill/logging.h"
#include "shill/netlink_socket.h"
#include "shill/netlink_message.h"
#include "shill/scope_logger.h"
#include "shill/shill_time.h"
using base::Bind;
using base::LazyInstance;
using std::list;
using std::map;
using std::string;
namespace shill {
namespace {
LazyInstance<NetlinkManager> g_netlink_manager = LAZY_INSTANCE_INITIALIZER;
} // namespace
const char NetlinkManager::kEventTypeConfig[] = "config";
const char NetlinkManager::kEventTypeScan[] = "scan";
const char NetlinkManager::kEventTypeRegulatory[] = "regulatory";
const char NetlinkManager::kEventTypeMlme[] = "mlme";
const long NetlinkManager::kMaximumNewFamilyWaitSeconds = 1;
const long NetlinkManager::kMaximumNewFamilyWaitMicroSeconds = 0;
NetlinkManager::MessageType::MessageType() :
family_id(NetlinkMessage::kIllegalMessageType) {}
NetlinkManager::NetlinkManager()
: dispatcher_(NULL),
weak_ptr_factory_(this),
dispatcher_callback_(Bind(&NetlinkManager::OnRawNlMessageReceived,
weak_ptr_factory_.GetWeakPtr())),
sock_(NULL) {}
NetlinkManager *NetlinkManager::GetInstance() {
return g_netlink_manager.Pointer();
}
void NetlinkManager::Reset(bool full) {
ClearBroadcastHandlers();
message_types_.clear();
if (full) {
dispatcher_ = NULL;
delete sock_;
sock_ = NULL;
}
}
void NetlinkManager::OnNewFamilyMessage(const NetlinkMessage &raw_message) {
uint16_t family_id;
string family_name;
if (raw_message.message_type() == ErrorAckMessage::kMessageType) {
const ErrorAckMessage *error_ack_message =
reinterpret_cast<const ErrorAckMessage *>(&raw_message);
if (error_ack_message->error()) {
LOG(ERROR) << __func__ << ": Message (seq: "
<< raw_message.sequence_number() << ") failed: "
<< error_ack_message->ToString();
} else {
SLOG(WiFi, 6) << __func__ << ": Message (seq: "
<< raw_message.sequence_number() << ") ACKed";
}
return;
}
if (raw_message.message_type() != ControlNetlinkMessage::kMessageType) {
LOG(ERROR) << "Received unexpected message type: "
<< raw_message.message_type();
return;
}
const ControlNetlinkMessage *message =
reinterpret_cast<const ControlNetlinkMessage *>(&raw_message);
if (!message->const_attributes()->GetU16AttributeValue(CTRL_ATTR_FAMILY_ID,
&family_id)) {
LOG(ERROR) << __func__ << ": Couldn't get family_id attribute";
return;
}
if (!message->const_attributes()->GetStringAttributeValue(
CTRL_ATTR_FAMILY_NAME, &family_name)) {
LOG(ERROR) << __func__ << ": Couldn't get family_name attribute";
return;
}
SLOG(WiFi, 3) << "Socket family '" << family_name << "' has id=" << family_id;
// Extract the available multicast groups from the message.
AttributeListConstRefPtr multicast_groups;
if (message->const_attributes()->ConstGetNestedAttributeList(
CTRL_ATTR_MCAST_GROUPS, &multicast_groups)) {
AttributeListConstRefPtr current_group;
for (int i = 1;
multicast_groups->ConstGetNestedAttributeList(i, &current_group);
++i) {
string group_name;
uint32_t group_id;
if (!current_group->GetStringAttributeValue(CTRL_ATTR_MCAST_GRP_NAME,
&group_name)) {
LOG(WARNING) << "Expected CTRL_ATTR_MCAST_GRP_NAME, found none";
continue;
}
if (!current_group->GetU32AttributeValue(CTRL_ATTR_MCAST_GRP_ID,
&group_id)) {
LOG(WARNING) << "Expected CTRL_ATTR_MCAST_GRP_ID, found none";
continue;
}
SLOG(WiFi, 3) << " Adding group '" << group_name << "' = " << group_id;
message_types_[family_name].groups[group_name] = group_id;
}
}
message_types_[family_name].family_id = family_id;
}
bool NetlinkManager::Init() {
// Install message factory for control class of messages, which has
// statically-known message type.
message_factory_.AddFactoryMethod(
ControlNetlinkMessage::kMessageType,
Bind(&ControlNetlinkMessage::CreateMessage));
if (!sock_) {
sock_ = new NetlinkSocket;
if (!sock_) {
LOG(ERROR) << "No memory";
return false;
}
if (!sock_->Init()) {
return false;
}
}
return true;
}
void NetlinkManager::Start(EventDispatcher *dispatcher) {
dispatcher_ = dispatcher;
CHECK(dispatcher_);
// Install ourselves in the shill mainloop so we receive messages on the
// netlink socket.
dispatcher_handler_.reset(dispatcher_->CreateInputHandler(
file_descriptor(),
dispatcher_callback_,
Bind(&NetlinkManager::OnReadError, weak_ptr_factory_.GetWeakPtr())));
}
int NetlinkManager::file_descriptor() const {
return (sock_ ? sock_->file_descriptor() : -1);
}
uint16_t NetlinkManager::GetFamily(string name,
const NetlinkMessageFactory::FactoryMethod &message_factory) {
MessageType &message_type = message_types_[name];
if (message_type.family_id != NetlinkMessage::kIllegalMessageType) {
return message_type.family_id;
}
if (!sock_) {
LOG(FATAL) << "Must call |Init| before this method.";
return false;
}
GetFamilyMessage msg;
if (!msg.attributes()->CreateStringAttribute(CTRL_ATTR_FAMILY_NAME,
"CTRL_ATTR_FAMILY_NAME")) {
LOG(ERROR) << "Couldn't create string attribute";
return false;
}
if (!msg.attributes()->SetStringAttributeValue(CTRL_ATTR_FAMILY_NAME, name)) {
LOG(ERROR) << "Couldn't set string attribute";
return false;
}
SendMessage(&msg, Bind(&NetlinkManager::OnNewFamilyMessage,
weak_ptr_factory_.GetWeakPtr()));
// Wait for a response. The code absolutely needs family_ids for its
// message types so we do a synchronous wait. It's OK to do this because
// a) libnl does a synchronous wait (so there's prior art), b) waiting
// asynchronously would add significant and unnecessary complexity to the
// code that deals with pending messages that could, potentially, be waiting
// for a message type, and c) it really doesn't take very long for the
// GETFAMILY / NEWFAMILY transaction to transpire (this transaction was timed
// over 20 times and found a maximum duration of 11.1 microseconds and an
// average of 4.0 microseconds).
struct timeval start_time, now, end_time;
struct timeval maximum_wait_duration = {kMaximumNewFamilyWaitSeconds,
kMaximumNewFamilyWaitMicroSeconds};
Time *time = Time::GetInstance();
time->GetTimeMonotonic(&start_time);
now = start_time;
timeradd(&start_time, &maximum_wait_duration, &end_time);
do {
// Wait with timeout for a message from the netlink socket.
fd_set read_fds;
FD_ZERO(&read_fds);
FD_SET(file_descriptor(), &read_fds);
struct timeval wait_duration;
timersub(&end_time, &now, &wait_duration);
int result = select(file_descriptor() + 1, &read_fds, NULL, NULL,
&wait_duration);
if (result < 0) {
PLOG(ERROR) << "Select failed";
return NetlinkMessage::kIllegalMessageType;
}
if (result == 0) {
LOG(WARNING) << "Timed out waiting for family_id for family '"
<< name << "'.";
return NetlinkMessage::kIllegalMessageType;
}
// Read and process any messages.
ByteString received;
sock_->RecvMessage(&received);
InputData input_data(received.GetData(), received.GetLength());
OnRawNlMessageReceived(&input_data);
if (message_type.family_id != NetlinkMessage::kIllegalMessageType) {
uint16_t family_id = message_type.family_id;
if (family_id != NetlinkMessage::kIllegalMessageType) {
message_factory_.AddFactoryMethod(family_id, message_factory);
}
time->GetTimeMonotonic(&now);
timersub(&now, &start_time, &wait_duration);
SLOG(WiFi, 5) << "Found id " << message_type.family_id
<< " for name '" << name << "' in "
<< wait_duration.tv_sec << " sec, "
<< wait_duration.tv_usec << " usec.";
return message_type.family_id;
}
time->GetTimeMonotonic(&now);
} while (timercmp(&now, &end_time, <));
LOG(ERROR) << "Timed out waiting for family_id for family '" << name << "'.";
return NetlinkMessage::kIllegalMessageType;
}
uint16_t NetlinkManager::GetMessageType(string name) const {
map<const string, MessageType>::const_iterator family =
message_types_.find(name);
if (family == message_types_.end()) {
LOG(WARNING) << "Family '" << name << "' is not in list.";
return NetlinkMessage::kIllegalMessageType;
}
return family->second.family_id;
}
bool NetlinkManager::AddBroadcastHandler(const NetlinkMessageHandler &handler) {
list<NetlinkMessageHandler>::iterator i;
if (FindBroadcastHandler(handler)) {
LOG(WARNING) << "Trying to re-add a handler";
return false; // Should only be one copy in the list.
}
if (handler.is_null()) {
LOG(WARNING) << "Trying to add a NULL handler";
return false;
}
// And add the handler to the list.
SLOG(WiFi, 3) << "NetlinkManager::" << __func__ << " - adding handler";
broadcast_handlers_.push_back(handler);
return true;
}
bool NetlinkManager::RemoveBroadcastHandler(
const NetlinkMessageHandler &handler) {
list<NetlinkMessageHandler>::iterator i;
for (i = broadcast_handlers_.begin(); i != broadcast_handlers_.end(); ++i) {
if ((*i).Equals(handler)) {
broadcast_handlers_.erase(i);
// Should only be one copy in the list so we don't have to continue
// looking for another one.
return true;
}
}
LOG(WARNING) << "NetlinkMessageHandler not found.";
return false;
}
bool NetlinkManager::FindBroadcastHandler(const NetlinkMessageHandler &handler)
const {
list<NetlinkMessageHandler>::const_iterator i;
for (i = broadcast_handlers_.begin(); i != broadcast_handlers_.end(); ++i) {
if ((*i).Equals(handler)) {
return true;
}
}
return false;
}
void NetlinkManager::ClearBroadcastHandlers() {
broadcast_handlers_.clear();
}
bool NetlinkManager::SendMessage(NetlinkMessage *message,
const NetlinkMessageHandler &handler) {
if (!message) {
LOG(ERROR) << "Message is NULL.";
return false;
}
ByteString message_string = message->Encode(this->GetSequenceNumber());
if (handler.is_null()) {
SLOG(WiFi, 3) << "Handler for message was null.";
} else if (ContainsKey(message_handlers_, message->sequence_number())) {
LOG(ERROR) << "A handler already existed for sequence: "
<< message->sequence_number();
return false;
} else {
message_handlers_[message->sequence_number()] = handler;
}
SLOG(WiFi, 6) << "NL Message " << message->sequence_number()
<< " Sending (" << message_string.GetLength()
<< " bytes) ===>";
message->Print(6);
NetlinkMessage::PrintBytes(6, message_string.GetConstData(),
message_string.GetLength());
if (!sock_->SendMessage(message_string)) {
LOG(ERROR) << "Failed to send Netlink message.";
return false;
}
return true;
}
bool NetlinkManager::RemoveMessageHandler(const NetlinkMessage &message) {
if (!ContainsKey(message_handlers_, message.sequence_number())) {
return false;
}
message_handlers_.erase(message.sequence_number());
return true;
}
uint32_t NetlinkManager::GetSequenceNumber() {
return sock_ ?
sock_->GetSequenceNumber() : NetlinkMessage::kBroadcastSequenceNumber;
}
bool NetlinkManager::SubscribeToEvents(const string &family_id,
const string &group_name) {
if (!ContainsKey(message_types_, family_id)) {
LOG(ERROR) << "Family '" << family_id << "' doesn't exist";
return false;
}
if (!ContainsKey(message_types_[family_id].groups, group_name)) {
LOG(ERROR) << "Group '" << group_name << "' doesn't exist in family '"
<< family_id << "'";
return false;
}
uint32_t group_id = message_types_[family_id].groups[group_name];
if (!sock_) {
LOG(FATAL) << "Need to call |Init| first.";
}
return sock_->SubscribeToEvents(group_id);
}
void NetlinkManager::OnRawNlMessageReceived(InputData *data) {
if (!data) {
LOG(ERROR) << __func__ << "() called with null header.";
return;
}
unsigned char *buf = data->buf;
unsigned char *end = buf + data->len;
while (buf < end) {
nlmsghdr *msg = reinterpret_cast<nlmsghdr *>(buf);
// Discard the message if there're not enough bytes to a) tell the code how
// much space is in the message (i.e., to access nlmsg_len) or b) to hold
// the entire message. The odd calculation is there to keep the code from
// potentially calculating an illegal address (causes a segfault on some
// architectures).
size_t bytes_left = end - buf;
if (((bytes_left < (offsetof(nlmsghdr, nlmsg_len) +
sizeof(msg->nlmsg_len))) ||
(bytes_left < msg->nlmsg_len))) {
LOG(ERROR) << "Discarding incomplete message.";
return;
}
OnNlMessageReceived(msg);
buf += msg->nlmsg_len;
}
}
void NetlinkManager::OnNlMessageReceived(nlmsghdr *msg) {
if (!msg) {
LOG(ERROR) << __func__ << "() called with null header.";
return;
}
const uint32 sequence_number = msg->nlmsg_seq;
scoped_ptr<NetlinkMessage> message(message_factory_.CreateMessage(msg));
if (message == NULL) {
SLOG(WiFi, 3) << "NL Message " << sequence_number << " <===";
SLOG(WiFi, 3) << __func__ << "(msg:NULL)";
return; // Skip current message, continue parsing buffer.
}
SLOG(WiFi, 3) << "NL Message " << sequence_number
<< " Received (" << msg->nlmsg_len << " bytes) <===";
message->Print(6);
NetlinkMessage::PrintBytes(8, reinterpret_cast<const unsigned char *>(msg),
msg->nlmsg_len);
// Call (then erase) any message-specific handler.
if (ContainsKey(message_handlers_, sequence_number)) {
SLOG(WiFi, 3) << "found message-specific handler";
if (message_handlers_[sequence_number].is_null()) {
LOG(ERROR) << "NetlinkMessageHandler exists but is NULL for ID "
<< sequence_number;
} else {
message_handlers_[sequence_number].Run(*message);
}
if (message->message_type() == ErrorAckMessage::kMessageType) {
const ErrorAckMessage *error_ack_message =
reinterpret_cast<const ErrorAckMessage *>(message.get());
if (error_ack_message->error()) {
SLOG(WiFi, 3) << "Removing callback";
message_handlers_.erase(sequence_number);
} else {
SLOG(WiFi, 3) << "ACK message -- not removing callback";
}
} else if ((message->flags() & NLM_F_MULTI) &&
(message->message_type() != NLMSG_DONE)) {
SLOG(WiFi, 3) << "Multi-part message -- not removing callback";
} else {
SLOG(WiFi, 3) << "Removing callback";
message_handlers_.erase(sequence_number);
}
} else {
list<NetlinkMessageHandler>::const_iterator i =
broadcast_handlers_.begin();
while (i != broadcast_handlers_.end()) {
SLOG(WiFi, 3) << __func__ << " - calling broadcast handler";
i->Run(*message);
++i;
}
}
}
void NetlinkManager::OnReadError(const Error &error) {
// TODO(wdg): When netlink_manager is used for scan, et al., this should
// either be LOG(FATAL) or the code should properly deal with errors,
// e.g., dropped messages due to the socket buffer being full.
LOG(ERROR) << "NetlinkManager's netlink Socket read returns error: "
<< error.message();
}
} // namespace shill.