blob: 6ce09fd251c56fedd38197ce21f5cc422f364e7e [file] [log] [blame]
henrike@webrtc.org0e118e72013-07-10 00:45:36 +00001/*
2 * libjingle
3 * Copyright 2004--2005, Google Inc.
4 *
5 * Redistribution and use in source and binary forms, with or without
6 * modification, are permitted provided that the following conditions are met:
7 *
8 * 1. Redistributions of source code must retain the above copyright notice,
9 * this list of conditions and the following disclaimer.
10 * 2. Redistributions in binary form must reproduce the above copyright notice,
11 * this list of conditions and the following disclaimer in the documentation
12 * and/or other materials provided with the distribution.
13 * 3. The name of the author may not be used to endorse or promote products
14 * derived from this software without specific prior written permission.
15 *
16 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR IMPLIED
17 * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
18 * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO
19 * EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
20 * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21 * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
22 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
23 * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
24 * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
25 * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26 */
27
28#include "talk/base/natsocketfactory.h"
29
30#include "talk/base/logging.h"
31#include "talk/base/natserver.h"
32#include "talk/base/virtualsocketserver.h"
33
34namespace talk_base {
35
36// Packs the given socketaddress into the buffer in buf, in the quasi-STUN
37// format that the natserver uses.
38// Returns 0 if an invalid address is passed.
39size_t PackAddressForNAT(char* buf, size_t buf_size,
40 const SocketAddress& remote_addr) {
41 const IPAddress& ip = remote_addr.ipaddr();
42 int family = ip.family();
43 buf[0] = 0;
44 buf[1] = family;
45 // Writes the port.
46 *(reinterpret_cast<uint16*>(&buf[2])) = HostToNetwork16(remote_addr.port());
47 if (family == AF_INET) {
48 ASSERT(buf_size >= kNATEncodedIPv4AddressSize);
49 in_addr v4addr = ip.ipv4_address();
pbos@webrtc.orgb9518272014-03-07 15:22:04 +000050 memcpy(&buf[4], &v4addr, kNATEncodedIPv4AddressSize - 4);
henrike@webrtc.org0e118e72013-07-10 00:45:36 +000051 return kNATEncodedIPv4AddressSize;
52 } else if (family == AF_INET6) {
53 ASSERT(buf_size >= kNATEncodedIPv6AddressSize);
54 in6_addr v6addr = ip.ipv6_address();
pbos@webrtc.orgb9518272014-03-07 15:22:04 +000055 memcpy(&buf[4], &v6addr, kNATEncodedIPv6AddressSize - 4);
henrike@webrtc.org0e118e72013-07-10 00:45:36 +000056 return kNATEncodedIPv6AddressSize;
57 }
58 return 0U;
59}
60
61// Decodes the remote address from a packet that has been encoded with the nat's
62// quasi-STUN format. Returns the length of the address (i.e., the offset into
63// data where the original packet starts).
64size_t UnpackAddressFromNAT(const char* buf, size_t buf_size,
65 SocketAddress* remote_addr) {
66 ASSERT(buf_size >= 8);
67 ASSERT(buf[0] == 0);
68 int family = buf[1];
69 uint16 port = NetworkToHost16(*(reinterpret_cast<const uint16*>(&buf[2])));
70 if (family == AF_INET) {
71 const in_addr* v4addr = reinterpret_cast<const in_addr*>(&buf[4]);
72 *remote_addr = SocketAddress(IPAddress(*v4addr), port);
73 return kNATEncodedIPv4AddressSize;
74 } else if (family == AF_INET6) {
75 ASSERT(buf_size >= 20);
76 const in6_addr* v6addr = reinterpret_cast<const in6_addr*>(&buf[4]);
77 *remote_addr = SocketAddress(IPAddress(*v6addr), port);
78 return kNATEncodedIPv6AddressSize;
79 }
80 return 0U;
81}
82
83
84// NATSocket
85class NATSocket : public AsyncSocket, public sigslot::has_slots<> {
86 public:
87 explicit NATSocket(NATInternalSocketFactory* sf, int family, int type)
mallinath@webrtc.org93e7d6e2013-09-30 18:59:08 +000088 : sf_(sf), family_(family), type_(type), connected_(false),
henrike@webrtc.org0e118e72013-07-10 00:45:36 +000089 socket_(NULL), buf_(NULL), size_(0) {
90 }
91
92 virtual ~NATSocket() {
93 delete socket_;
94 delete[] buf_;
95 }
96
97 virtual SocketAddress GetLocalAddress() const {
98 return (socket_) ? socket_->GetLocalAddress() : SocketAddress();
99 }
100
101 virtual SocketAddress GetRemoteAddress() const {
102 return remote_addr_; // will be NIL if not connected
103 }
104
105 virtual int Bind(const SocketAddress& addr) {
106 if (socket_) { // already bound, bubble up error
107 return -1;
108 }
109
110 int result;
111 socket_ = sf_->CreateInternalSocket(family_, type_, addr, &server_addr_);
112 result = (socket_) ? socket_->Bind(addr) : -1;
113 if (result >= 0) {
114 socket_->SignalConnectEvent.connect(this, &NATSocket::OnConnectEvent);
115 socket_->SignalReadEvent.connect(this, &NATSocket::OnReadEvent);
116 socket_->SignalWriteEvent.connect(this, &NATSocket::OnWriteEvent);
117 socket_->SignalCloseEvent.connect(this, &NATSocket::OnCloseEvent);
118 } else {
119 server_addr_.Clear();
120 delete socket_;
121 socket_ = NULL;
122 }
123
124 return result;
125 }
126
127 virtual int Connect(const SocketAddress& addr) {
128 if (!socket_) { // socket must be bound, for now
129 return -1;
130 }
131
132 int result = 0;
133 if (type_ == SOCK_STREAM) {
134 result = socket_->Connect(server_addr_.IsNil() ? addr : server_addr_);
135 } else {
136 connected_ = true;
137 }
138
139 if (result >= 0) {
140 remote_addr_ = addr;
141 }
142
143 return result;
144 }
145
146 virtual int Send(const void* data, size_t size) {
147 ASSERT(connected_);
148 return SendTo(data, size, remote_addr_);
149 }
150
151 virtual int SendTo(const void* data, size_t size, const SocketAddress& addr) {
152 ASSERT(!connected_ || addr == remote_addr_);
153 if (server_addr_.IsNil() || type_ == SOCK_STREAM) {
154 return socket_->SendTo(data, size, addr);
155 }
156 // This array will be too large for IPv4 packets, but only by 12 bytes.
wu@webrtc.org5c9dd592013-10-25 21:18:33 +0000157 scoped_ptr<char[]> buf(new char[size + kNATEncodedIPv6AddressSize]);
henrike@webrtc.org0e118e72013-07-10 00:45:36 +0000158 size_t addrlength = PackAddressForNAT(buf.get(),
159 size + kNATEncodedIPv6AddressSize,
160 addr);
161 size_t encoded_size = size + addrlength;
pbos@webrtc.orgb9518272014-03-07 15:22:04 +0000162 memcpy(buf.get() + addrlength, data, size);
henrike@webrtc.org0e118e72013-07-10 00:45:36 +0000163 int result = socket_->SendTo(buf.get(), encoded_size, server_addr_);
164 if (result >= 0) {
165 ASSERT(result == static_cast<int>(encoded_size));
166 result = result - static_cast<int>(addrlength);
167 }
168 return result;
169 }
170
171 virtual int Recv(void* data, size_t size) {
172 SocketAddress addr;
173 return RecvFrom(data, size, &addr);
174 }
175
176 virtual int RecvFrom(void* data, size_t size, SocketAddress *out_addr) {
177 if (server_addr_.IsNil() || type_ == SOCK_STREAM) {
178 return socket_->RecvFrom(data, size, out_addr);
179 }
180 // Make sure we have enough room to read the requested amount plus the
181 // largest possible header address.
182 SocketAddress remote_addr;
183 Grow(size + kNATEncodedIPv6AddressSize);
184
185 // Read the packet from the socket.
186 int result = socket_->RecvFrom(buf_, size_, &remote_addr);
187 if (result >= 0) {
188 ASSERT(remote_addr == server_addr_);
189
190 // TODO: we need better framing so we know how many bytes we can
191 // return before we need to read the next address. For UDP, this will be
192 // fine as long as the reader always reads everything in the packet.
193 ASSERT((size_t)result < size_);
194
195 // Decode the wire packet into the actual results.
196 SocketAddress real_remote_addr;
197 size_t addrlength =
198 UnpackAddressFromNAT(buf_, result, &real_remote_addr);
pbos@webrtc.orgb9518272014-03-07 15:22:04 +0000199 memcpy(data, buf_ + addrlength, result - addrlength);
henrike@webrtc.org0e118e72013-07-10 00:45:36 +0000200
201 // Make sure this packet should be delivered before returning it.
202 if (!connected_ || (real_remote_addr == remote_addr_)) {
203 if (out_addr)
204 *out_addr = real_remote_addr;
205 result = result - static_cast<int>(addrlength);
206 } else {
207 LOG(LS_ERROR) << "Dropping packet from unknown remote address: "
208 << real_remote_addr.ToString();
209 result = 0; // Tell the caller we didn't read anything
210 }
211 }
212
213 return result;
214 }
215
216 virtual int Close() {
217 int result = 0;
218 if (socket_) {
219 result = socket_->Close();
220 if (result >= 0) {
221 connected_ = false;
222 remote_addr_ = SocketAddress();
223 delete socket_;
224 socket_ = NULL;
225 }
226 }
227 return result;
228 }
229
230 virtual int Listen(int backlog) {
231 return socket_->Listen(backlog);
232 }
233 virtual AsyncSocket* Accept(SocketAddress *paddr) {
234 return socket_->Accept(paddr);
235 }
236 virtual int GetError() const {
237 return socket_->GetError();
238 }
239 virtual void SetError(int error) {
240 socket_->SetError(error);
241 }
242 virtual ConnState GetState() const {
243 return connected_ ? CS_CONNECTED : CS_CLOSED;
244 }
245 virtual int EstimateMTU(uint16* mtu) {
246 return socket_->EstimateMTU(mtu);
247 }
248 virtual int GetOption(Option opt, int* value) {
249 return socket_->GetOption(opt, value);
250 }
251 virtual int SetOption(Option opt, int value) {
252 return socket_->SetOption(opt, value);
253 }
254
255 void OnConnectEvent(AsyncSocket* socket) {
256 // If we're NATed, we need to send a request with the real addr to use.
257 ASSERT(socket == socket_);
258 if (server_addr_.IsNil()) {
259 connected_ = true;
260 SignalConnectEvent(this);
261 } else {
262 SendConnectRequest();
263 }
264 }
265 void OnReadEvent(AsyncSocket* socket) {
266 // If we're NATed, we need to process the connect reply.
267 ASSERT(socket == socket_);
268 if (type_ == SOCK_STREAM && !server_addr_.IsNil() && !connected_) {
269 HandleConnectReply();
270 } else {
271 SignalReadEvent(this);
272 }
273 }
274 void OnWriteEvent(AsyncSocket* socket) {
275 ASSERT(socket == socket_);
276 SignalWriteEvent(this);
277 }
278 void OnCloseEvent(AsyncSocket* socket, int error) {
279 ASSERT(socket == socket_);
280 SignalCloseEvent(this, error);
281 }
282
283 private:
284 // Makes sure the buffer is at least the given size.
285 void Grow(size_t new_size) {
286 if (size_ < new_size) {
287 delete[] buf_;
288 size_ = new_size;
289 buf_ = new char[size_];
290 }
291 }
292
293 // Sends the destination address to the server to tell it to connect.
294 void SendConnectRequest() {
295 char buf[256];
296 size_t length = PackAddressForNAT(buf, ARRAY_SIZE(buf), remote_addr_);
297 socket_->Send(buf, length);
298 }
299
300 // Handles the byte sent back from the server and fires the appropriate event.
301 void HandleConnectReply() {
302 char code;
303 socket_->Recv(&code, sizeof(code));
304 if (code == 0) {
305 SignalConnectEvent(this);
306 } else {
307 Close();
308 SignalCloseEvent(this, code);
309 }
310 }
311
312 NATInternalSocketFactory* sf_;
313 int family_;
314 int type_;
henrike@webrtc.org0e118e72013-07-10 00:45:36 +0000315 bool connected_;
316 SocketAddress remote_addr_;
317 SocketAddress server_addr_; // address of the NAT server
318 AsyncSocket* socket_;
319 char* buf_;
320 size_t size_;
321};
322
323// NATSocketFactory
324NATSocketFactory::NATSocketFactory(SocketFactory* factory,
325 const SocketAddress& nat_addr)
326 : factory_(factory), nat_addr_(nat_addr) {
327}
328
329Socket* NATSocketFactory::CreateSocket(int type) {
330 return CreateSocket(AF_INET, type);
331}
332
333Socket* NATSocketFactory::CreateSocket(int family, int type) {
334 return new NATSocket(this, family, type);
335}
336
337AsyncSocket* NATSocketFactory::CreateAsyncSocket(int type) {
338 return CreateAsyncSocket(AF_INET, type);
339}
340
341AsyncSocket* NATSocketFactory::CreateAsyncSocket(int family, int type) {
342 return new NATSocket(this, family, type);
343}
344
345AsyncSocket* NATSocketFactory::CreateInternalSocket(int family, int type,
346 const SocketAddress& local_addr, SocketAddress* nat_addr) {
347 *nat_addr = nat_addr_;
348 return factory_->CreateAsyncSocket(family, type);
349}
350
351// NATSocketServer
352NATSocketServer::NATSocketServer(SocketServer* server)
353 : server_(server), msg_queue_(NULL) {
354}
355
356NATSocketServer::Translator* NATSocketServer::GetTranslator(
357 const SocketAddress& ext_ip) {
358 return nats_.Get(ext_ip);
359}
360
361NATSocketServer::Translator* NATSocketServer::AddTranslator(
362 const SocketAddress& ext_ip, const SocketAddress& int_ip, NATType type) {
363 // Fail if a translator already exists with this extternal address.
364 if (nats_.Get(ext_ip))
365 return NULL;
366
367 return nats_.Add(ext_ip, new Translator(this, type, int_ip, server_, ext_ip));
368}
369
370void NATSocketServer::RemoveTranslator(
371 const SocketAddress& ext_ip) {
372 nats_.Remove(ext_ip);
373}
374
375Socket* NATSocketServer::CreateSocket(int type) {
376 return CreateSocket(AF_INET, type);
377}
378
379Socket* NATSocketServer::CreateSocket(int family, int type) {
380 return new NATSocket(this, family, type);
381}
382
383AsyncSocket* NATSocketServer::CreateAsyncSocket(int type) {
384 return CreateAsyncSocket(AF_INET, type);
385}
386
387AsyncSocket* NATSocketServer::CreateAsyncSocket(int family, int type) {
388 return new NATSocket(this, family, type);
389}
390
391AsyncSocket* NATSocketServer::CreateInternalSocket(int family, int type,
392 const SocketAddress& local_addr, SocketAddress* nat_addr) {
393 AsyncSocket* socket = NULL;
394 Translator* nat = nats_.FindClient(local_addr);
395 if (nat) {
396 socket = nat->internal_factory()->CreateAsyncSocket(family, type);
397 *nat_addr = (type == SOCK_STREAM) ?
398 nat->internal_tcp_address() : nat->internal_address();
399 } else {
400 socket = server_->CreateAsyncSocket(family, type);
401 }
402 return socket;
403}
404
405// NATSocketServer::Translator
406NATSocketServer::Translator::Translator(
407 NATSocketServer* server, NATType type, const SocketAddress& int_ip,
408 SocketFactory* ext_factory, const SocketAddress& ext_ip)
409 : server_(server) {
410 // Create a new private network, and a NATServer running on the private
411 // network that bridges to the external network. Also tell the private
412 // network to use the same message queue as us.
413 VirtualSocketServer* internal_server = new VirtualSocketServer(server_);
414 internal_server->SetMessageQueue(server_->queue());
415 internal_factory_.reset(internal_server);
416 nat_server_.reset(new NATServer(type, internal_server, int_ip,
417 ext_factory, ext_ip));
418}
419
420
421NATSocketServer::Translator* NATSocketServer::Translator::GetTranslator(
422 const SocketAddress& ext_ip) {
423 return nats_.Get(ext_ip);
424}
425
426NATSocketServer::Translator* NATSocketServer::Translator::AddTranslator(
427 const SocketAddress& ext_ip, const SocketAddress& int_ip, NATType type) {
428 // Fail if a translator already exists with this extternal address.
429 if (nats_.Get(ext_ip))
430 return NULL;
431
432 AddClient(ext_ip);
433 return nats_.Add(ext_ip,
434 new Translator(server_, type, int_ip, server_, ext_ip));
435}
436void NATSocketServer::Translator::RemoveTranslator(
437 const SocketAddress& ext_ip) {
438 nats_.Remove(ext_ip);
439 RemoveClient(ext_ip);
440}
441
442bool NATSocketServer::Translator::AddClient(
443 const SocketAddress& int_ip) {
444 // Fail if a client already exists with this internal address.
445 if (clients_.find(int_ip) != clients_.end())
446 return false;
447
448 clients_.insert(int_ip);
449 return true;
450}
451
452void NATSocketServer::Translator::RemoveClient(
453 const SocketAddress& int_ip) {
454 std::set<SocketAddress>::iterator it = clients_.find(int_ip);
455 if (it != clients_.end()) {
456 clients_.erase(it);
457 }
458}
459
460NATSocketServer::Translator* NATSocketServer::Translator::FindClient(
461 const SocketAddress& int_ip) {
462 // See if we have the requested IP, or any of our children do.
463 return (clients_.find(int_ip) != clients_.end()) ?
464 this : nats_.FindClient(int_ip);
465}
466
467// NATSocketServer::TranslatorMap
468NATSocketServer::TranslatorMap::~TranslatorMap() {
469 for (TranslatorMap::iterator it = begin(); it != end(); ++it) {
470 delete it->second;
471 }
472}
473
474NATSocketServer::Translator* NATSocketServer::TranslatorMap::Get(
475 const SocketAddress& ext_ip) {
476 TranslatorMap::iterator it = find(ext_ip);
477 return (it != end()) ? it->second : NULL;
478}
479
480NATSocketServer::Translator* NATSocketServer::TranslatorMap::Add(
481 const SocketAddress& ext_ip, Translator* nat) {
482 (*this)[ext_ip] = nat;
483 return nat;
484}
485
486void NATSocketServer::TranslatorMap::Remove(
487 const SocketAddress& ext_ip) {
488 TranslatorMap::iterator it = find(ext_ip);
489 if (it != end()) {
490 delete it->second;
491 erase(it);
492 }
493}
494
495NATSocketServer::Translator* NATSocketServer::TranslatorMap::FindClient(
496 const SocketAddress& int_ip) {
497 Translator* nat = NULL;
498 for (TranslatorMap::iterator it = begin(); it != end() && !nat; ++it) {
499 nat = it->second->FindClient(int_ip);
500 }
501 return nat;
502}
503
504} // namespace talk_base