blob: 0dc861be3abc7abd9c399aced553cace3f8df455 [file] [log] [blame]
Dennis Kempinb65b67d2022-03-18 12:38:09 -07001// Copyright 2017 The Chromium OS Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use std::{
6 mem,
7 ops::Deref,
8 os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd},
9 ptr,
10 time::Duration,
11};
12
13use libc::{c_void, eventfd, read, write, POLLIN};
14use serde::{Deserialize, Serialize};
15
16use super::{
17 duration_to_timespec, errno_result, AsRawDescriptor, FromRawDescriptor, IntoRawDescriptor,
18 RawDescriptor, Result, SafeDescriptor,
19};
20use crate::generate_scoped_event;
21
22/// A safe wrapper around a Linux eventfd (man 2 eventfd).
23///
24/// An eventfd is useful because it is sendable across processes and can be used for signaling in
25/// and out of the KVM API. They can also be polled like any other file descriptor.
26#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)]
27#[serde(transparent)]
28pub struct EventFd {
29 event_handle: SafeDescriptor,
30}
31
32/// Wrapper around the return value of doing a read on an EventFd which distinguishes between
33/// getting a valid count of the number of times the eventfd has been written to and timing out
34/// waiting for the count to be non-zero.
35#[derive(Debug, PartialEq, Eq)]
36pub enum EventReadResult {
37 Count(u64),
38 Timeout,
39}
40
41impl EventFd {
42 /// Creates a new blocking EventFd with an initial value of 0.
43 pub fn new() -> Result<EventFd> {
44 // This is safe because eventfd merely allocated an eventfd for our process and we handle
45 // the error case.
46 let ret = unsafe { eventfd(0, 0) };
47 if ret < 0 {
48 return errno_result();
49 }
50 // This is safe because we checked ret for success and know the kernel gave us an fd that we
51 // own.
52 Ok(EventFd {
53 event_handle: unsafe { SafeDescriptor::from_raw_descriptor(ret) },
54 })
55 }
56
57 /// Adds `v` to the eventfd's count, blocking until this won't overflow the count.
58 pub fn write(&self, v: u64) -> Result<()> {
59 // This is safe because we made this fd and the pointer we pass can not overflow because we
60 // give the syscall's size parameter properly.
61 let ret = unsafe {
62 write(
63 self.as_raw_fd(),
64 &v as *const u64 as *const c_void,
65 mem::size_of::<u64>(),
66 )
67 };
68 if ret <= 0 {
69 return errno_result();
70 }
71 Ok(())
72 }
73
74 /// Blocks until the the eventfd's count is non-zero, then resets the count to zero.
75 pub fn read(&self) -> Result<u64> {
76 let mut buf: u64 = 0;
77 let ret = unsafe {
78 // This is safe because we made this fd and the pointer we pass can not overflow because
79 // we give the syscall's size parameter properly.
80 read(
81 self.as_raw_fd(),
82 &mut buf as *mut u64 as *mut c_void,
83 mem::size_of::<u64>(),
84 )
85 };
86 if ret <= 0 {
87 return errno_result();
88 }
89 Ok(buf)
90 }
91
92 /// Blocks for a maximum of `timeout` duration until the the eventfd's count is non-zero. If
93 /// a timeout does not occur then the count is returned as a EventReadResult::Count(count),
94 /// and the count is reset to 0. If a timeout does occur then this function will return
95 /// EventReadResult::Timeout.
Noah Goldc2867722022-03-18 16:04:25 -070096 pub fn read_timeout(&self, timeout: Duration) -> Result<EventReadResult> {
Dennis Kempinb65b67d2022-03-18 12:38:09 -070097 let mut pfd = libc::pollfd {
98 fd: self.as_raw_descriptor(),
99 events: POLLIN,
100 revents: 0,
101 };
102 let timeoutspec: libc::timespec = duration_to_timespec(timeout);
103 // Safe because this only modifies |pfd| and we check the return value
104 let ret = unsafe {
105 libc::ppoll(
106 &mut pfd as *mut libc::pollfd,
107 1,
108 &timeoutspec,
109 ptr::null_mut(),
110 )
111 };
112 if ret < 0 {
113 return errno_result();
114 }
115
116 // no return events (revents) means we got a timeout
117 if pfd.revents == 0 {
118 return Ok(EventReadResult::Timeout);
119 }
120
121 let mut buf = 0u64;
122 // This is safe because we made this fd and the pointer we pass can not overflow because
123 // we give the syscall's size parameter properly.
124 let ret = unsafe {
125 libc::read(
126 self.as_raw_descriptor(),
127 &mut buf as *mut _ as *mut c_void,
128 mem::size_of::<u64>(),
129 )
130 };
131 if ret < 0 {
132 return errno_result();
133 }
134 Ok(EventReadResult::Count(buf))
135 }
136
137 /// Clones this EventFd, internally creating a new file descriptor. The new EventFd will share
138 /// the same underlying count within the kernel.
139 pub fn try_clone(&self) -> Result<EventFd> {
140 self.event_handle
141 .try_clone()
142 .map(|event_handle| EventFd { event_handle })
143 }
144}
145
146impl AsRawFd for EventFd {
147 fn as_raw_fd(&self) -> RawFd {
148 self.event_handle.as_raw_fd()
149 }
150}
151
152impl AsRawDescriptor for EventFd {
153 fn as_raw_descriptor(&self) -> RawDescriptor {
154 self.event_handle.as_raw_descriptor()
155 }
156}
157
158impl FromRawFd for EventFd {
159 unsafe fn from_raw_fd(fd: RawFd) -> Self {
160 EventFd {
161 event_handle: SafeDescriptor::from_raw_descriptor(fd),
162 }
163 }
164}
165
166impl IntoRawFd for EventFd {
167 fn into_raw_fd(self) -> RawFd {
168 self.event_handle.into_raw_descriptor()
169 }
170}
171
172impl From<EventFd> for SafeDescriptor {
173 fn from(evt: EventFd) -> Self {
174 evt.event_handle
175 }
176}
177
178generate_scoped_event!(EventFd);
179
180#[cfg(test)]
181mod tests {
182 use super::*;
183
184 #[test]
185 fn new() {
186 EventFd::new().unwrap();
187 }
188
189 #[test]
190 fn read_write() {
191 let evt = EventFd::new().unwrap();
192 evt.write(55).unwrap();
193 assert_eq!(evt.read(), Ok(55));
194 }
195
196 #[test]
197 fn clone() {
198 let evt = EventFd::new().unwrap();
199 let evt_clone = evt.try_clone().unwrap();
200 evt.write(923).unwrap();
201 assert_eq!(evt_clone.read(), Ok(923));
202 }
203
204 #[test]
205 fn scoped_event() {
206 let scoped_evt = ScopedEvent::new().unwrap();
207 let evt_clone: EventFd = scoped_evt.try_clone().unwrap();
208 drop(scoped_evt);
209 assert_eq!(evt_clone.read(), Ok(1));
210 }
211
212 #[test]
213 fn eventfd_from_scoped_event() {
214 let scoped_evt = ScopedEvent::new().unwrap();
215 let evt: EventFd = scoped_evt.into();
216 evt.write(1).unwrap();
217 }
218
219 #[test]
220 fn timeout() {
Noah Goldc2867722022-03-18 16:04:25 -0700221 let evt = EventFd::new().expect("failed to create eventfd");
Dennis Kempinb65b67d2022-03-18 12:38:09 -0700222 assert_eq!(
223 evt.read_timeout(Duration::from_millis(1))
224 .expect("failed to read from eventfd with timeout"),
225 EventReadResult::Timeout
226 );
227 }
228}