Paul Stewart | ac1328e | 2012-07-20 11:55:40 -0700 | [diff] [blame] | 1 | // Copyright (c) 2012 The Chromium OS Authors. All rights reserved. |
| 2 | // Use of this source code is governed by a BSD-style license that can be |
| 3 | // found in the LICENSE file. |
| 4 | |
| 5 | #include "shill/arp_client.h" |
| 6 | |
| 7 | #include <linux/if_packet.h> |
| 8 | #include <net/ethernet.h> |
| 9 | #include <net/if_arp.h> |
| 10 | #include <netinet/in.h> |
| 11 | #include <string.h> |
| 12 | |
Paul Stewart | ac1328e | 2012-07-20 11:55:40 -0700 | [diff] [blame] | 13 | #include "shill/arp_packet.h" |
| 14 | #include "shill/byte_string.h" |
Christopher Wiley | b691efd | 2012-08-09 13:51:51 -0700 | [diff] [blame] | 15 | #include "shill/logging.h" |
Paul Stewart | ac1328e | 2012-07-20 11:55:40 -0700 | [diff] [blame] | 16 | #include "shill/sockets.h" |
| 17 | |
| 18 | namespace shill { |
| 19 | |
| 20 | // ARP opcode is the last uint16 in the ARP header. |
| 21 | const size_t ArpClient::kArpOpOffset = sizeof(arphdr) - sizeof(uint16); |
| 22 | |
| 23 | // The largest packet we expect is one with IPv6 addresses in it. |
| 24 | const size_t ArpClient::kMaxArpPacketLength = |
| 25 | sizeof(arphdr) + sizeof(in6_addr) * 2 + ETH_ALEN * 2; |
| 26 | |
| 27 | ArpClient::ArpClient(int interface_index) |
| 28 | : interface_index_(interface_index), |
| 29 | sockets_(new Sockets()), |
| 30 | socket_(-1) {} |
| 31 | |
| 32 | ArpClient::~ArpClient() {} |
| 33 | |
| 34 | bool ArpClient::Start() { |
| 35 | if (!CreateSocket()) { |
| 36 | LOG(ERROR) << "Could not open ARP socket."; |
| 37 | Stop(); |
| 38 | return false; |
| 39 | } |
| 40 | return true; |
| 41 | } |
| 42 | |
| 43 | void ArpClient::Stop() { |
| 44 | socket_closer_.reset(); |
| 45 | } |
| 46 | |
| 47 | |
| 48 | bool ArpClient::CreateSocket() { |
| 49 | int socket = sockets_->Socket(PF_PACKET, SOCK_DGRAM, htons(ETHERTYPE_ARP)); |
| 50 | if (socket == -1) { |
| 51 | PLOG(ERROR) << "Could not create ARP socket"; |
| 52 | return false; |
| 53 | } |
| 54 | socket_ = socket; |
| 55 | socket_closer_.reset(new ScopedSocketCloser(sockets_.get(), socket_)); |
| 56 | |
| 57 | // Create a packet filter incoming ARP replies. |
| 58 | static const sock_filter arp_reply_filter[] = { |
| 59 | // If we a packet contains ARPOP_REPLY as the ARP opcode... |
| 60 | BPF_STMT(BPF_LD | BPF_H | BPF_ABS, kArpOpOffset), |
| 61 | BPF_JUMP(BPF_JMP | BPF_JEQ | BPF_K, ARPOP_REPLY, 0, 1), |
| 62 | // Return the the packet (up to largest expected packet size). |
| 63 | BPF_STMT(BPF_RET | BPF_K, kMaxArpPacketLength), |
| 64 | // Otherwise, drop it. |
| 65 | BPF_STMT(BPF_RET | BPF_K, 0), |
| 66 | }; |
| 67 | |
| 68 | sock_fprog pf; |
| 69 | pf.filter = const_cast<sock_filter *>(arp_reply_filter); |
| 70 | pf.len = arraysize(arp_reply_filter); |
| 71 | if (sockets_->AttachFilter(socket_, &pf) != 0) { |
| 72 | PLOG(ERROR) << "Could not attach packet filter"; |
| 73 | return false; |
| 74 | } |
| 75 | |
| 76 | if (sockets_->SetNonBlocking(socket_) != 0) { |
| 77 | PLOG(ERROR) << "Could not set socket to be non-blocking"; |
| 78 | return false; |
| 79 | } |
| 80 | |
| 81 | sockaddr_ll socket_address; |
| 82 | memset(&socket_address, 0, sizeof(socket_address)); |
| 83 | socket_address.sll_family = AF_PACKET; |
| 84 | socket_address.sll_protocol = htons(ETHERTYPE_ARP); |
| 85 | socket_address.sll_ifindex = interface_index_; |
| 86 | |
| 87 | if (sockets_->Bind(socket_, |
| 88 | reinterpret_cast<struct sockaddr *>(&socket_address), |
| 89 | sizeof(socket_address)) != 0) { |
| 90 | PLOG(ERROR) << "Could not bind socket to interface"; |
| 91 | return false; |
| 92 | } |
| 93 | |
| 94 | return true; |
| 95 | } |
| 96 | |
| 97 | bool ArpClient::ReceiveReply(ArpPacket *packet, ByteString *sender) const { |
| 98 | ByteString payload(kMaxArpPacketLength); |
| 99 | sockaddr_ll socket_address; |
| 100 | memset(&socket_address, 0, sizeof(socket_address)); |
| 101 | socklen_t socklen = sizeof(socket_address); |
| 102 | int result = sockets_->RecvFrom( |
| 103 | socket_, |
| 104 | payload.GetData(), |
| 105 | payload.GetLength(), |
| 106 | 0, |
| 107 | reinterpret_cast<struct sockaddr *>(&socket_address), |
| 108 | &socklen); |
| 109 | if (result < 0) { |
| 110 | PLOG(ERROR) << "Socket recvfrom failed"; |
| 111 | return false; |
| 112 | } |
| 113 | |
| 114 | payload.Resize(result); |
| 115 | if (!packet->ParseReply(payload)) { |
| 116 | LOG(ERROR) << "Failed to parse ARP reply."; |
| 117 | return false; |
| 118 | } |
| 119 | |
| 120 | // The socket address returned may only be big enough to contain |
| 121 | // the hardware address of the sender. |
| 122 | CHECK(socklen >= |
| 123 | sizeof(socket_address) - sizeof(socket_address.sll_addr) + ETH_ALEN); |
| 124 | CHECK(socket_address.sll_halen == ETH_ALEN); |
| 125 | *sender = ByteString( |
| 126 | reinterpret_cast<const unsigned char *>(&socket_address.sll_addr), |
| 127 | socket_address.sll_halen); |
| 128 | return true; |
| 129 | } |
| 130 | |
| 131 | bool ArpClient::TransmitRequest(const ArpPacket &packet) const { |
| 132 | ByteString payload; |
| 133 | if (!packet.FormatRequest(&payload)) { |
| 134 | return false; |
| 135 | } |
| 136 | |
| 137 | sockaddr_ll socket_address; |
| 138 | memset(&socket_address, 0, sizeof(socket_address)); |
| 139 | socket_address.sll_family = AF_PACKET; |
| 140 | socket_address.sll_protocol = htons(ETHERTYPE_ARP); |
| 141 | socket_address.sll_hatype = ARPHRD_ETHER; |
| 142 | socket_address.sll_halen = ETH_ALEN; |
| 143 | socket_address.sll_ifindex = interface_index_; |
| 144 | |
| 145 | ByteString remote_address = packet.remote_mac_address(); |
| 146 | CHECK(sizeof(socket_address.sll_addr) >= remote_address.GetLength()); |
| 147 | if (remote_address.IsZero()) { |
| 148 | // If the destination MAC address is unspecified, send the packet |
| 149 | // to the broadcast (all-ones) address. |
| 150 | remote_address.BitwiseInvert(); |
| 151 | } |
| 152 | memcpy(&socket_address.sll_addr, remote_address.GetConstData(), |
| 153 | remote_address.GetLength()); |
| 154 | |
| 155 | int result = sockets_->SendTo( |
| 156 | socket_, |
| 157 | payload.GetConstData(), |
| 158 | payload.GetLength(), |
| 159 | 0, |
| 160 | reinterpret_cast<struct sockaddr *>(&socket_address), |
| 161 | sizeof(socket_address)); |
| 162 | const int expected_result = static_cast<int>(payload.GetLength()); |
| 163 | if (result != expected_result) { |
| 164 | if (result < 0) { |
| 165 | PLOG(ERROR) << "Socket sendto failed"; |
| 166 | } else if (result < static_cast<int>(payload.GetLength())) { |
| 167 | LOG(ERROR) << "Socket sendto returned " |
| 168 | << result |
| 169 | << " which is different from expected result " |
| 170 | << expected_result; |
| 171 | } |
| 172 | return false; |
| 173 | } |
| 174 | |
| 175 | return true; |
| 176 | } |
| 177 | |
| 178 | } // namespace shill |