Add thread pool class
Added a thread pool class loosely based on google3 code.
Modified the compiler to have a single thread pool instead of creating new threads in ForAll.
Moved barrier to be in top level directory as it is not GC specific code.
Performance Timings:
Reference:
boot.oat: 14.306596s
time mm oat-target:
real 2m33.748s
user 10m23.190s
sys 5m54.140s
Thread pool:
boot.oat: 13.111049s
time mm oat-target:
real 2m29.372s
user 10m3.130s
sys 5m46.290s
The speed increase is probably just noise.
Change-Id: If3c1280cbaa4c7e4361127d064ac744ea12cdf49
diff --git a/build/Android.common.mk b/build/Android.common.mk
index 178af64..f4c7c98 100644
--- a/build/Android.common.mk
+++ b/build/Android.common.mk
@@ -164,6 +164,7 @@
LIBART_COMMON_SRC_FILES := \
src/atomic.cc.arm \
+ src/barrier.cc \
src/check_jni.cc \
src/class_linker.cc \
src/common_throws.cc \
@@ -181,7 +182,6 @@
src/dlmalloc.cc \
src/file.cc \
src/file_linux.cc \
- src/gc/barrier.cc \
src/gc/card_table.cc \
src/gc/heap_bitmap.cc \
src/gc/large_object_space.cc \
@@ -257,6 +257,7 @@
src/stringprintf.cc \
src/thread.cc \
src/thread_list.cc \
+ src/thread_pool.cc \
src/trace.cc \
src/utf.cc \
src/utils.cc \
@@ -394,6 +395,7 @@
test/ReferenceMap/stack_walk_refmap_jni.cc
TEST_COMMON_SRC_FILES := \
+ src/barrier_test.cc \
src/class_linker_test.cc \
src/compiler_test.cc \
src/dex_cache_test.cc \
@@ -418,6 +420,7 @@
src/reference_table_test.cc \
src/runtime_support_test.cc \
src/runtime_test.cc \
+ src/thread_pool_test.cc \
src/utils_test.cc \
src/zip_archive_test.cc \
src/verifier/method_verifier_test.cc \
diff --git a/build/Android.gtest.mk b/build/Android.gtest.mk
index de8c502..6d865dc 100644
--- a/build/Android.gtest.mk
+++ b/build/Android.gtest.mk
@@ -77,7 +77,7 @@
LOCAL_CFLAGS := $(ART_TEST_CFLAGS)
ifeq ($$(art_target_or_host),target)
LOCAL_CFLAGS += $(ART_TARGET_CFLAGS) $(ART_TARGET_DEBUG_CFLAGS)
- LOCAL_SHARED_LIBRARIES += libdl libicuuc libicui18n libnativehelper libstlport libz
+ LOCAL_SHARED_LIBRARIES += libdl libicuuc libicui18n libnativehelper libstlport libz libcutils
LOCAL_STATIC_LIBRARIES += libgtest
LOCAL_MODULE_PATH := $(ART_NATIVETEST_OUT)
include $(BUILD_EXECUTABLE)
@@ -86,6 +86,7 @@
else # host
LOCAL_CFLAGS += $(ART_HOST_CFLAGS) $(ART_HOST_DEBUG_CFLAGS)
LOCAL_SHARED_LIBRARIES += libicuuc-host libicui18n-host libnativehelper libz-host
+ LOCAL_STATIC_LIBRARIES += libcutils
ifeq ($(HOST_OS),darwin)
# Mac OS complains about unresolved symbols if you don't include this.
LOCAL_WHOLE_STATIC_LIBRARIES := libgtest_host
diff --git a/src/atomic_integer.h b/src/atomic_integer.h
index 54d5fd8..adf3e77 100644
--- a/src/atomic_integer.h
+++ b/src/atomic_integer.h
@@ -17,7 +17,8 @@
#ifndef ART_SRC_ATOMIC_INTEGER_H_
#define ART_SRC_ATOMIC_INTEGER_H_
-#include "atomic.h"
+#include "cutils/atomic.h"
+#include "cutils/atomic-inline.h"
namespace art {
@@ -62,6 +63,14 @@
int32_t operator -- (int32_t) {
return android_atomic_dec(&value_);
}
+
+ int32_t operator ++ () {
+ return android_atomic_inc(&value_) + 1;
+ }
+
+ int32_t operator -- () {
+ return android_atomic_dec(&value_) - 1;
+ }
private:
int32_t value_;
};
diff --git a/src/gc/barrier.cc b/src/barrier.cc
similarity index 96%
rename from src/gc/barrier.cc
rename to src/barrier.cc
index aa9433b..9651828 100644
--- a/src/gc/barrier.cc
+++ b/src/barrier.cc
@@ -1,5 +1,5 @@
#include "barrier.h"
-#include "../mutex.h"
+#include "../src/mutex.h"
#include "thread.h"
namespace art {
diff --git a/src/gc/barrier.h b/src/barrier.h
similarity index 94%
rename from src/gc/barrier.h
rename to src/barrier.h
index 207536a..342890b 100644
--- a/src/gc/barrier.h
+++ b/src/barrier.h
@@ -14,10 +14,10 @@
* limitations under the License.
*/
-#ifndef ART_SRC_GC_BARRIER_H_
-#define ART_SRC_GC_BARRIER_H_
+#ifndef ART_SRC_BARRIER_H_
+#define ART_SRC_BARRIER_H_
-#include "../mutex.h"
+#include "../src/mutex.h"
#include "locks.h"
#include "UniquePtr.h"
diff --git a/src/barrier_test.cc b/src/barrier_test.cc
new file mode 100644
index 0000000..43b279e
--- /dev/null
+++ b/src/barrier_test.cc
@@ -0,0 +1,136 @@
+/*
+ * Copyright (C) 2012 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "barrier.h"
+
+#include <string>
+
+#include "atomic_integer.h"
+#include "common_test.h"
+#include "thread_pool.h"
+#include "UniquePtr.h"
+
+namespace art {
+class CheckWaitClosure : public Closure {
+ public:
+ CheckWaitClosure(Barrier* barrier, AtomicInteger* count1, AtomicInteger* count2,
+ AtomicInteger* count3)
+ : barrier_(barrier),
+ count1_(count1),
+ count2_(count2),
+ count3_(count3) {
+
+ }
+
+ void Run(Thread* self) {
+ LOG(INFO) << "Before barrier 1 " << self;
+ ++*count1_;
+ barrier_->Wait(self);
+ ++*count2_;
+ LOG(INFO) << "Before barrier 2 " << self;
+ barrier_->Wait(self);
+ ++*count3_;
+ LOG(INFO) << "After barrier 2 " << self;
+ delete this;
+ }
+ private:
+ Barrier* const barrier_;
+ AtomicInteger* const count1_;
+ AtomicInteger* const count2_;
+ AtomicInteger* const count3_;
+};
+
+class BarrierTest : public CommonTest {
+ public:
+ static int32_t num_threads;
+};
+
+int32_t BarrierTest::num_threads = 4;
+
+// Check that barrier wait and barrier increment work.
+TEST_F(BarrierTest, CheckWait) {
+ Thread* self = Thread::Current();
+ ThreadPool thread_pool(num_threads);
+ Barrier barrier;
+ AtomicInteger count1 = 0;
+ AtomicInteger count2 = 0;
+ AtomicInteger count3 = 0;
+ for (int32_t i = 0; i < num_threads; ++i) {
+ thread_pool.AddTask(self, new CheckWaitClosure(&barrier, &count1, &count2, &count3));
+ }
+ thread_pool.StartWorkers(self);
+ barrier.Increment(self, num_threads);
+ // At this point each thread should have passed through the barrier. The first count should be
+ // equal to num_threads.
+ EXPECT_EQ(num_threads, count1);
+ // Count 3 should still be zero since no thread should have gone past the second barrier.
+ EXPECT_EQ(0, count3);
+ // Now lets tell the threads to pass again.
+ barrier.Increment(self, num_threads);
+ // Count 2 should be equal to num_threads since each thread must have passed the second barrier
+ // at this point.
+ EXPECT_EQ(num_threads, count2);
+ // Wait for all the threads to finish.
+ thread_pool.Wait(self);
+ // All three counts should be equal to num_threads now.
+ EXPECT_EQ(count1, count2);
+ EXPECT_EQ(count2, count3);
+ EXPECT_EQ(num_threads, count3);
+}
+
+class CheckPassClosure : public Closure {
+ public:
+ CheckPassClosure(Barrier* barrier, AtomicInteger* count, size_t subtasks)
+ : barrier_(barrier),
+ count_(count),
+ subtasks_(subtasks) {
+
+ }
+
+ void Run(Thread* self) {
+ for (size_t i = 0; i < subtasks_; ++i) {
+ ++*count_;
+ // Pass through to next subtask.
+ barrier_->Pass(self);
+ }
+ delete this;
+ }
+ private:
+ Barrier* const barrier_;
+ AtomicInteger* const count_;
+ const size_t subtasks_;
+};
+
+// Check that barrier pass through works.
+TEST_F(BarrierTest, CheckPass) {
+ Thread* self = Thread::Current();
+ ThreadPool thread_pool(num_threads);
+ Barrier barrier;
+ AtomicInteger count = 0;
+ const int32_t num_tasks = num_threads * 4;
+ const int32_t num_sub_tasks = 128;
+ for (int32_t i = 0; i < num_tasks; ++i) {
+ thread_pool.AddTask(self, new CheckPassClosure(&barrier, &count, num_sub_tasks));
+ }
+ thread_pool.StartWorkers(self);
+ const int32_t expected_total_tasks = num_sub_tasks * num_tasks;
+ // Wait for all the tasks to complete using the barrier.
+ barrier.Increment(self, expected_total_tasks);
+ // The total number of completed tasks should be equal to expected_total_tasks.
+ EXPECT_EQ(count, expected_total_tasks);
+}
+
+} // namespace art
diff --git a/src/closure.h b/src/closure.h
new file mode 100644
index 0000000..17f2b84
--- /dev/null
+++ b/src/closure.h
@@ -0,0 +1,32 @@
+/*
+ * Copyright (C) 2012 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef ART_SRC_CLOSURE_H_
+#define ART_SRC_CLOSURE_H_
+
+namespace art {
+
+class Thread;
+
+class Closure {
+ public:
+ virtual ~Closure() { }
+ virtual void Run(Thread* self) = 0;
+};
+
+} // namespace art
+
+#endif // ART_SRC_CLOSURE_H_
diff --git a/src/compiler.cc b/src/compiler.cc
index 8d7f5b6..ddf9e87 100644
--- a/src/compiler.cc
+++ b/src/compiler.cc
@@ -35,6 +35,7 @@
#include "ScopedLocalRef.h"
#include "stl_util.h"
#include "thread.h"
+#include "thread_pool.h"
#include "timing_logger.h"
#include "verifier/method_verifier.h"
@@ -313,7 +314,8 @@
compiler_(NULL),
compiler_context_(NULL),
jni_compiler_(NULL),
- create_invoke_stub_(NULL)
+ create_invoke_stub_(NULL),
+ thread_pool_(new ThreadPool(thread_count))
{
std::string compiler_so_name(MakeCompilerSoName(instruction_set_));
compiler_library_ = dlopen(compiler_so_name.c_str(), RTLD_LAZY);
@@ -980,14 +982,18 @@
class CompilationContext {
public:
+ typedef void Callback(const CompilationContext* context, size_t index);
+
CompilationContext(ClassLinker* class_linker,
jobject class_loader,
Compiler* compiler,
- const DexFile* dex_file)
+ const DexFile* dex_file,
+ ThreadPool* thread_pool)
: class_linker_(class_linker),
class_loader_(class_loader),
compiler_(compiler),
- dex_file_(dex_file) {}
+ dex_file_(dex_file),
+ thread_pool_(thread_pool) {}
ClassLinker* GetClassLinker() const {
CHECK(class_linker_ != NULL);
@@ -1008,96 +1014,64 @@
return dex_file_;
}
+ void ForAll(size_t begin, size_t end, Callback callback, size_t work_units) {
+ Thread* self = Thread::Current();
+ self->AssertNoPendingException();
+ CHECK_GT(work_units, 0U);
+
+ std::vector<Closure*> closures(work_units);
+ for (size_t i = 0; i < work_units; ++i) {
+ closures[i] = new ForAllClosure(this, begin + i, end, callback, work_units);
+ thread_pool_->AddTask(self, closures[i]);
+ }
+ thread_pool_->StartWorkers(self);
+
+ // Ensure we're suspended while we're blocked waiting for the other threads to finish (worker
+ // thread destructor's called below perform join).
+ CHECK_NE(self->GetState(), kRunnable);
+
+ // Wait for all the worker threads to finish.
+ thread_pool_->Wait(self);
+
+ STLDeleteElements(&closures);
+ }
+
private:
+
+ class ForAllClosure : public Closure {
+ public:
+ ForAllClosure(CompilationContext* context, size_t begin, size_t end, Callback* callback,
+ size_t stripe)
+ : context_(context),
+ begin_(begin),
+ end_(end),
+ callback_(callback),
+ stripe_(stripe)
+ {
+
+ }
+
+ virtual void Run(Thread* self) {
+ for (size_t i = begin_; i < end_; i += stripe_) {
+ callback_(context_, i);
+ self->AssertNoPendingException();
+ }
+ }
+ private:
+ CompilationContext* const context_;
+ const size_t begin_;
+ const size_t end_;
+ const Callback* callback_;
+ const size_t stripe_;
+ };
+
ClassLinker* const class_linker_;
const jobject class_loader_;
Compiler* const compiler_;
const DexFile* const dex_file_;
+ ThreadPool* thread_pool_;
};
-typedef void Callback(const CompilationContext* context, size_t index);
-
-static void ForAll(CompilationContext* context, size_t begin, size_t end, Callback callback,
- size_t thread_count);
-
-class WorkerThread {
- public:
- WorkerThread(CompilationContext* context, size_t begin, size_t end, Callback callback, size_t stripe, bool spawn)
- : spawn_(spawn), context_(context), begin_(begin), end_(end), callback_(callback), stripe_(stripe) {
- if (spawn_) {
- // Mac OS stacks are only 512KiB. Make sure we have the same stack size on all platforms.
- pthread_attr_t attr;
- CHECK_PTHREAD_CALL(pthread_attr_init, (&attr), "new compiler worker thread");
- CHECK_PTHREAD_CALL(pthread_attr_setstacksize, (&attr, 1*MB), "new compiler worker thread");
- CHECK_PTHREAD_CALL(pthread_create, (&pthread_, &attr, &Go, this), "new compiler worker thread");
- CHECK_PTHREAD_CALL(pthread_attr_destroy, (&attr), "new compiler worker thread");
- }
- }
-
- ~WorkerThread() {
- if (spawn_) {
- CHECK_PTHREAD_CALL(pthread_join, (pthread_, NULL), "compiler worker shutdown");
- }
- }
-
- private:
- static void* Go(void* arg) LOCKS_EXCLUDED(Locks::mutator_lock_) {
- WorkerThread* worker = reinterpret_cast<WorkerThread*>(arg);
- Runtime* runtime = Runtime::Current();
- if (worker->spawn_) {
- CHECK(runtime->AttachCurrentThread("Compiler Worker", true, NULL));
- }
- worker->Run();
- if (worker->spawn_) {
- runtime->DetachCurrentThread();
- }
- return NULL;
- }
-
- void Go() LOCKS_EXCLUDED(Locks::mutator_lock_) {
- Go(this);
- }
-
- void Run() LOCKS_EXCLUDED(Locks::mutator_lock_) {
- Thread* self = Thread::Current();
- for (size_t i = begin_; i < end_; i += stripe_) {
- callback_(context_, i);
- self->AssertNoPendingException();
- }
- }
-
- pthread_t pthread_;
- // Was this thread spawned or is it the main thread?
- const bool spawn_;
-
- const CompilationContext* const context_;
- const size_t begin_;
- const size_t end_;
- Callback* callback_;
- const size_t stripe_;
-
- friend void ForAll(CompilationContext*, size_t, size_t, Callback, size_t);
-};
-
-static void ForAll(CompilationContext* context, size_t begin, size_t end, Callback callback,
- size_t thread_count)
- LOCKS_EXCLUDED(Locks::mutator_lock_) {
- Thread* self = Thread::Current();
- self->AssertNoPendingException();
- CHECK_GT(thread_count, 0U);
-
- std::vector<WorkerThread*> threads;
- for (size_t i = 0; i < thread_count; ++i) {
- threads.push_back(new WorkerThread(context, begin + i, end, callback, thread_count, (i != 0)));
- }
- threads[0]->Go();
-
- // Ensure we're suspended while we're blocked waiting for the other threads to finish (worker
- // thread destructor's called below perform join).
- CHECK_NE(self->GetState(), kRunnable);
- STLDeleteElements(&threads);
-}
-
// Return true if the class should be skipped during compilation. We
// never skip classes in the boot class loader. However, if we have a
// non-boot class loader and we can resolve the class in the boot
@@ -1216,11 +1190,11 @@
// TODO: we could resolve strings here, although the string table is largely filled with class
// and method names.
- CompilationContext context(class_linker, class_loader, this, &dex_file);
- ForAll(&context, 0, dex_file.NumTypeIds(), ResolveType, thread_count_);
+ CompilationContext context(class_linker, class_loader, this, &dex_file, thread_pool_.get());
+ context.ForAll(0, dex_file.NumTypeIds(), ResolveType, thread_count_);
timings.AddSplit("Resolve " + dex_file.GetLocation() + " Types");
- ForAll(&context, 0, dex_file.NumClassDefs(), ResolveClassFieldsAndMethods, thread_count_);
+ context.ForAll(0, dex_file.NumClassDefs(), ResolveClassFieldsAndMethods, thread_count_);
timings.AddSplit("Resolve " + dex_file.GetLocation() + " MethodsAndFields");
}
@@ -1281,8 +1255,8 @@
void Compiler::VerifyDexFile(jobject class_loader, const DexFile& dex_file, TimingLogger& timings) {
ClassLinker* class_linker = Runtime::Current()->GetClassLinker();
- CompilationContext context(class_linker, class_loader, this, &dex_file);
- ForAll(&context, 0, dex_file.NumClassDefs(), VerifyClass, thread_count_);
+ CompilationContext context(class_linker, class_loader, this, &dex_file, thread_pool_.get());
+ context.ForAll(0, dex_file.NumClassDefs(), VerifyClass, thread_count_);
timings.AddSplit("Verify " + dex_file.GetLocation());
}
@@ -1326,8 +1300,8 @@
void Compiler::InitializeClassesWithoutClinit(jobject jni_class_loader, const DexFile& dex_file,
TimingLogger& timings) {
ClassLinker* class_linker = Runtime::Current()->GetClassLinker();
- CompilationContext context(class_linker, jni_class_loader, this, &dex_file);
- ForAll(&context, 0, dex_file.NumClassDefs(), InitializeClassWithoutClinit, thread_count_);
+ CompilationContext context(class_linker, jni_class_loader, this, &dex_file, thread_pool_.get());
+ context.ForAll(0, dex_file.NumClassDefs(), InitializeClassWithoutClinit, thread_count_);
timings.AddSplit("InitializeNoClinit " + dex_file.GetLocation());
}
@@ -1416,8 +1390,8 @@
void Compiler::CompileDexFile(jobject class_loader, const DexFile& dex_file,
TimingLogger& timings) {
- CompilationContext context(NULL, class_loader, this, &dex_file);
- ForAll(&context, 0, dex_file.NumClassDefs(), Compiler::CompileClass, thread_count_);
+ CompilationContext context(NULL, class_loader, this, &dex_file, thread_pool_.get());
+ context.ForAll(0, dex_file.NumClassDefs(), Compiler::CompileClass, thread_count_);
timings.AddSplit("Compile " + dex_file.GetLocation());
}
diff --git a/src/compiler.h b/src/compiler.h
index 5e9dbd7..20e608d 100644
--- a/src/compiler.h
+++ b/src/compiler.h
@@ -32,6 +32,7 @@
#include "object.h"
#include "runtime.h"
#include "safe_map.h"
+#include "thread_pool.h"
namespace art {
@@ -365,6 +366,8 @@
const char* shorty, uint32_t shorty_len);
CreateInvokeStubFn create_invoke_stub_;
+ UniquePtr<ThreadPool> thread_pool_;
+
pthread_key_t tls_key_;
#if defined(ART_USE_LLVM_COMPILER)
diff --git a/src/gc/mark_sweep.cc b/src/gc/mark_sweep.cc
index 0869e26..e93eb1a 100644
--- a/src/gc/mark_sweep.cc
+++ b/src/gc/mark_sweep.cc
@@ -527,7 +527,7 @@
Thread* self;
};
-class CheckpointMarkThreadRoots : public Thread::CheckpointFunction {
+class CheckpointMarkThreadRoots : public Closure {
public:
CheckpointMarkThreadRoots(MarkSweep* mark_sweep) : mark_sweep_(mark_sweep) {
@@ -536,7 +536,8 @@
virtual void Run(Thread* thread) NO_THREAD_SAFETY_ANALYSIS {
// Note: self is not necessarily equal to thread since thread may be suspended.
Thread* self = Thread::Current();
- DCHECK(thread == self || thread->IsSuspended() || thread->GetState() == kWaitingPerformingGc);
+ DCHECK(thread == self || thread->IsSuspended() || thread->GetState() == kWaitingPerformingGc)
+ << thread->GetState();
WriterMutexLock mu(self, *Locks::heap_bitmap_lock_);
thread->VisitRoots(MarkSweep::MarkObjectVisitor, mark_sweep_);
mark_sweep_->GetBarrier().Pass(self);
diff --git a/src/image_test.cc b/src/image_test.cc
index afccb4a..e2abbac 100644
--- a/src/image_test.cc
+++ b/src/image_test.cc
@@ -72,6 +72,9 @@
ASSERT_GE(sizeof(image_header) + space->Size(), static_cast<size_t>(file->Length()));
}
+ // Need to delete the compiler since it has worker threads which are attached to runtime.
+ delete compiler_.release();
+
// tear down old runtime before making a new one, clearing out misc state
delete runtime_.release();
java_lang_dex_file_ = NULL;
diff --git a/src/thread.cc b/src/thread.cc
index 67773a5..0551844 100644
--- a/src/thread.cc
+++ b/src/thread.cc
@@ -555,7 +555,7 @@
}
}
-bool Thread::RequestCheckpoint(CheckpointFunction* function) {
+bool Thread::RequestCheckpoint(Closure* function) {
CHECK(!ReadFlag(kCheckpointRequest)) << "Already have a pending checkpoint request";
checkpoint_function_ = function;
union StateAndFlags old_state_and_flags = state_and_flags_;
diff --git a/src/thread.h b/src/thread.h
index abfd719..8dbfb55 100644
--- a/src/thread.h
+++ b/src/thread.h
@@ -25,6 +25,7 @@
#include <string>
#include <vector>
+#include "closure.h"
#include "globals.h"
#include "macros.h"
#include "oat/runtime/oat_support_entrypoints.h"
@@ -106,12 +107,6 @@
class PACKED Thread {
public:
- class CheckpointFunction {
- public:
- virtual ~CheckpointFunction() { }
- virtual void Run(Thread* self) = 0;
- };
-
// Space to throw a StackOverflowError in.
#if !defined(ART_USE_LLVM_COMPILER)
static const size_t kStackOverflowReservedBytes = 4 * KB;
@@ -183,7 +178,7 @@
void ModifySuspendCount(Thread* self, int delta, bool for_debugger)
EXCLUSIVE_LOCKS_REQUIRED(Locks::thread_suspend_count_lock_);
- bool RequestCheckpoint(CheckpointFunction* function);
+ bool RequestCheckpoint(Closure* function);
// Called when thread detected that the thread_suspend_count_ was non-zero. Gives up share of
// mutator_lock_ and waits until it is resumed and thread_suspend_count_ is zero.
@@ -776,7 +771,7 @@
const char* last_no_thread_suspension_cause_;
// Pending checkpoint functions.
- CheckpointFunction* checkpoint_function_;
+ Closure* checkpoint_function_;
public:
// Runtime support function pointers
diff --git a/src/thread_list.cc b/src/thread_list.cc
index 4ad25ae..4b2e17f 100644
--- a/src/thread_list.cc
+++ b/src/thread_list.cc
@@ -151,7 +151,7 @@
}
#endif
-size_t ThreadList::RunCheckpoint(Thread::CheckpointFunction* checkpoint_function) {
+size_t ThreadList::RunCheckpoint(Closure* checkpoint_function) {
Thread* self = Thread::Current();
if (kIsDebugBuild) {
Locks::mutator_lock_->AssertNotHeld(self);
diff --git a/src/thread_list.h b/src/thread_list.h
index a41fa57..d64183b 100644
--- a/src/thread_list.h
+++ b/src/thread_list.h
@@ -57,7 +57,7 @@
// Run a checkpoint on threads, running threads are not suspended but run the checkpoint inside
// of the suspend check. Returns how many checkpoints we should expect to run.
- size_t RunCheckpoint(Thread::CheckpointFunction* checkpoint_function);
+ size_t RunCheckpoint(Closure* checkpoint_function);
LOCKS_EXCLUDED(Locks::thread_list_lock_,
Locks::thread_suspend_count_lock_);
diff --git a/src/thread_pool.cc b/src/thread_pool.cc
new file mode 100644
index 0000000..fa0cf79
--- /dev/null
+++ b/src/thread_pool.cc
@@ -0,0 +1,124 @@
+#include "runtime.h"
+#include "stl_util.h"
+#include "thread.h"
+#include "thread_pool.h"
+
+namespace art {
+
+ThreadPoolWorker::ThreadPoolWorker(ThreadPool* thread_pool, const std::string& name,
+ size_t stack_size)
+ : thread_pool_(thread_pool),
+ name_(name),
+ stack_size_(stack_size) {
+ const char* reason = "new thread pool worker thread";
+ CHECK_PTHREAD_CALL(pthread_attr_init, (&attr), reason);
+ CHECK_PTHREAD_CALL(pthread_attr_setstacksize, (&attr, stack_size), reason);
+ CHECK_PTHREAD_CALL(pthread_create, (&pthread_, &attr, &Callback, this), reason);
+ CHECK_PTHREAD_CALL(pthread_attr_destroy, (&attr), reason);
+}
+
+ThreadPoolWorker::~ThreadPoolWorker() {
+ CHECK_PTHREAD_CALL(pthread_join, (pthread_, NULL), "thread pool worker shutdown");
+}
+
+void ThreadPoolWorker::Run() {
+ Thread* self = Thread::Current();
+ Closure* task = NULL;
+ while ((task = thread_pool_->GetTask(self)) != NULL) {
+ task->Run(self);
+ }
+}
+
+void* ThreadPoolWorker::Callback(void* arg) {
+ ThreadPoolWorker* worker = reinterpret_cast<ThreadPoolWorker*>(arg);
+ Runtime* runtime = Runtime::Current();
+ CHECK(runtime->AttachCurrentThread(worker->name_.c_str(), true, NULL));
+ // Do work until its time to shut down.
+ worker->Run();
+ runtime->DetachCurrentThread();
+ return NULL;
+}
+
+void ThreadPool::AddTask(Thread* self, Closure* task){
+ MutexLock mu(self, task_queue_lock_);
+ tasks_.push_back(task);
+ // If we have any waiters, signal one.
+ if (waiting_count_ != 0) {
+ task_queue_condition_.Signal(self);
+ }
+}
+
+void ThreadPool::AddThread(size_t stack_size) {
+ threads_.push_back(
+ new ThreadPoolWorker(
+ this,
+ StringPrintf("Thread pool worker %d", static_cast<int>(GetThreadCount())),
+ stack_size));
+}
+
+ThreadPool::ThreadPool(size_t num_threads)
+ : task_queue_lock_("task queue lock"),
+ task_queue_condition_("task queue condition", task_queue_lock_),
+ completion_condition_("task completion condition", task_queue_lock_),
+ started_(false),
+ shutting_down_(false),
+ waiting_count_(0) {
+ while (GetThreadCount() < num_threads) {
+ AddThread(ThreadPoolWorker::kDefaultStackSize);
+ }
+}
+
+ThreadPool::~ThreadPool() {
+ // Tell any remaining workers to shut down.
+ shutting_down_ = true;
+ android_memory_barrier();
+ // Broadcast to everyone waiting.
+ task_queue_condition_.Broadcast(Thread::Current());
+ // Wait for the threads to finish.
+ STLDeleteElements(&threads_);
+}
+
+void ThreadPool::StartWorkers(Thread* self) {
+ MutexLock mu(self, task_queue_lock_);
+ started_ = true;
+ android_memory_barrier();
+ task_queue_condition_.Broadcast(self);
+}
+
+void ThreadPool::StopWorkers(Thread* self) {
+ MutexLock mu(self, task_queue_lock_);
+ started_ = false;
+ android_memory_barrier();
+}
+
+Closure* ThreadPool::GetTask(Thread* self) {
+ MutexLock mu(self, task_queue_lock_);
+ while (!shutting_down_) {
+ if (started_ && !tasks_.empty()) {
+ Closure* task = tasks_.front();
+ tasks_.pop_front();
+ return task;
+ }
+
+ waiting_count_++;
+ if (waiting_count_ == GetThreadCount() && tasks_.empty()) {
+ // We may be done, lets broadcast to the completion condition.
+ completion_condition_.Broadcast(self);
+ }
+ task_queue_condition_.Wait(self);
+ waiting_count_--;
+ }
+
+ // We are shutting down, return NULL to tell the worker thread to stop looping.
+ return NULL;
+}
+
+void ThreadPool::Wait(Thread* self) {
+ MutexLock mu(self, task_queue_lock_);
+ // Wait until each thread is waiting and the task list is empty.
+ while (waiting_count_ != GetThreadCount() || !tasks_.empty()) {
+ completion_condition_.Wait(self);
+ }
+}
+
+} // namespace art
diff --git a/src/thread_pool.h b/src/thread_pool.h
new file mode 100644
index 0000000..22e30b7
--- /dev/null
+++ b/src/thread_pool.h
@@ -0,0 +1,103 @@
+/*
+ * Copyright (C) 2012 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef ART_SRC_THREAD_POOL_H_
+#define ART_SRC_THREAD_POOL_H_
+
+#include <deque>
+#include <vector>
+
+#include "locks.h"
+#include "../src/mutex.h"
+
+namespace art {
+
+class Closure;
+class ThreadPool;
+
+class ThreadPoolWorker {
+ public:
+ static const size_t kDefaultStackSize = 1 * MB;
+
+ size_t GetStackSize() const {
+ return stack_size_;
+ }
+
+ virtual ~ThreadPoolWorker();
+
+ private:
+ ThreadPoolWorker(ThreadPool* thread_pool, const std::string& name, size_t stack_size);
+ static void* Callback(void* arg) LOCKS_EXCLUDED(Locks::mutator_lock_);
+ void Run();
+
+ ThreadPool* thread_pool_;
+ const std::string name_;
+ const size_t stack_size_;
+ pthread_t pthread_;
+ pthread_attr_t attr;
+
+ friend class ThreadPool;
+ DISALLOW_COPY_AND_ASSIGN(ThreadPoolWorker);
+};
+
+class ThreadPool {
+ public:
+ // Returns the number of threads in the thread pool.
+ size_t GetThreadCount() const {
+ return threads_.size();
+ }
+
+ // Broadcast to the workers and tell them to empty out the work queue.
+ void StartWorkers(Thread* self);
+
+ // Do not allow workers to grab any new tasks.
+ void StopWorkers(Thread* self);
+
+ // Add a new task, the first available started worker will process it. Does not delete the task
+ // after running it, it is the caller's responsibility.
+ void AddTask(Thread* self, Closure* task);
+
+ ThreadPool(size_t num_threads);
+ virtual ~ThreadPool();
+
+ // Wait for all tasks currently on queue to get completed.
+ void Wait(Thread* self);
+
+ private:
+ // Add a new task.
+ void AddThread(size_t stack_size);
+
+ // Get a task to run, blocks if there are no tasks left
+ Closure* GetTask(Thread* self);
+
+ Mutex task_queue_lock_;
+ ConditionVariable task_queue_condition_ GUARDED_BY(task_queue_lock_);
+ ConditionVariable completion_condition_ GUARDED_BY(task_queue_lock_);
+ volatile bool started_ GUARDED_BY(task_queue_lock_);
+ volatile bool shutting_down_ GUARDED_BY(task_queue_lock_);
+ // How many worker threads are waiting on the condition.
+ volatile size_t waiting_count_ GUARDED_BY(task_queue_lock_);
+ std::deque<Closure*> tasks_ GUARDED_BY(task_queue_lock_);
+ // TODO: make this immutable/const?
+ std::vector<ThreadPoolWorker*> threads_;
+
+ friend class ThreadPoolWorker;
+ DISALLOW_COPY_AND_ASSIGN(ThreadPool);
+};
+
+} // namespace art
+
+#endif // ART_SRC_THREAD_POOL_H_
diff --git a/src/thread_pool_test.cc b/src/thread_pool_test.cc
new file mode 100644
index 0000000..783f786
--- /dev/null
+++ b/src/thread_pool_test.cc
@@ -0,0 +1,126 @@
+/*
+ * Copyright (C) 2012 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+#include <string>
+
+#include "atomic_integer.h"
+#include "common_test.h"
+#include "thread_pool.h"
+
+namespace art {
+
+class CountClosure : public Closure {
+ public:
+ CountClosure(AtomicInteger* count) : count_(count) {
+
+ }
+
+ void Run(Thread* /* self */) {
+ // Simulate doing some work.
+ usleep(100);
+ // Increment the counter which keeps track of work completed.
+ ++*count_;
+ delete this;
+ }
+
+ private:
+ AtomicInteger* const count_;
+};
+
+class ThreadPoolTest : public CommonTest {
+ public:
+ static int32_t num_threads;
+};
+
+int32_t ThreadPoolTest::num_threads = 4;
+
+// Check that the thread pool actually runs tasks that you assign it.
+TEST_F(ThreadPoolTest, CheckRun) {
+ Thread* self = Thread::Current();
+ ThreadPool thread_pool(num_threads);
+ AtomicInteger count = 0;
+ static const int32_t num_tasks = num_threads * 4;
+ for (int32_t i = 0; i < num_tasks; ++i) {
+ thread_pool.AddTask(self, new CountClosure(&count));
+ }
+ thread_pool.StartWorkers(self);
+ // Wait for tasks to complete.
+ thread_pool.Wait(self);
+ // Make sure that we finished all the work.
+ EXPECT_EQ(num_tasks, count);
+}
+
+TEST_F(ThreadPoolTest, StopStart) {
+ Thread* self = Thread::Current();
+ ThreadPool thread_pool(num_threads);
+ AtomicInteger count = 0;
+ static const int32_t num_tasks = num_threads * 4;
+ for (int32_t i = 0; i < num_tasks; ++i) {
+ thread_pool.AddTask(self, new CountClosure(&count));
+ }
+ usleep(200);
+ // Check that no threads started prematurely.
+ EXPECT_EQ(0, count);
+ // Signal the threads to start processing tasks.
+ thread_pool.StartWorkers(self);
+ usleep(200);
+ thread_pool.StopWorkers(self);
+ AtomicInteger bad_count = 0;
+ thread_pool.AddTask(self, new CountClosure(&bad_count));
+ usleep(200);
+ // Ensure that the task added after the workers were stopped doesn't get run.
+ EXPECT_EQ(0, bad_count);
+}
+
+class TreeClosure : public Closure {
+ public:
+ TreeClosure(ThreadPool* const thread_pool, AtomicInteger* count, int depth)
+ : thread_pool_(thread_pool),
+ count_(count),
+ depth_(depth) {
+
+ }
+
+ void Run(Thread* self) {
+ if (depth_ > 1) {
+ thread_pool_->AddTask(self, new TreeClosure(thread_pool_, count_, depth_ - 1));
+ thread_pool_->AddTask(self, new TreeClosure(thread_pool_, count_, depth_ - 1));
+ }
+ // Increment the counter which keeps track of work completed.
+ ++*count_;
+ delete this;
+ }
+
+ private:
+ ThreadPool* const thread_pool_;
+ AtomicInteger* const count_;
+ const int depth_;
+};
+
+// Test that adding new tasks from within a task works.
+TEST_F(ThreadPoolTest, RecursiveTest) {
+ Thread* self = Thread::Current();
+ ThreadPool thread_pool(num_threads);
+ AtomicInteger count = 0;
+ static const int depth = 8;
+ thread_pool.AddTask(self, new TreeClosure(&thread_pool, &count, depth));
+ thread_pool.StartWorkers(self);
+ thread_pool.Wait(self);
+ EXPECT_EQ((1 << depth) - 1, count);
+}
+
+} // namespace art