blob: e9dbd182bbced004c227429a7cb543a827e5ae26 [file] [log] [blame]
Marat Dukhan0a312192015-08-22 17:46:29 -04001/* Standard C headers */
2#include <stdint.h>
3#include <stdbool.h>
Marat Dukhan3a45d9a2015-08-23 22:25:19 -04004#include <stdlib.h>
Marat Dukhan0a312192015-08-22 17:46:29 -04005#include <string.h>
6#include <assert.h>
7
8/* POSIX headers */
9#include <pthread.h>
10#include <unistd.h>
11
12/* Library header */
13#include <pthreadpool.h>
14
15#define PTHREADPOOL_CACHELINE_SIZE 64
16#define PTHREADPOOL_CACHELINE_ALIGNED __attribute__((__aligned__(PTHREADPOOL_CACHELINE_SIZE)))
Marat Dukhanaf6468b2015-08-25 12:16:57 -040017
Marat Dukhana04943a2015-08-25 12:41:05 -040018#if defined(__clang__)
19 #if __has_extension(c_static_assert) || __has_feature(c_static_assert)
20 #define PTHREADPOOL_STATIC_ASSERT(predicate, message) _Static_assert((predicate), message)
21 #else
22 #define PTHREADPOOL_STATIC_ASSERT(predicate, message)
23 #endif
24#elif defined(__GNUC__) && ((__GNUC__ > 4) || (__GNUC__ == 4) && (__GNUC_MINOR__ >= 6))
25 /* Static assert is supported by gcc >= 4.6 */
Marat Dukhanaf6468b2015-08-25 12:16:57 -040026 #define PTHREADPOOL_STATIC_ASSERT(predicate, message) _Static_assert((predicate), message)
27#else
Marat Dukhanaf6468b2015-08-25 12:16:57 -040028 #define PTHREADPOOL_STATIC_ASSERT(predicate, message)
29#endif
Marat Dukhan0a312192015-08-22 17:46:29 -040030
31enum thread_state {
32 thread_state_idle,
33 thread_state_compute_1d,
34 thread_state_shutdown,
35};
36
37struct PTHREADPOOL_CACHELINE_ALIGNED thread_info {
38 /**
39 * Index of the first element in the work range.
40 * Before processing a new element the owning worker thread increments this value.
41 */
42 volatile size_t range_start;
43 /**
44 * Index of the element after the last element of the work range.
45 * Before processing a new element the stealing worker thread decrements this value.
46 */
47 volatile size_t range_end;
48 /**
49 * The number of elements in the work range.
50 * Due to race conditions range_length <= range_end - range_start.
51 * The owning worker thread must decrement this value before incrementing @a range_start.
52 * The stealing worker thread must decrement this value before decrementing @a range_end.
53 */
54 volatile size_t range_length;
55 /**
56 * The active state of the thread.
57 */
58 volatile enum thread_state state;
59 /**
60 * Thread number in the 0..threads_count-1 range.
61 */
62 size_t thread_number;
63 /**
64 * The pthread object corresponding to the thread.
65 */
66 pthread_t thread_object;
67 /**
68 * Condition variable used to wake up the thread.
69 * When the thread is idle, it waits on this condition variable.
70 */
71 pthread_cond_t wakeup_condvar;
72};
73
74PTHREADPOOL_STATIC_ASSERT(sizeof(struct thread_info) % PTHREADPOOL_CACHELINE_SIZE == 0, "thread_info structure must occupy an integer number of cache lines (64 bytes)");
75
76struct PTHREADPOOL_CACHELINE_ALIGNED pthreadpool {
77 /**
78 * The number of threads that signalled completion of an operation.
79 */
80 volatile size_t checkedin_threads;
81 /**
82 * The function to call for each item.
83 */
84 volatile pthreadpool_function_1d_t function;
85 /**
86 * The first argument to the item processing function.
87 */
88 void *volatile argument;
89 /**
90 * Serializes concurrent calls to @a pthreadpool_compute_* from different threads.
91 */
92 pthread_mutex_t execution_mutex;
93 /**
94 * Guards access to the @a checkedin_threads variable.
95 */
96 pthread_mutex_t barrier_mutex;
97 /**
98 * Condition variable to wait until all threads check in.
99 */
100 pthread_cond_t barrier_condvar;
101 /**
102 * Guards access to the @a state variables.
103 */
104 pthread_mutex_t state_mutex;
105 /**
106 * Condition variable to wait for change of @a state variable.
107 */
108 pthread_cond_t state_condvar;
109 /**
110 * The number of threads in the thread pool. Never changes after initialization.
111 */
112 size_t threads_count;
113 /**
114 * Thread information structures that immediately follow this structure.
115 */
116 struct thread_info threads[];
117};
118
119PTHREADPOOL_STATIC_ASSERT(sizeof(struct pthreadpool) % PTHREADPOOL_CACHELINE_SIZE == 0, "pthreadpool structure must occupy an integer number of cache lines (64 bytes)");
120
121static void checkin_worker_thread(struct pthreadpool* threadpool) {
122 pthread_mutex_lock(&threadpool->barrier_mutex);
123 const size_t checkedin_threads = threadpool->checkedin_threads + 1;
124 threadpool->checkedin_threads = checkedin_threads;
125 if (checkedin_threads == threadpool->threads_count) {
126 pthread_cond_signal(&threadpool->barrier_condvar);
127 }
128 pthread_mutex_unlock(&threadpool->barrier_mutex);
129}
130
131static void wait_worker_threads(struct pthreadpool* threadpool) {
132 if (threadpool->checkedin_threads != threadpool->threads_count) {
133 pthread_mutex_lock(&threadpool->barrier_mutex);
134 while (threadpool->checkedin_threads != threadpool->threads_count) {
135 pthread_cond_wait(&threadpool->barrier_condvar, &threadpool->barrier_mutex);
136 };
137 pthread_mutex_unlock(&threadpool->barrier_mutex);
138 }
139}
140
141static void wakeup_worker_threads(struct pthreadpool* threadpool) {
142 pthread_mutex_lock(&threadpool->state_mutex);
143 threadpool->checkedin_threads = 0; /* Locking of barrier_mutex not needed: readers are sleeping */
144 pthread_cond_broadcast(&threadpool->state_condvar);
145 pthread_mutex_unlock(&threadpool->state_mutex); /* Do wake up */
146}
147
148inline static bool atomic_decrement(volatile size_t* value) {
149 size_t actual_value = *value;
150 if (actual_value != 0) {
151 size_t expected_value;
152 do {
153 expected_value = actual_value;
154 const size_t new_value = actual_value - 1;
155 actual_value = __sync_val_compare_and_swap(value, expected_value, new_value);
156 } while ((actual_value != expected_value) && (actual_value != 0));
157 }
158 return actual_value != 0;
159}
160
161static void thread_compute_1d(struct pthreadpool* threadpool, struct thread_info* thread) {
162 const pthreadpool_function_1d_t function = threadpool->function;
163 void *const argument = threadpool->argument;
164 /* Process thread's own range of items */
165 size_t range_start = thread->range_start;
166 while (atomic_decrement(&thread->range_length)) {
167 function(argument, range_start++);
168 }
169 /* Done, now look for other threads' items to steal */
170 const size_t thread_number = thread->thread_number;
171 const size_t threads_count = threadpool->threads_count;
172 for (size_t tid = (thread_number + 1) % threads_count; tid != thread_number; tid = (tid + 1) % threads_count) {
173 struct thread_info* other_thread = &threadpool->threads[tid];
174 if (other_thread->state != thread_state_idle) {
175 while (atomic_decrement(&other_thread->range_length)) {
176 const size_t item_id = __sync_sub_and_fetch(&other_thread->range_end, 1);
177 function(argument, item_id);
178 }
179 }
180 }
181}
182
183static void* thread_main(void* arg) {
184 struct thread_info* thread = (struct thread_info*) arg;
185 struct pthreadpool* threadpool = ((struct pthreadpool*) (thread - thread->thread_number)) - 1;
186
187 /* Check in */
188 checkin_worker_thread(threadpool);
189
190 /* Monitor the state changes and act accordingly */
191 for (;;) {
192 /* Lock the state mutex */
193 pthread_mutex_lock(&threadpool->state_mutex);
194 /* Read the state */
195 enum thread_state state;
196 while ((state = thread->state) == thread_state_idle) {
197 /* Wait for state change */
198 pthread_cond_wait(&threadpool->state_condvar, &threadpool->state_mutex);
199 }
200 /* Read non-idle state */
201 pthread_mutex_unlock(&threadpool->state_mutex);
202 switch (state) {
203 case thread_state_compute_1d:
204 thread_compute_1d(threadpool, thread);
205 break;
206 case thread_state_shutdown:
207 return NULL;
208 case thread_state_idle:
209 /* To inhibit compiler warning */
210 break;
211 }
212 /* Notify the master thread that we finished processing */
213 thread->state = thread_state_idle;
214 checkin_worker_thread(threadpool);
215 };
216}
217
218struct pthreadpool* pthreadpool_create(size_t threads_count) {
219 if (threads_count == 0) {
220 threads_count = (size_t) sysconf(_SC_NPROCESSORS_ONLN);
221 }
Marat Dukhan3a45d9a2015-08-23 22:25:19 -0400222 struct pthreadpool* threadpool = NULL;
223 posix_memalign((void**) &threadpool, 64, sizeof(struct pthreadpool) + threads_count * sizeof(struct thread_info));
Marat Dukhan0a312192015-08-22 17:46:29 -0400224 memset(threadpool, 0, sizeof(struct pthreadpool) + threads_count * sizeof(struct thread_info));
225 threadpool->threads_count = threads_count;
226 pthread_mutex_init(&threadpool->execution_mutex, NULL);
227 pthread_mutex_init(&threadpool->barrier_mutex, NULL);
228 pthread_cond_init(&threadpool->barrier_condvar, NULL);
229 pthread_mutex_init(&threadpool->state_mutex, NULL);
230 pthread_cond_init(&threadpool->state_condvar, NULL);
231
232 for (size_t tid = 0; tid < threads_count; tid++) {
233 threadpool->threads[tid].thread_number = tid;
234 pthread_create(&threadpool->threads[tid].thread_object, NULL, &thread_main, &threadpool->threads[tid]);
235 }
236
237 /* Wait until all threads initialize */
238 wait_worker_threads(threadpool);
239 return threadpool;
240}
241
Marat Dukhan7b1f6e52015-08-25 11:24:08 -0400242size_t pthreadpool_get_threads_count(struct pthreadpool* threadpool) {
Marat Dukhan0a312192015-08-22 17:46:29 -0400243 return threadpool->threads_count;
244}
245
246static inline size_t multiply_divide(size_t a, size_t b, size_t d) {
247 #if defined(__SIZEOF_SIZE_T__) && (__SIZEOF_SIZE_T__ == 4)
248 return (size_t) (((uint64_t) a) * ((uint64_t) b)) / ((uint64_t) d);
249 #elif defined(__SIZEOF_SIZE_T__) && (__SIZEOF_SIZE_T__ == 8)
Marat Dukhanc058bd32015-08-23 22:24:48 -0400250 return (size_t) (((__uint128_t) a) * ((__uint128_t) b)) / ((__uint128_t) d);
Marat Dukhan0a312192015-08-22 17:46:29 -0400251 #else
252 #error "Unsupported platform"
253 #endif
254}
255
256void pthreadpool_compute_1d(
257 struct pthreadpool* threadpool,
258 pthreadpool_function_1d_t function,
259 void* argument,
260 size_t items)
261{
262 /* Protect the global threadpool structures */
263 pthread_mutex_lock(&threadpool->execution_mutex);
264
265 /* Spread the work between threads */
266 for (size_t tid = 0; tid < threadpool->threads_count; tid++) {
267 struct thread_info* thread = &threadpool->threads[tid];
268 thread->range_start = multiply_divide(items, tid, threadpool->threads_count);
269 thread->range_end = multiply_divide(items, tid + 1, threadpool->threads_count);
270 thread->range_length = thread->range_end - thread->range_start;
271 thread->state = thread_state_compute_1d;
272 }
273
274 /* Setup global arguments */
275 threadpool->function = function;
276 threadpool->argument = argument;
277
278 /* Wake up the threads */
279 wakeup_worker_threads(threadpool);
280
281 /* Wait until the threads finish computation */
282 wait_worker_threads(threadpool);
283
284 /* Unprotect the global threadpool structures */
285 pthread_mutex_unlock(&threadpool->execution_mutex);
286}
287
288void pthreadpool_destroy(struct pthreadpool* threadpool) {
289 /* Update threads' states */
290 for (size_t tid = 0; tid < threadpool->threads_count; tid++) {
291 threadpool->threads[tid].state = thread_state_shutdown;
292 }
293
294 /* Wake up the threads */
295 wakeup_worker_threads(threadpool);
296
297 /* Wait until all threads return */
298 for (size_t tid = 0; tid < threadpool->threads_count; tid++) {
299 pthread_join(threadpool->threads[tid].thread_object, NULL);
300 }
301
302 /* Release resources */
303 pthread_mutex_destroy(&threadpool->execution_mutex);
304 pthread_mutex_destroy(&threadpool->barrier_mutex);
305 pthread_cond_destroy(&threadpool->barrier_condvar);
306 pthread_mutex_destroy(&threadpool->state_mutex);
307 pthread_cond_destroy(&threadpool->state_condvar);
308 free(threadpool);
309}