blob: 0b497e642a25cdf00b063dbba7d8ea14a3004b53 [file] [log] [blame]
// Copyright (c) 2012 The Chromium OS Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "shill/arp_client.h"
#include <linux/if_packet.h>
#include <net/ethernet.h>
#include <net/if_arp.h>
#include <netinet/in.h>
#include <string.h>
#include "shill/arp_packet.h"
#include "shill/logging.h"
#include "shill/net/byte_string.h"
#include "shill/net/sockets.h"
namespace shill {
// ARP opcode is the last uint16_t in the ARP header.
const size_t ArpClient::kArpOpOffset = sizeof(arphdr) - sizeof(uint16_t);
// The largest packet we expect is one with IPv6 addresses in it.
const size_t ArpClient::kMaxArpPacketLength =
sizeof(arphdr) + sizeof(in6_addr) * 2 + ETH_ALEN * 2;
ArpClient::ArpClient(int interface_index)
: interface_index_(interface_index),
sockets_(new Sockets()),
socket_(-1) {}
ArpClient::~ArpClient() {}
bool ArpClient::StartReplyListener() {
return Start(ARPOP_REPLY);
}
bool ArpClient::StartRequestListener() {
return Start(ARPOP_REQUEST);
}
bool ArpClient::Start(uint16_t arp_opcode) {
if (!CreateSocket(arp_opcode)) {
LOG(ERROR) << "Could not open ARP socket.";
Stop();
return false;
}
return true;
}
void ArpClient::Stop() {
socket_closer_.reset();
}
bool ArpClient::CreateSocket(uint16_t arp_opcode) {
int socket = sockets_->Socket(PF_PACKET, SOCK_DGRAM, htons(ETHERTYPE_ARP));
if (socket == -1) {
PLOG(ERROR) << "Could not create ARP socket";
return false;
}
socket_ = socket;
socket_closer_.reset(new ScopedSocketCloser(sockets_.get(), socket_));
// Create a packet filter incoming ARP packets.
const sock_filter arp_filter[] = {
// If a packet contains the ARP opcode we are looking for...
BPF_STMT(BPF_LD | BPF_H | BPF_ABS, kArpOpOffset),
BPF_JUMP(BPF_JMP | BPF_JEQ | BPF_K, arp_opcode, 0, 1),
// Return the the packet (up to largest expected packet size).
BPF_STMT(BPF_RET | BPF_K, kMaxArpPacketLength),
// Otherwise, drop it.
BPF_STMT(BPF_RET | BPF_K, 0),
};
sock_fprog pf;
pf.filter = const_cast<sock_filter *>(arp_filter);
pf.len = arraysize(arp_filter);
if (sockets_->AttachFilter(socket_, &pf) != 0) {
PLOG(ERROR) << "Could not attach packet filter";
return false;
}
if (sockets_->SetNonBlocking(socket_) != 0) {
PLOG(ERROR) << "Could not set socket to be non-blocking";
return false;
}
sockaddr_ll socket_address;
memset(&socket_address, 0, sizeof(socket_address));
socket_address.sll_family = AF_PACKET;
socket_address.sll_protocol = htons(ETHERTYPE_ARP);
socket_address.sll_ifindex = interface_index_;
if (sockets_->Bind(socket_,
reinterpret_cast<struct sockaddr *>(&socket_address),
sizeof(socket_address)) != 0) {
PLOG(ERROR) << "Could not bind socket to interface";
return false;
}
return true;
}
bool ArpClient::ReceivePacket(ArpPacket *packet, ByteString *sender) const {
ByteString payload(kMaxArpPacketLength);
sockaddr_ll socket_address;
memset(&socket_address, 0, sizeof(socket_address));
socklen_t socklen = sizeof(socket_address);
int result = sockets_->RecvFrom(
socket_,
payload.GetData(),
payload.GetLength(),
0,
reinterpret_cast<struct sockaddr *>(&socket_address),
&socklen);
if (result < 0) {
PLOG(ERROR) << "Socket recvfrom failed";
return false;
}
payload.Resize(result);
if (!packet->Parse(payload)) {
LOG(ERROR) << "Failed to parse ARP packet.";
return false;
}
// The socket address returned may only be big enough to contain
// the hardware address of the sender.
CHECK(socklen >=
sizeof(socket_address) - sizeof(socket_address.sll_addr) + ETH_ALEN);
CHECK(socket_address.sll_halen == ETH_ALEN);
*sender = ByteString(
reinterpret_cast<const unsigned char *>(&socket_address.sll_addr),
socket_address.sll_halen);
return true;
}
bool ArpClient::TransmitRequest(const ArpPacket &packet) const {
ByteString payload;
if (!packet.FormatRequest(&payload)) {
return false;
}
sockaddr_ll socket_address;
memset(&socket_address, 0, sizeof(socket_address));
socket_address.sll_family = AF_PACKET;
socket_address.sll_protocol = htons(ETHERTYPE_ARP);
socket_address.sll_hatype = ARPHRD_ETHER;
socket_address.sll_halen = ETH_ALEN;
socket_address.sll_ifindex = interface_index_;
ByteString remote_address = packet.remote_mac_address();
CHECK(sizeof(socket_address.sll_addr) >= remote_address.GetLength());
if (remote_address.IsZero()) {
// If the destination MAC address is unspecified, send the packet
// to the broadcast (all-ones) address.
remote_address.BitwiseInvert();
}
memcpy(&socket_address.sll_addr, remote_address.GetConstData(),
remote_address.GetLength());
int result = sockets_->SendTo(
socket_,
payload.GetConstData(),
payload.GetLength(),
0,
reinterpret_cast<struct sockaddr *>(&socket_address),
sizeof(socket_address));
const int expected_result = static_cast<int>(payload.GetLength());
if (result != expected_result) {
if (result < 0) {
PLOG(ERROR) << "Socket sendto failed";
} else if (result < static_cast<int>(payload.GetLength())) {
LOG(ERROR) << "Socket sendto returned "
<< result
<< " which is different from expected result "
<< expected_result;
}
return false;
}
return true;
}
} // namespace shill