blob: a1816213d473f7412ab71f214534b1601fff18b3 [file] [log] [blame]
Marat Dukhan0a312192015-08-22 17:46:29 -04001/* Standard C headers */
2#include <stdint.h>
3#include <stdbool.h>
Marat Dukhan3a45d9a2015-08-23 22:25:19 -04004#include <stdlib.h>
Marat Dukhan0a312192015-08-22 17:46:29 -04005#include <string.h>
6#include <assert.h>
7
8/* POSIX headers */
9#include <pthread.h>
10#include <unistd.h>
11
12/* Library header */
13#include <pthreadpool.h>
14
15#define PTHREADPOOL_CACHELINE_SIZE 64
16#define PTHREADPOOL_CACHELINE_ALIGNED __attribute__((__aligned__(PTHREADPOOL_CACHELINE_SIZE)))
Marat Dukhanaf6468b2015-08-25 12:16:57 -040017
Marat Dukhana04943a2015-08-25 12:41:05 -040018#if defined(__clang__)
19 #if __has_extension(c_static_assert) || __has_feature(c_static_assert)
20 #define PTHREADPOOL_STATIC_ASSERT(predicate, message) _Static_assert((predicate), message)
21 #else
22 #define PTHREADPOOL_STATIC_ASSERT(predicate, message)
23 #endif
24#elif defined(__GNUC__) && ((__GNUC__ > 4) || (__GNUC__ == 4) && (__GNUC_MINOR__ >= 6))
25 /* Static assert is supported by gcc >= 4.6 */
Marat Dukhanaf6468b2015-08-25 12:16:57 -040026 #define PTHREADPOOL_STATIC_ASSERT(predicate, message) _Static_assert((predicate), message)
27#else
Marat Dukhanaf6468b2015-08-25 12:16:57 -040028 #define PTHREADPOOL_STATIC_ASSERT(predicate, message)
29#endif
Marat Dukhan0a312192015-08-22 17:46:29 -040030
Marat Dukhanad0ca6a2015-10-16 03:15:19 -040031static inline size_t multiply_divide(size_t a, size_t b, size_t d) {
32 #if defined(__SIZEOF_SIZE_T__) && (__SIZEOF_SIZE_T__ == 4)
33 return (size_t) (((uint64_t) a) * ((uint64_t) b)) / ((uint64_t) d);
34 #elif defined(__SIZEOF_SIZE_T__) && (__SIZEOF_SIZE_T__ == 8)
35 return (size_t) (((__uint128_t) a) * ((__uint128_t) b)) / ((__uint128_t) d);
36 #else
37 #error "Unsupported platform"
38 #endif
39}
40
41static inline size_t divide_round_up(size_t dividend, size_t divisor) {
42 if (dividend % divisor == 0) {
43 return dividend / divisor;
44 } else {
45 return dividend / divisor + 1;
46 }
47}
48
49static inline size_t min(size_t a, size_t b) {
50 return a < b ? a : b;
51}
52
Marat Dukhan0a312192015-08-22 17:46:29 -040053enum thread_state {
54 thread_state_idle,
55 thread_state_compute_1d,
56 thread_state_shutdown,
57};
58
59struct PTHREADPOOL_CACHELINE_ALIGNED thread_info {
60 /**
61 * Index of the first element in the work range.
62 * Before processing a new element the owning worker thread increments this value.
63 */
64 volatile size_t range_start;
65 /**
66 * Index of the element after the last element of the work range.
67 * Before processing a new element the stealing worker thread decrements this value.
68 */
69 volatile size_t range_end;
70 /**
71 * The number of elements in the work range.
72 * Due to race conditions range_length <= range_end - range_start.
73 * The owning worker thread must decrement this value before incrementing @a range_start.
74 * The stealing worker thread must decrement this value before decrementing @a range_end.
75 */
76 volatile size_t range_length;
77 /**
78 * The active state of the thread.
79 */
80 volatile enum thread_state state;
81 /**
82 * Thread number in the 0..threads_count-1 range.
83 */
84 size_t thread_number;
85 /**
86 * The pthread object corresponding to the thread.
87 */
88 pthread_t thread_object;
89 /**
90 * Condition variable used to wake up the thread.
91 * When the thread is idle, it waits on this condition variable.
92 */
93 pthread_cond_t wakeup_condvar;
94};
95
96PTHREADPOOL_STATIC_ASSERT(sizeof(struct thread_info) % PTHREADPOOL_CACHELINE_SIZE == 0, "thread_info structure must occupy an integer number of cache lines (64 bytes)");
97
98struct PTHREADPOOL_CACHELINE_ALIGNED pthreadpool {
99 /**
100 * The number of threads that signalled completion of an operation.
101 */
102 volatile size_t checkedin_threads;
103 /**
104 * The function to call for each item.
105 */
Marat Dukhanad0ca6a2015-10-16 03:15:19 -0400106 volatile void* function;
Marat Dukhan0a312192015-08-22 17:46:29 -0400107 /**
108 * The first argument to the item processing function.
109 */
110 void *volatile argument;
111 /**
112 * Serializes concurrent calls to @a pthreadpool_compute_* from different threads.
113 */
114 pthread_mutex_t execution_mutex;
115 /**
116 * Guards access to the @a checkedin_threads variable.
117 */
118 pthread_mutex_t barrier_mutex;
119 /**
120 * Condition variable to wait until all threads check in.
121 */
122 pthread_cond_t barrier_condvar;
123 /**
124 * Guards access to the @a state variables.
125 */
126 pthread_mutex_t state_mutex;
127 /**
128 * Condition variable to wait for change of @a state variable.
129 */
130 pthread_cond_t state_condvar;
131 /**
132 * The number of threads in the thread pool. Never changes after initialization.
133 */
134 size_t threads_count;
135 /**
136 * Thread information structures that immediately follow this structure.
137 */
138 struct thread_info threads[];
139};
140
141PTHREADPOOL_STATIC_ASSERT(sizeof(struct pthreadpool) % PTHREADPOOL_CACHELINE_SIZE == 0, "pthreadpool structure must occupy an integer number of cache lines (64 bytes)");
142
143static void checkin_worker_thread(struct pthreadpool* threadpool) {
144 pthread_mutex_lock(&threadpool->barrier_mutex);
145 const size_t checkedin_threads = threadpool->checkedin_threads + 1;
146 threadpool->checkedin_threads = checkedin_threads;
147 if (checkedin_threads == threadpool->threads_count) {
148 pthread_cond_signal(&threadpool->barrier_condvar);
149 }
150 pthread_mutex_unlock(&threadpool->barrier_mutex);
151}
152
153static void wait_worker_threads(struct pthreadpool* threadpool) {
154 if (threadpool->checkedin_threads != threadpool->threads_count) {
155 pthread_mutex_lock(&threadpool->barrier_mutex);
156 while (threadpool->checkedin_threads != threadpool->threads_count) {
157 pthread_cond_wait(&threadpool->barrier_condvar, &threadpool->barrier_mutex);
158 };
159 pthread_mutex_unlock(&threadpool->barrier_mutex);
160 }
161}
162
163static void wakeup_worker_threads(struct pthreadpool* threadpool) {
164 pthread_mutex_lock(&threadpool->state_mutex);
165 threadpool->checkedin_threads = 0; /* Locking of barrier_mutex not needed: readers are sleeping */
166 pthread_cond_broadcast(&threadpool->state_condvar);
167 pthread_mutex_unlock(&threadpool->state_mutex); /* Do wake up */
168}
169
170inline static bool atomic_decrement(volatile size_t* value) {
171 size_t actual_value = *value;
172 if (actual_value != 0) {
173 size_t expected_value;
174 do {
175 expected_value = actual_value;
176 const size_t new_value = actual_value - 1;
177 actual_value = __sync_val_compare_and_swap(value, expected_value, new_value);
178 } while ((actual_value != expected_value) && (actual_value != 0));
179 }
180 return actual_value != 0;
181}
182
183static void thread_compute_1d(struct pthreadpool* threadpool, struct thread_info* thread) {
Marat Dukhanad0ca6a2015-10-16 03:15:19 -0400184 const pthreadpool_function_1d_t function = (pthreadpool_function_1d_t) threadpool->function;
Marat Dukhan0a312192015-08-22 17:46:29 -0400185 void *const argument = threadpool->argument;
186 /* Process thread's own range of items */
187 size_t range_start = thread->range_start;
188 while (atomic_decrement(&thread->range_length)) {
189 function(argument, range_start++);
190 }
191 /* Done, now look for other threads' items to steal */
192 const size_t thread_number = thread->thread_number;
193 const size_t threads_count = threadpool->threads_count;
194 for (size_t tid = (thread_number + 1) % threads_count; tid != thread_number; tid = (tid + 1) % threads_count) {
195 struct thread_info* other_thread = &threadpool->threads[tid];
196 if (other_thread->state != thread_state_idle) {
197 while (atomic_decrement(&other_thread->range_length)) {
198 const size_t item_id = __sync_sub_and_fetch(&other_thread->range_end, 1);
199 function(argument, item_id);
200 }
201 }
202 }
203}
204
205static void* thread_main(void* arg) {
206 struct thread_info* thread = (struct thread_info*) arg;
207 struct pthreadpool* threadpool = ((struct pthreadpool*) (thread - thread->thread_number)) - 1;
208
209 /* Check in */
210 checkin_worker_thread(threadpool);
211
212 /* Monitor the state changes and act accordingly */
213 for (;;) {
214 /* Lock the state mutex */
215 pthread_mutex_lock(&threadpool->state_mutex);
216 /* Read the state */
217 enum thread_state state;
218 while ((state = thread->state) == thread_state_idle) {
219 /* Wait for state change */
220 pthread_cond_wait(&threadpool->state_condvar, &threadpool->state_mutex);
221 }
222 /* Read non-idle state */
223 pthread_mutex_unlock(&threadpool->state_mutex);
224 switch (state) {
225 case thread_state_compute_1d:
226 thread_compute_1d(threadpool, thread);
227 break;
228 case thread_state_shutdown:
229 return NULL;
230 case thread_state_idle:
231 /* To inhibit compiler warning */
232 break;
233 }
234 /* Notify the master thread that we finished processing */
235 thread->state = thread_state_idle;
236 checkin_worker_thread(threadpool);
237 };
238}
239
240struct pthreadpool* pthreadpool_create(size_t threads_count) {
241 if (threads_count == 0) {
242 threads_count = (size_t) sysconf(_SC_NPROCESSORS_ONLN);
243 }
Marat Dukhan3a45d9a2015-08-23 22:25:19 -0400244 struct pthreadpool* threadpool = NULL;
Marat Dukhanad0ca6a2015-10-16 03:15:19 -0400245 if (posix_memalign((void**) &threadpool, 64, sizeof(struct pthreadpool) + threads_count * sizeof(struct thread_info)) != 0) {
246 return NULL;
247 }
Marat Dukhan0a312192015-08-22 17:46:29 -0400248 memset(threadpool, 0, sizeof(struct pthreadpool) + threads_count * sizeof(struct thread_info));
249 threadpool->threads_count = threads_count;
250 pthread_mutex_init(&threadpool->execution_mutex, NULL);
251 pthread_mutex_init(&threadpool->barrier_mutex, NULL);
252 pthread_cond_init(&threadpool->barrier_condvar, NULL);
253 pthread_mutex_init(&threadpool->state_mutex, NULL);
254 pthread_cond_init(&threadpool->state_condvar, NULL);
255
256 for (size_t tid = 0; tid < threads_count; tid++) {
257 threadpool->threads[tid].thread_number = tid;
258 pthread_create(&threadpool->threads[tid].thread_object, NULL, &thread_main, &threadpool->threads[tid]);
259 }
260
261 /* Wait until all threads initialize */
262 wait_worker_threads(threadpool);
263 return threadpool;
264}
265
Marat Dukhan7b1f6e52015-08-25 11:24:08 -0400266size_t pthreadpool_get_threads_count(struct pthreadpool* threadpool) {
Marat Dukhan0a312192015-08-22 17:46:29 -0400267 return threadpool->threads_count;
268}
269
Marat Dukhan0a312192015-08-22 17:46:29 -0400270void pthreadpool_compute_1d(
271 struct pthreadpool* threadpool,
272 pthreadpool_function_1d_t function,
273 void* argument,
Marat Dukhanad0ca6a2015-10-16 03:15:19 -0400274 size_t range)
Marat Dukhan0a312192015-08-22 17:46:29 -0400275{
Marat Dukhanad0ca6a2015-10-16 03:15:19 -0400276 if (threadpool == NULL) {
277 /* No thread pool provided: execute function sequentially on the calling thread */
278 for (size_t i = 0; i < range; i++) {
279 function(argument, i);
280 }
281 } else {
282 /* Protect the global threadpool structures */
283 pthread_mutex_lock(&threadpool->execution_mutex);
Marat Dukhan0a312192015-08-22 17:46:29 -0400284
Marat Dukhanad0ca6a2015-10-16 03:15:19 -0400285 /* Spread the work between threads */
286 for (size_t tid = 0; tid < threadpool->threads_count; tid++) {
287 struct thread_info* thread = &threadpool->threads[tid];
288 thread->range_start = multiply_divide(range, tid, threadpool->threads_count);
289 thread->range_end = multiply_divide(range, tid + 1, threadpool->threads_count);
290 thread->range_length = thread->range_end - thread->range_start;
291 thread->state = thread_state_compute_1d;
292 }
293
294 /* Setup global arguments */
295 threadpool->function = function;
296 threadpool->argument = argument;
297
298 /* Wake up the threads */
299 wakeup_worker_threads(threadpool);
300
301 /* Wait until the threads finish computation */
302 wait_worker_threads(threadpool);
303
304 /* Unprotect the global threadpool structures */
305 pthread_mutex_unlock(&threadpool->execution_mutex);
Marat Dukhan0a312192015-08-22 17:46:29 -0400306 }
Marat Dukhanad0ca6a2015-10-16 03:15:19 -0400307}
Marat Dukhan0a312192015-08-22 17:46:29 -0400308
Marat Dukhane76282f2015-11-02 17:47:04 -0500309struct compute_1d_tiled_context {
310 pthreadpool_function_1d_tiled_t function;
311 void* argument;
312 size_t range;
313 size_t tile;
314};
315
316static void compute_1d_tiled(const struct compute_1d_tiled_context* context, size_t linear_index) {
317 const size_t tile_index = linear_index;
318 const size_t index = tile_index * context->tile;
319 const size_t tile = min(context->tile, context->range - index);
320 context->function(context->argument, index, tile);
321}
322
323void pthreadpool_compute_1d_tiled(
324 pthreadpool_t threadpool,
325 pthreadpool_function_1d_tiled_t function,
326 void* argument,
327 size_t range,
328 size_t tile)
329{
330 const size_t tile_range = divide_round_up(range, tile);
331 struct compute_1d_tiled_context context = {
332 .function = function,
333 .argument = argument,
334 .range = range,
335 .tile = tile
336 };
337 pthreadpool_compute_1d(threadpool, (pthreadpool_function_1d_t) compute_1d_tiled, &context, tile_range);
338}
339
Marat Dukhanad0ca6a2015-10-16 03:15:19 -0400340struct compute_2d_context {
341 pthreadpool_function_2d_t function;
342 void* argument;
343 size_t range_j;
344};
Marat Dukhan0a312192015-08-22 17:46:29 -0400345
Marat Dukhanad0ca6a2015-10-16 03:15:19 -0400346static void compute_2d(const struct compute_2d_context* context, size_t linear_index) {
347 const size_t range_j = context->range_j;
348 context->function(context->argument, linear_index / range_j, linear_index % range_j);
349}
Marat Dukhan0a312192015-08-22 17:46:29 -0400350
Marat Dukhanad0ca6a2015-10-16 03:15:19 -0400351void pthreadpool_compute_2d(
352 struct pthreadpool* threadpool,
353 pthreadpool_function_2d_t function,
354 void* argument,
355 size_t range_i,
356 size_t range_j)
357{
358 struct compute_2d_context context = {
359 .function = function,
360 .argument = argument,
361 .range_j = range_j
362 };
363 pthreadpool_compute_1d(threadpool, (pthreadpool_function_1d_t) compute_2d, &context, range_i * range_j);
364}
Marat Dukhan0a312192015-08-22 17:46:29 -0400365
Marat Dukhanad0ca6a2015-10-16 03:15:19 -0400366struct compute_2d_tiled_context {
367 pthreadpool_function_2d_tiled_t function;
368 void* argument;
369 size_t tile_range_j;
370 size_t range_i;
371 size_t range_j;
372 size_t tile_i;
373 size_t tile_j;
374};
375
376static void compute_2d_tiled(const struct compute_2d_tiled_context* context, size_t linear_index) {
377 const size_t tile_index_i = linear_index / context->tile_range_j;
378 const size_t tile_index_j = linear_index % context->tile_range_j;
379 const size_t index_i = tile_index_i * context->tile_i;
380 const size_t index_j = tile_index_j * context->tile_j;
381 const size_t tile_i = min(context->tile_i, context->range_i - index_i);
382 const size_t tile_j = min(context->tile_j, context->range_j - index_j);
383 context->function(context->argument, index_i, index_j, tile_i, tile_j);
384}
385
386void pthreadpool_compute_2d_tiled(
387 pthreadpool_t threadpool,
388 pthreadpool_function_2d_tiled_t function,
389 void* argument,
390 size_t range_i,
391 size_t range_j,
392 size_t tile_i,
393 size_t tile_j)
394{
395 const size_t tile_range_i = divide_round_up(range_i, tile_i);
396 const size_t tile_range_j = divide_round_up(range_j, tile_j);
397 struct compute_2d_tiled_context context = {
398 .function = function,
399 .argument = argument,
400 .tile_range_j = tile_range_j,
401 .range_i = range_i,
402 .range_j = range_j,
403 .tile_i = tile_i,
404 .tile_j = tile_j
405 };
406 pthreadpool_compute_1d(threadpool, (pthreadpool_function_1d_t) compute_2d_tiled, &context, tile_range_i * tile_range_j);
Marat Dukhan0a312192015-08-22 17:46:29 -0400407}
408
409void pthreadpool_destroy(struct pthreadpool* threadpool) {
410 /* Update threads' states */
411 for (size_t tid = 0; tid < threadpool->threads_count; tid++) {
412 threadpool->threads[tid].state = thread_state_shutdown;
413 }
414
415 /* Wake up the threads */
416 wakeup_worker_threads(threadpool);
417
418 /* Wait until all threads return */
419 for (size_t tid = 0; tid < threadpool->threads_count; tid++) {
420 pthread_join(threadpool->threads[tid].thread_object, NULL);
421 }
422
423 /* Release resources */
424 pthread_mutex_destroy(&threadpool->execution_mutex);
425 pthread_mutex_destroy(&threadpool->barrier_mutex);
426 pthread_cond_destroy(&threadpool->barrier_condvar);
427 pthread_mutex_destroy(&threadpool->state_mutex);
428 pthread_cond_destroy(&threadpool->state_condvar);
429 free(threadpool);
430}