blob: a00d3aae4837f80c1201f030889c3ef9ebad0c0a [file] [log] [blame]
jeremya@chromium.orgf14bfab2013-03-13 13:23:10 +09001// Copyright 2013 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include <sys/socket.h>
6
7#include "base/bind.h"
8#include "base/file_util.h"
9#include "base/files/file_path.h"
10#include "base/path_service.h"
11#include "base/synchronization/waitable_event.h"
12#include "base/threading/thread.h"
13#include "base/threading/thread_restrictions.h"
14#include "ipc/unix_domain_socket_util.h"
15#include "testing/gtest/include/gtest/gtest.h"
16
17namespace {
18
19class SocketAcceptor : public MessageLoopForIO::Watcher {
20 public:
21 SocketAcceptor(int fd, base::MessageLoopProxy* target_thread)
22 : server_fd_(-1),
23 target_thread_(target_thread),
24 started_watching_event_(false, false),
25 accepted_event_(false, false) {
26 target_thread->PostTask(FROM_HERE,
27 base::Bind(&SocketAcceptor::StartWatching, base::Unretained(this), fd));
28 }
29
30 virtual ~SocketAcceptor() {
31 Close();
32 }
33
34 int server_fd() const { return server_fd_; }
35
36 void WaitUntilReady() {
37 started_watching_event_.Wait();
38 }
39
40 void WaitForAccept() {
41 accepted_event_.Wait();
42 }
43
44 void Close() {
45 if (watcher_.get()) {
46 target_thread_->PostTask(FROM_HERE,
47 base::Bind(&SocketAcceptor::StopWatching, base::Unretained(this),
48 watcher_.release()));
49 }
50 }
51
52 private:
53 void StartWatching(int fd) {
54 watcher_.reset(new MessageLoopForIO::FileDescriptorWatcher);
55 MessageLoopForIO::current()->WatchFileDescriptor(
56 fd,
57 true,
58 MessageLoopForIO::WATCH_READ,
59 watcher_.get(),
60 this);
61 started_watching_event_.Signal();
62 }
63 void StopWatching(MessageLoopForIO::FileDescriptorWatcher* watcher) {
64 watcher->StopWatchingFileDescriptor();
65 delete watcher;
66 }
67 virtual void OnFileCanReadWithoutBlocking(int fd) OVERRIDE {
68 ASSERT_EQ(-1, server_fd_);
69 IPC::ServerAcceptConnection(fd, &server_fd_);
70 watcher_->StopWatchingFileDescriptor();
71 accepted_event_.Signal();
72 }
73 virtual void OnFileCanWriteWithoutBlocking(int fd) OVERRIDE {}
74
75 int server_fd_;
76 base::MessageLoopProxy* target_thread_;
77 scoped_ptr<MessageLoopForIO::FileDescriptorWatcher> watcher_;
78 base::WaitableEvent started_watching_event_;
79 base::WaitableEvent accepted_event_;
80
81 DISALLOW_COPY_AND_ASSIGN(SocketAcceptor);
82};
83
84const base::FilePath GetChannelDir() {
85#if defined(OS_ANDROID)
86 base::FilePath tmp_dir;
87 PathService::Get(base::DIR_CACHE, &tmp_dir);
88 return tmp_dir;
89#else
90 return base::FilePath("/var/tmp");
91#endif
92}
93
94class TestUnixSocketConnection {
95 public:
96 TestUnixSocketConnection()
97 : worker_("WorkerThread"),
98 server_listen_fd_(-1),
99 server_fd_(-1),
100 client_fd_(-1) {
101 socket_name_ = GetChannelDir().Append("TestSocket");
102 base::Thread::Options options;
103 options.message_loop_type = MessageLoop::TYPE_IO;
104 worker_.StartWithOptions(options);
105 }
106
107 bool CreateServerSocket() {
108 IPC::CreateServerUnixDomainSocket(socket_name_, &server_listen_fd_);
109 if (server_listen_fd_ < 0)
110 return false;
111 struct stat socket_stat;
112 stat(socket_name_.value().c_str(), &socket_stat);
113 EXPECT_TRUE(S_ISSOCK(socket_stat.st_mode));
114 acceptor_.reset(new SocketAcceptor(server_listen_fd_,
115 worker_.message_loop_proxy()));
116 acceptor_->WaitUntilReady();
117 return true;
118 }
119
120 bool CreateClientSocket() {
121 DCHECK(server_listen_fd_ >= 0);
122 IPC::CreateClientUnixDomainSocket(socket_name_, &client_fd_);
123 if (client_fd_ < 0)
124 return false;
125 acceptor_->WaitForAccept();
126 server_fd_ = acceptor_->server_fd();
127 return server_fd_ >= 0;
128 }
129
130 virtual ~TestUnixSocketConnection() {
131 if (client_fd_ >= 0)
132 close(client_fd_);
133 if (server_fd_ >= 0)
134 close(server_fd_);
135 if (server_listen_fd_ >= 0) {
136 close(server_listen_fd_);
137 unlink(socket_name_.value().c_str());
138 }
139 }
140
141 int client_fd() const { return client_fd_; }
142 int server_fd() const { return server_fd_; }
143
144 private:
145 base::Thread worker_;
146 base::FilePath socket_name_;
147 int server_listen_fd_;
148 int server_fd_;
149 int client_fd_;
150 scoped_ptr<SocketAcceptor> acceptor_;
151};
152
153// Ensure that IPC::CreateServerUnixDomainSocket creates a socket that
154// IPC::CreateClientUnixDomainSocket can successfully connect to.
155TEST(UnixDomainSocketUtil, Connect) {
156 TestUnixSocketConnection connection;
157 ASSERT_TRUE(connection.CreateServerSocket());
158 ASSERT_TRUE(connection.CreateClientSocket());
159}
160
161// Ensure that messages can be sent across the resulting socket.
162TEST(UnixDomainSocketUtil, SendReceive) {
163 TestUnixSocketConnection connection;
164 ASSERT_TRUE(connection.CreateServerSocket());
165 ASSERT_TRUE(connection.CreateClientSocket());
166
167 const char buffer[] = "Hello, server!";
168 size_t buf_len = sizeof(buffer);
169 size_t sent_bytes =
170 HANDLE_EINTR(send(connection.client_fd(), buffer, buf_len, 0));
171 ASSERT_EQ(buf_len, sent_bytes);
172 char recv_buf[sizeof(buffer)];
173 size_t received_bytes =
174 HANDLE_EINTR(recv(connection.server_fd(), recv_buf, buf_len, 0));
175 ASSERT_EQ(buf_len, received_bytes);
176 ASSERT_EQ(0, memcmp(recv_buf, buffer, buf_len));
177}
178
179} // namespace