blob: 8d0ac76179e66b3838a7492f15185bee3176de45 [file] [log] [blame]
// Copyright 2015 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// multi_thread_gemm.h: Multi-threaded GEMM entry point.
// Readers note: To understand this file, it is useful to first
// read and understand the much simpler single_thread_gemm.h.
#ifndef GEMMLOWP_INTERNAL_MULTI_THREAD_GEMM_H_
#define GEMMLOWP_INTERNAL_MULTI_THREAD_GEMM_H_
#include <pthread.h>
#include <unistd.h>
#include <vector>
#include "single_thread_gemm.h"
namespace gemmlowp {
// A BlockingCounter lets one thread to wait for N events to occur.
// This is how the master thread waits for all the worker threads
// to have finished working.
class BlockingCounter {
public:
BlockingCounter()
: cond_(PTHREAD_COND_INITIALIZER),
mutex_(PTHREAD_MUTEX_INITIALIZER),
count_(0) {}
// Sets/resets the counter; initial_count is the number of
// decrementing events that the Wait() call will be waiting for.
void Reset(int initial_count) {
pthread_mutex_lock(&mutex_);
assert(count_ == 0);
count_ = initial_count;
pthread_mutex_unlock(&mutex_);
}
// Decrements the counter; if the counter hits zero, signals
// the thread that was waiting for that, and returns true.
// Otherwise (if the decremented count is still nonzero),
// returns false.
bool DecrementCount() {
pthread_mutex_lock(&mutex_);
assert(count_ > 0);
count_--;
if (count_ == 0) {
pthread_cond_signal(&cond_);
}
bool retval = count_ == 0;
pthread_mutex_unlock(&mutex_);
return retval;
}
// Waits for the N other threads (N having been set by Reset())
// to hit the BlockingCounter.
void Wait() {
ScopedProfilingLabel label("BlockingCounter::Wait");
pthread_mutex_lock(&mutex_);
while (count_) {
pthread_cond_wait(&cond_, &mutex_);
}
pthread_mutex_unlock(&mutex_);
}
private:
pthread_cond_t cond_;
pthread_mutex_t mutex_;
int count_;
};
// A workload for a worker.
struct Task {
Task() : local_allocator(nullptr) {}
virtual ~Task() {}
virtual void Run() const = 0;
Allocator* local_allocator;
};
// A worker thread.
class Worker {
public:
enum class State {
ThreadStartup, // The initial state before the thread main loop runs.
Ready, // Is not working, has not yet received new work to do.
HasWork, // Has work to do.
ExitAsSoonAsPossible // Should exit at earliest convenience.
};
explicit Worker(BlockingCounter* counter_to_decrement_when_ready)
: task_(nullptr),
state_cond_(PTHREAD_COND_INITIALIZER),
state_mutex_(PTHREAD_MUTEX_INITIALIZER),
state_(State::ThreadStartup),
counter_to_decrement_when_ready_(counter_to_decrement_when_ready) {
pthread_create(&thread_, nullptr, ThreadFunc, this);
}
~Worker() {
ChangeState(State::ExitAsSoonAsPossible);
pthread_join(thread_, nullptr);
}
// Changes State; may be called from either the worker thread
// or the master thread; however, not all state transitions are legal,
// which is guarded by assertions.
void ChangeState(State new_state) {
ScopedProfilingLabel label("Worker::ChangeState");
pthread_mutex_lock(&state_mutex_);
assert(new_state != state_);
switch (state_) {
case State::ThreadStartup:
assert(new_state == State::Ready);
break;
case State::Ready:
assert(new_state == State::HasWork ||
new_state == State::ExitAsSoonAsPossible);
break;
case State::HasWork:
assert(new_state == State::Ready ||
new_state == State::ExitAsSoonAsPossible);
break;
default:
abort();
}
state_ = new_state;
pthread_cond_signal(&state_cond_);
if (state_ == State::Ready) {
counter_to_decrement_when_ready_->DecrementCount();
}
pthread_mutex_unlock(&state_mutex_);
}
// Thread entry point.
void ThreadFunc() {
ScopedProfilingLabel label("Worker::ThreadFunc");
RegisterCurrentThreadForProfiling();
ChangeState(State::Ready);
// Thread main loop
while (true) {
// Get a state to act on
pthread_mutex_lock(&state_mutex_);
switch (state_) {
case State::ExitAsSoonAsPossible:
case State::HasWork:
break;
case State::Ready:
// In the 'Ready' state, we have nothing to do but to wait until
// we switch to another state.
while (state_ == State::Ready) {
ScopedProfilingLabel label("Worker::ThreadFunc waiting");
pthread_cond_wait(&state_cond_, &state_mutex_);
}
break;
default:
abort();
}
State state_to_act_upon = state_;
pthread_mutex_unlock(&state_mutex_);
// We now have a state to act on, so act.
switch (state_to_act_upon) {
case State::HasWork:
// Got work to do! So do it, and then revert to 'Ready' state.
assert(task_);
task_->Run();
delete task_;
task_ = nullptr;
ChangeState(State::Ready);
break;
case State::ExitAsSoonAsPossible:
return;
default:
abort();
}
}
}
static void* ThreadFunc(void* arg) {
static_cast<Worker*>(arg)->ThreadFunc();
return nullptr;
}
// Called by the master thead to give this worker work to do.
// It is only legal to call this if the worker
void StartWork(Task* task) {
assert(!task_);
task->local_allocator = &local_allocator_;
task_ = task;
assert(state_ == State::Ready);
ChangeState(State::HasWork);
}
private:
// The underlying thread.
pthread_t thread_;
// The task to be worked on.
const Task* task_;
// The condition variable and mutex guarding state changes.
pthread_cond_t state_cond_;
pthread_mutex_t state_mutex_;
// The state enum tells if we're currently working, waiting for work, etc.
State state_;
// Each thread had a local allocator so they can allocate temporary
// buffers without blocking each other.
Allocator local_allocator_;
// pointer to the master's thread BlockingCounter object, to notify the
// master thread of when this worker switches to the 'Ready' state.
BlockingCounter* const counter_to_decrement_when_ready_;
};
// A very simple pool of workers, that only allows the very
// specific parallelization pattern that we use here:
// a fixed number of workers can be given work, and one then
// waits for all of them to finish.
class WorkersPool {
public:
WorkersPool() {}
~WorkersPool() {
for (auto w : workers_) {
delete w;
}
}
BlockingCounter& counter_to_decrement_when_ready() {
return counter_to_decrement_when_ready_;
}
// Give work to a specific worker.
void StartWorker(int index, Task* task_) {
assert(static_cast<std::size_t>(index) < workers_.size());
workers_[index]->StartWork(task_);
}
// Ensures that the pool has at least the given count of workers.
// If any new worker has to be created, this function waits for it to
// be ready.
void CreateWorkers(std::size_t workers_count) {
if (workers_.size() >= workers_count) {
return;
}
counter_to_decrement_when_ready_.Reset(workers_count - workers_.size());
while (workers_.size() < workers_count) {
workers_.push_back(new Worker(&counter_to_decrement_when_ready_));
}
counter_to_decrement_when_ready_.Wait();
}
private:
// copy construction disallowed
WorkersPool(const WorkersPool&) = delete;
// The workers in this pool. They are owned by the pool:
// the pool creates workers and destroys them in its destructor.
std::vector<Worker*> workers_;
// The BlockingCounter used to wait for the workers.
BlockingCounter counter_to_decrement_when_ready_;
};
// The task we use to implement a multi-threaded Gemm: a block of the
// RHS has been packed by the master thread; each worker thread
// then has to pack a block of the LHS and accumulate the Gemm of these
// packed LHS and RHS blocks.
template <typename KernelFormat, typename Scalar, BitDepthSetting BitDepth,
MapOrder LhsOrder, MapOrder RhsOrder, MapOrder ResultOrder>
struct GemmWithPackedRhsTask : Task {
typedef PackedSideBlock<typename KernelFormat::Lhs>
PackedLhs;
typedef PackedSideBlock<typename KernelFormat::Rhs>
PackedRhs;
GemmWithPackedRhsTask(const KernelBase& _kernel,
const MatrixMap<const Scalar, LhsOrder>& _lhs,
const PackedRhs& _packed_rhs,
MatrixMap<Scalar, ResultOrder>* _result,
int _lhs_offset, int _rhs_offset, int _result_offset,
int _result_mult_int, int _result_shift)
: kernel(_kernel),
lhs(_lhs),
packed_rhs(_packed_rhs),
result(*_result),
lhs_offset(_lhs_offset),
rhs_offset(_rhs_offset),
result_offset(_result_offset),
result_mult_int(_result_mult_int),
result_shift(_result_shift) {}
void Run() const override {
ScopedProfilingLabel label("GemmWithPackedRhsTask");
const int rows = result.rows();
const int cols = result.cols();
const int depth = lhs.cols();
BlockParams block_params;
block_params.Init<KernelFormat>(rows, cols, depth, 1);
PackedLhs packed_lhs(Side::Lhs, local_allocator, block_params, rhs_offset);
PackedResult packed_result(local_allocator, block_params);
local_allocator->Commit();
for (int c = 0; c < cols; c += block_params.l2_cols) {
int cs = std::min(block_params.l2_cols, cols - c);
for (int r = 0; r < rows; r += block_params.l2_rows) {
int rs = std::min(block_params.l2_rows, rows - r);
PackLhs<BitDepth>(&packed_lhs, lhs.block(r, 0, rs, depth));
Compute(kernel, block_params, &packed_result, packed_lhs, packed_rhs);
auto result_block = result.block(r, c, rs, cs);
UnpackResult<BitDepth>(
&result_block, packed_result, depth,
packed_lhs.rank_one_update(), packed_rhs.rank_one_update(),
lhs_offset, rhs_offset, result_offset, result_mult_int,
result_shift);
}
}
local_allocator->Decommit();
}
const KernelBase& kernel;
const MatrixMap<const Scalar, LhsOrder> lhs;
const PackedRhs packed_rhs;
MatrixMap<Scalar, ResultOrder> result;
int lhs_offset;
int rhs_offset;
int result_offset;
int result_mult_int;
int result_shift;
};
class MultiThreadGemmContext : public SingleThreadGemmContext {
public:
MultiThreadGemmContext() : max_num_threads_(0) {}
void set_max_num_threads(int n) { max_num_threads_ = n; }
int max_num_threads() const { return max_num_threads_; }
WorkersPool* workers_pool() { return &workers_pool_; }
protected:
// The workers pool used by MultiThreadGemm. Making
// this part of the context allows it to be persistent,
// avoiding recreating threads on every Gemm.
WorkersPool workers_pool_;
// The maximum number of worker threads to use (in addition
// to the master thread).
// The default value 0 means the default behavior of
// detecting the number of hardware threads. Nonzero values mean
// skipping and overriding hardware detection.
int max_num_threads_;
};
// Determines how many worker threads should be used for a given Gemm
// operation.
template <int KernelRows>
inline int HowManyWorkers(MultiThreadGemmContext* context, int rows, int cols,
int depth) {
// First check if the user set an explicit maximum number of threads.
int max_count = context->max_num_threads();
if (!max_count) {
// No user-set maximum number of threads, so we need to
// do some hardware detection.
// This is expensive to query so we do it only once.
// Too bad for dynamicness. Also, we dont use the c++11 standard getter
// because Google's coding style currently bans #include <thread_>.
static const int hardware_threads_count =
static_cast<int>(sysconf(_SC_NPROCESSORS_CONF));
max_count = hardware_threads_count;
}
// Basic calculation: take into account max pool size, and
// how many rows we have to feed our kernel.
int workers_count = std::min(max_count, CeilQuotient(rows, KernelRows));
// At this point for small products we already have workers_count==1 so
// we can avoid doing more work; otherwise, we still want to check
// that the cubic size (rows*cols*depth) is big enough to keep
// workers_ busy.
if (workers_count > 1) {
// Empirically determined value.
static const int min_cubic_size_per_thread = 256 * 1024;
// We can only multiply two out of three sizes without risking overflow
int cols_times_depth = cols * depth;
if (cols_times_depth < min_cubic_size_per_thread) {
// in that case, we can multiply by rows without risking overflow
int cubic_size = rows * cols_times_depth;
workers_count = std::min(
workers_count, CeilQuotient(cubic_size, min_cubic_size_per_thread));
}
}
assert(workers_count > 0 && workers_count <= max_count);
return workers_count;
}
// The main multi-threaded Gemm function.
// To understand it, first read the code of SingleThreadedGemm().
// The parallelization scheme used here is to have this master function
// pack a block of RHS and then start worker threads to pack a block of LHS
// each, and accumulate the corresponding products.
template <typename KernelFormat, typename Scalar, BitDepthSetting BitDepth,
MapOrder LhsOrder, MapOrder RhsOrder, MapOrder ResultOrder>
void MultiThreadGemm(MultiThreadGemmContext* context, const KernelBase& kernel,
const MatrixMap<const Scalar, LhsOrder>& lhs,
const MatrixMap<const Scalar, RhsOrder>& rhs,
MatrixMap<Scalar, ResultOrder>* result, int lhs_offset,
int rhs_offset, int result_offset, int result_mult_int,
int result_shift) {
ScopedProfilingLabel label("gemmlowp::MultiThreadGemm");
assert(lhs.cols() == rhs.rows());
int rows = result->rows();
int cols = result->cols();
int depth = lhs.cols();
const int workers_count =
HowManyWorkers<KernelFormat::kRows>(context, rows, cols, depth);
if (workers_count == 1) {
return SingleThreadGemm<KernelFormat, Scalar, BitDepth>(
context, kernel, lhs, rhs, result, lhs_offset, rhs_offset,
result_offset, result_mult_int, result_shift);
}
assert(workers_count > 1);
Allocator* allocator = context->allocator();
WorkersPool* workers_pool = context->workers_pool();
workers_pool->CreateWorkers(workers_count);
BlockParams block_params;
block_params.Init<KernelFormat>(rows, cols, depth, workers_count);
PackedSideBlock<typename KernelFormat::Rhs>
packed_rhs(Side::Rhs, allocator, block_params, lhs_offset);
allocator->Commit();
// We loop over large blocks of the RHS.
for (int c = 0; c < cols; c += block_params.l2_cols) {
int cs = std::min(block_params.l2_cols, cols - c);
// Pack a large block of the RHS.
PackRhs<BitDepth>(&packed_rhs, rhs.block(0, c, depth, cs));
// Give work to each worker.
int next_start_row = 0;
workers_pool->counter_to_decrement_when_ready().Reset(workers_count);
for (int w = 0; w < workers_count; w++) {
int start_row = next_start_row;
next_start_row = std::min(
rows, RoundUp<KernelFormat::kRows>(rows * (w + 1) / workers_count));
int block_rows = next_start_row - start_row;
auto lhs_block = lhs.block(start_row, 0, block_rows, depth);
auto result_block = result->block(start_row, c, block_rows, cs);
typedef GemmWithPackedRhsTask<KernelFormat, Scalar, BitDepth, LhsOrder,
RhsOrder, ResultOrder> TaskType;
auto task = new TaskType(kernel, lhs_block, packed_rhs, &result_block,
lhs_offset, rhs_offset, result_offset,
result_mult_int, result_shift);
workers_pool->StartWorker(w, task);
}
// Wait for the workers.
workers_pool->counter_to_decrement_when_ready().Wait();
}
allocator->Decommit();
}
} // namespace gemmlowp
#endif // GEMMLOWP_INTERNAL_MULTI_THREAD_GEMM_H_