Merge from Chromium at DEPS revision r167172
This commit was generated by merge_to_master.py.
Change-Id: Ib8d56fd5ae39a2d7e8c91dcd76cc6d13f25f2aab
diff --git a/base/win/OWNERS b/base/win/OWNERS
new file mode 100644
index 0000000..3aae3d6
--- /dev/null
+++ b/base/win/OWNERS
@@ -0,0 +1 @@
+cpu@chromium.org
diff --git a/base/win/dllmain.cc b/base/win/dllmain.cc
new file mode 100644
index 0000000..15437e0
--- /dev/null
+++ b/base/win/dllmain.cc
@@ -0,0 +1,122 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+// Windows doesn't support pthread_key_create's destr_function, and in fact
+// it's a bit tricky to get code to run when a thread exits. This is
+// cargo-cult magic from http://www.codeproject.com/threads/tls.asp.
+// We are trying to be compatible with both a LoadLibrary style invocation, as
+// well as static linking. This code only needs to be included if we use
+// LoadLibrary, but it hooks into the "standard" set of TLS callbacks that are
+// provided for static linking.
+
+// This code is deliberately written to match the style of calls seen in
+// base/threading/thread_local_storage_win.cc. Please keep the two in sync if
+// coding conventions are changed.
+
+// WARNING: Do *NOT* try to include this in the construction of the base
+// library, even though it potentially drives code in
+// base/threading/thread_local_storage_win.cc. If you do, some users will end
+// up getting duplicate definition of DllMain() in some of their later links.
+
+// Force a reference to _tls_used to make the linker create the TLS directory
+// if it's not already there (that is, even if __declspec(thread) is not used).
+// Force a reference to p_thread_callback_dllmain_typical_entry to prevent whole
+// program optimization from discarding the variables.
+
+#include <windows.h>
+
+#include "base/compiler_specific.h"
+#include "base/win/win_util.h"
+
+// Indicate if another service is scanning the callbacks. When this becomes
+// set to true, then DllMain() will stop supporting the callback service. This
+// value is set to true the first time any of our callbacks are called, as that
+// shows that some other service is handling callbacks.
+static bool linker_notifications_are_active = false;
+
+// This will be our mostly no-op callback that we'll list. We won't
+// deliberately call it, and if it is called, that means we don't need to do any
+// of the callbacks anymore. We expect such a call to arrive via a
+// THREAD_ATTACH message, long before we'd have to perform our THREAD_DETACH
+// callbacks.
+static void NTAPI on_callback(PVOID h, DWORD reason, PVOID reserved);
+
+#ifdef _WIN64
+
+#pragma comment(linker, "/INCLUDE:_tls_used")
+#pragma comment(linker, "/INCLUDE:p_thread_callback_dllmain_typical_entry")
+
+#else // _WIN64
+
+#pragma comment(linker, "/INCLUDE:__tls_used")
+#pragma comment(linker, "/INCLUDE:_p_thread_callback_dllmain_typical_entry")
+
+#endif // _WIN64
+
+// Explicitly depend on tlssup.cc variable to bracket the list of TLS callbacks.
+extern "C" PIMAGE_TLS_CALLBACK __xl_a;
+extern "C" PIMAGE_TLS_CALLBACK __xl_z;
+
+// extern "C" suppresses C++ name mangling so we know the symbol names for the
+// linker /INCLUDE:symbol pragmas above.
+extern "C" {
+#ifdef _WIN64
+
+// .CRT section is merged with .rdata on x64 so it must be constant data.
+#pragma data_seg(push, old_seg)
+// Use a typical possible name in the .CRT$XL? list of segments.
+#pragma const_seg(".CRT$XLB")
+// When defining a const variable, it must have external linkage to be sure the
+// linker doesn't discard it.
+extern const PIMAGE_TLS_CALLBACK p_thread_callback_dllmain_typical_entry;
+const PIMAGE_TLS_CALLBACK p_thread_callback_dllmain_typical_entry = on_callback;
+#pragma data_seg(pop, old_seg)
+
+#else // _WIN64
+
+#pragma data_seg(push, old_seg)
+// Use a typical possible name in the .CRT$XL? list of segments.
+#pragma data_seg(".CRT$XLB")
+PIMAGE_TLS_CALLBACK p_thread_callback_dllmain_typical_entry = on_callback;
+#pragma data_seg(pop, old_seg)
+
+#endif // _WIN64
+} // extern "C"
+
+NOINLINE static void CrashOnProcessDetach() {
+ *((int*)0) = 0x356;
+}
+
+// Make DllMain call the listed callbacks. This way any third parties that are
+// linked in will also be called.
+BOOL WINAPI DllMain(PVOID h, DWORD reason, PVOID reserved) {
+ if (DLL_PROCESS_DETACH == reason && base::win::ShouldCrashOnProcessDetach())
+ CrashOnProcessDetach();
+
+ if (DLL_THREAD_DETACH != reason && DLL_PROCESS_DETACH != reason)
+ return true; // We won't service THREAD_ATTACH calls.
+
+ if (linker_notifications_are_active)
+ return true; // Some other service is doing this work.
+
+ for (PIMAGE_TLS_CALLBACK* it = &__xl_a; it < &__xl_z; ++it) {
+ if (*it == NULL || *it == on_callback)
+ continue; // Don't bother to call our own callback.
+ (*it)(h, reason, reserved);
+ }
+ return true;
+}
+
+static void NTAPI on_callback(PVOID h, DWORD reason, PVOID reserved) {
+ // Do nothing. We were just a place holder in the list used to test that we
+ // call all items.
+ // If we are called, it means that some other system is scanning the callbacks
+ // and we don't need to do so in DllMain().
+ linker_notifications_are_active = true;
+ // Note: If some other routine some how plays this same game... we could both
+ // decide not to do the scanning <sigh>, but this trick should suppress
+ // duplicate calls on Vista, where the runtime takes care of the callbacks,
+ // and allow us to do the callbacks on XP, where we are currently devoid of
+ // callbacks (due to an explicit LoadLibrary call).
+}
diff --git a/base/win/enum_variant.cc b/base/win/enum_variant.cc
new file mode 100644
index 0000000..2975560
--- /dev/null
+++ b/base/win/enum_variant.cc
@@ -0,0 +1,83 @@
+// Copyright (c) 2011 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "base/win/enum_variant.h"
+
+#include <algorithm>
+
+#include "base/logging.h"
+
+namespace base {
+namespace win {
+
+EnumVariant::EnumVariant(unsigned long count)
+ : items_(new VARIANT[count]),
+ count_(count),
+ current_index_(0) {
+}
+
+EnumVariant::~EnumVariant() {
+}
+
+VARIANT* EnumVariant::ItemAt(unsigned long index) {
+ DCHECK(index < count_);
+ return &items_[index];
+}
+
+ULONG STDMETHODCALLTYPE EnumVariant::AddRef() {
+ return IUnknownImpl::AddRef();
+}
+
+ULONG STDMETHODCALLTYPE EnumVariant::Release() {
+ return IUnknownImpl::Release();
+}
+
+STDMETHODIMP EnumVariant::QueryInterface(REFIID riid, void** ppv) {
+ if (riid == IID_IEnumVARIANT) {
+ *ppv = static_cast<IEnumVARIANT*>(this);
+ AddRef();
+ return S_OK;
+ }
+
+ return IUnknownImpl::QueryInterface(riid, ppv);
+}
+
+STDMETHODIMP EnumVariant::Next(ULONG requested_count,
+ VARIANT* out_elements,
+ ULONG* out_elements_received) {
+ unsigned long count = std::min(requested_count, count_ - current_index_);
+ for (unsigned long i = 0; i < count; ++i)
+ out_elements[i] = items_[current_index_ + i];
+ current_index_ += count;
+ *out_elements_received = count;
+
+ return (count == requested_count ? S_OK : S_FALSE);
+}
+
+STDMETHODIMP EnumVariant::Skip(ULONG skip_count) {
+ unsigned long count = skip_count;
+ if (current_index_ + count > count_)
+ count = count_ - current_index_;
+
+ current_index_ += count;
+ return (count == skip_count ? S_OK : S_FALSE);
+}
+
+STDMETHODIMP EnumVariant::Reset() {
+ current_index_ = 0;
+ return S_OK;
+}
+
+STDMETHODIMP EnumVariant::Clone(IEnumVARIANT** out_cloned_object) {
+ EnumVariant* other = new EnumVariant(count_);
+ if (count_ > 0)
+ memcpy(other->ItemAt(0), &items_[0], count_ * sizeof(VARIANT));
+ other->Skip(current_index_);
+ other->AddRef();
+ *out_cloned_object = other;
+ return S_OK;
+}
+
+} // namespace win
+} // namespace base
diff --git a/base/win/enum_variant.h b/base/win/enum_variant.h
new file mode 100644
index 0000000..7cee91d
--- /dev/null
+++ b/base/win/enum_variant.h
@@ -0,0 +1,52 @@
+// Copyright (c) 2011 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef BASE_WIN_ENUM_VARIANT_H_
+#define BASE_WIN_ENUM_VARIANT_H_
+
+#include <unknwn.h>
+
+#include "base/memory/scoped_ptr.h"
+#include "base/win/iunknown_impl.h"
+
+namespace base {
+namespace win {
+
+// A simple implementation of IEnumVARIANT.
+class BASE_EXPORT EnumVariant
+ : public IEnumVARIANT,
+ public IUnknownImpl {
+ public:
+ // The constructor allocates an array of size |count|. Then use
+ // ItemAt to set the value of each item in the array to initialize it.
+ explicit EnumVariant(unsigned long count);
+
+ // Returns a mutable pointer to the item at position |index|.
+ VARIANT* ItemAt(unsigned long index);
+
+ // IUnknown.
+ ULONG STDMETHODCALLTYPE AddRef() OVERRIDE;
+ ULONG STDMETHODCALLTYPE Release() OVERRIDE;
+ STDMETHODIMP QueryInterface(REFIID riid, void** ppv) OVERRIDE;
+
+ // IEnumVARIANT.
+ STDMETHODIMP Next(ULONG requested_count,
+ VARIANT* out_elements,
+ ULONG* out_elements_received);
+ STDMETHODIMP Skip(ULONG skip_count);
+ STDMETHODIMP Reset();
+ STDMETHODIMP Clone(IEnumVARIANT** out_cloned_object);
+
+ private:
+ ~EnumVariant();
+
+ scoped_array<VARIANT> items_;
+ unsigned long count_;
+ unsigned long current_index_;
+};
+
+} // namespace win
+} // namespace base
+
+#endif // BASE_WIN_ENUM_VARIANT_H_
diff --git a/base/win/enum_variant_unittest.cc b/base/win/enum_variant_unittest.cc
new file mode 100644
index 0000000..99645a2
--- /dev/null
+++ b/base/win/enum_variant_unittest.cc
@@ -0,0 +1,118 @@
+// Copyright (c) 2011 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "base/win/enum_variant.h"
+
+#include "base/win/scoped_com_initializer.h"
+#include "testing/gtest/include/gtest/gtest.h"
+
+namespace base {
+namespace win {
+
+TEST(EnumVariantTest, EmptyEnumVariant) {
+ ScopedCOMInitializer com_initializer;
+
+ EnumVariant* ev = new EnumVariant(0);
+ ev->AddRef();
+
+ IUnknown* iunknown;
+ EXPECT_TRUE(SUCCEEDED(
+ ev->QueryInterface(IID_IUnknown, reinterpret_cast<void**>(&iunknown))));
+ iunknown->Release();
+
+ IEnumVARIANT* ienumvariant;
+ EXPECT_TRUE(SUCCEEDED(
+ ev->QueryInterface(IID_IEnumVARIANT,
+ reinterpret_cast<void**>(&ienumvariant))));
+ EXPECT_EQ(ev, ienumvariant);
+ ienumvariant->Release();
+
+ VARIANT out_element;
+ ULONG out_received = 0;
+ EXPECT_EQ(S_FALSE, ev->Next(1, &out_element, &out_received));
+ EXPECT_EQ(0, out_received);
+
+ EXPECT_EQ(S_FALSE, ev->Skip(1));
+
+ EXPECT_EQ(S_OK, ev->Reset());
+
+ IEnumVARIANT* ev2 = NULL;
+ EXPECT_EQ(S_OK, ev->Clone(&ev2));
+
+ EXPECT_NE(static_cast<IEnumVARIANT*>(NULL), ev2);
+ EXPECT_NE(ev, ev2);
+ EXPECT_EQ(S_FALSE, ev2->Skip(1));
+ EXPECT_EQ(S_OK, ev2->Reset());
+
+ ULONG ev2_finalrefcount = ev2->Release();
+ EXPECT_EQ(0, ev2_finalrefcount);
+
+ ULONG ev_finalrefcount = ev->Release();
+ EXPECT_EQ(0, ev_finalrefcount);
+}
+
+TEST(EnumVariantTest, SimpleEnumVariant) {
+ ScopedCOMInitializer com_initializer;
+
+ EnumVariant* ev = new EnumVariant(3);
+ ev->AddRef();
+ ev->ItemAt(0)->vt = VT_I4;
+ ev->ItemAt(0)->lVal = 10;
+ ev->ItemAt(1)->vt = VT_I4;
+ ev->ItemAt(1)->lVal = 20;
+ ev->ItemAt(2)->vt = VT_I4;
+ ev->ItemAt(2)->lVal = 30;
+
+ // Get elements one at a time.
+ VARIANT out_element;
+ ULONG out_received = 0;
+ EXPECT_EQ(S_OK, ev->Next(1, &out_element, &out_received));
+ EXPECT_EQ(1, out_received);
+ EXPECT_EQ(VT_I4, out_element.vt);
+ EXPECT_EQ(10, out_element.lVal);
+ EXPECT_EQ(S_OK, ev->Skip(1));
+ EXPECT_EQ(S_OK, ev->Next(1, &out_element, &out_received));
+ EXPECT_EQ(1, out_received);
+ EXPECT_EQ(VT_I4, out_element.vt);
+ EXPECT_EQ(30, out_element.lVal);
+ EXPECT_EQ(S_FALSE, ev->Next(1, &out_element, &out_received));
+
+ // Reset and get all elements at once.
+ VARIANT out_elements[3];
+ EXPECT_EQ(S_OK, ev->Reset());
+ EXPECT_EQ(S_OK, ev->Next(3, out_elements, &out_received));
+ EXPECT_EQ(3, out_received);
+ EXPECT_EQ(VT_I4, out_elements[0].vt);
+ EXPECT_EQ(10, out_elements[0].lVal);
+ EXPECT_EQ(VT_I4, out_elements[1].vt);
+ EXPECT_EQ(20, out_elements[1].lVal);
+ EXPECT_EQ(VT_I4, out_elements[2].vt);
+ EXPECT_EQ(30, out_elements[2].lVal);
+ EXPECT_EQ(S_FALSE, ev->Next(1, &out_element, &out_received));
+
+ // Clone it.
+ IEnumVARIANT* ev2 = NULL;
+ EXPECT_EQ(S_OK, ev->Clone(&ev2));
+ EXPECT_TRUE(ev2 != NULL);
+ EXPECT_EQ(S_FALSE, ev->Next(1, &out_element, &out_received));
+ EXPECT_EQ(S_OK, ev2->Reset());
+ EXPECT_EQ(S_OK, ev2->Next(3, out_elements, &out_received));
+ EXPECT_EQ(3, out_received);
+ EXPECT_EQ(VT_I4, out_elements[0].vt);
+ EXPECT_EQ(10, out_elements[0].lVal);
+ EXPECT_EQ(VT_I4, out_elements[1].vt);
+ EXPECT_EQ(20, out_elements[1].lVal);
+ EXPECT_EQ(VT_I4, out_elements[2].vt);
+ EXPECT_EQ(30, out_elements[2].lVal);
+ EXPECT_EQ(S_FALSE, ev2->Next(1, &out_element, &out_received));
+
+ ULONG ev2_finalrefcount = ev2->Release();
+ EXPECT_EQ(0, ev2_finalrefcount);
+
+ ULONG ev_finalrefcount = ev->Release();
+ EXPECT_EQ(0, ev_finalrefcount);
+}
+
+} // namespace win
+} // namespace base
diff --git a/base/win/event_trace_consumer.h b/base/win/event_trace_consumer.h
new file mode 100644
index 0000000..c1b42b4
--- /dev/null
+++ b/base/win/event_trace_consumer.h
@@ -0,0 +1,148 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+//
+// Declaration of a Windows event trace consumer base class.
+#ifndef BASE_WIN_EVENT_TRACE_CONSUMER_H_
+#define BASE_WIN_EVENT_TRACE_CONSUMER_H_
+
+#include <windows.h>
+#include <wmistr.h>
+#include <evntrace.h>
+#include <vector>
+#include "base/basictypes.h"
+
+namespace base {
+namespace win {
+
+// This class is a base class that makes it easier to consume events
+// from realtime or file sessions. Concrete consumers need to sublass
+// a specialization of this class and override the ProcessEvent and/or
+// the ProcessBuffer methods to implement the event consumption logic.
+// Usage might look like:
+// class MyConsumer: public EtwTraceConsumerBase<MyConsumer, 1> {
+// protected:
+// static VOID WINAPI ProcessEvent(PEVENT_TRACE event);
+// };
+//
+// MyConsumer consumer;
+// consumer.OpenFileSession(file_path);
+// consumer.Consume();
+template <class ImplClass>
+class EtwTraceConsumerBase {
+ public:
+ // Constructs a closed consumer.
+ EtwTraceConsumerBase() {
+ }
+
+ ~EtwTraceConsumerBase() {
+ Close();
+ }
+
+ // Opens the named realtime session, which must be existent.
+ // Note: You can use OpenRealtimeSession or OpenFileSession
+ // to open as many as MAXIMUM_WAIT_OBJECTS (63) sessions at
+ // any one time, though only one of them may be a realtime
+ // session.
+ HRESULT OpenRealtimeSession(const wchar_t* session_name);
+
+ // Opens the event trace log in "file_name", which must be a full or
+ // relative path to an existing event trace log file.
+ // Note: You can use OpenRealtimeSession or OpenFileSession
+ // to open as many as kNumSessions at any one time.
+ HRESULT OpenFileSession(const wchar_t* file_name);
+
+ // Consume all open sessions from beginning to end.
+ HRESULT Consume();
+
+ // Close all open sessions.
+ HRESULT Close();
+
+ protected:
+ // Override in subclasses to handle events.
+ static void ProcessEvent(EVENT_TRACE* event) {
+ }
+ // Override in subclasses to handle buffers.
+ static bool ProcessBuffer(EVENT_TRACE_LOGFILE* buffer) {
+ return true; // keep going
+ }
+
+ protected:
+ // Currently open sessions.
+ std::vector<TRACEHANDLE> trace_handles_;
+
+ private:
+ // These delegate to ImplClass callbacks with saner signatures.
+ static void WINAPI ProcessEventCallback(EVENT_TRACE* event) {
+ ImplClass::ProcessEvent(event);
+ }
+ static ULONG WINAPI ProcessBufferCallback(PEVENT_TRACE_LOGFILE buffer) {
+ return ImplClass::ProcessBuffer(buffer);
+ }
+
+ DISALLOW_COPY_AND_ASSIGN(EtwTraceConsumerBase);
+};
+
+template <class ImplClass> inline
+HRESULT EtwTraceConsumerBase<ImplClass>::OpenRealtimeSession(
+ const wchar_t* session_name) {
+ EVENT_TRACE_LOGFILE logfile = {};
+ logfile.LoggerName = const_cast<wchar_t*>(session_name);
+ logfile.LogFileMode = EVENT_TRACE_REAL_TIME_MODE;
+ logfile.BufferCallback = &ProcessBufferCallback;
+ logfile.EventCallback = &ProcessEventCallback;
+ logfile.Context = this;
+ TRACEHANDLE trace_handle = ::OpenTrace(&logfile);
+ if (reinterpret_cast<TRACEHANDLE>(INVALID_HANDLE_VALUE) == trace_handle)
+ return HRESULT_FROM_WIN32(::GetLastError());
+
+ trace_handles_.push_back(trace_handle);
+ return S_OK;
+}
+
+template <class ImplClass> inline
+HRESULT EtwTraceConsumerBase<ImplClass>::OpenFileSession(
+ const wchar_t* file_name) {
+ EVENT_TRACE_LOGFILE logfile = {};
+ logfile.LogFileName = const_cast<wchar_t*>(file_name);
+ logfile.BufferCallback = &ProcessBufferCallback;
+ logfile.EventCallback = &ProcessEventCallback;
+ logfile.Context = this;
+ TRACEHANDLE trace_handle = ::OpenTrace(&logfile);
+ if (reinterpret_cast<TRACEHANDLE>(INVALID_HANDLE_VALUE) == trace_handle)
+ return HRESULT_FROM_WIN32(::GetLastError());
+
+ trace_handles_.push_back(trace_handle);
+ return S_OK;
+}
+
+template <class ImplClass> inline
+HRESULT EtwTraceConsumerBase<ImplClass>::Consume() {
+ ULONG err = ::ProcessTrace(&trace_handles_[0],
+ trace_handles_.size(),
+ NULL,
+ NULL);
+ return HRESULT_FROM_WIN32(err);
+}
+
+template <class ImplClass> inline
+HRESULT EtwTraceConsumerBase<ImplClass>::Close() {
+ HRESULT hr = S_OK;
+ for (size_t i = 0; i < trace_handles_.size(); ++i) {
+ if (NULL != trace_handles_[i]) {
+ ULONG ret = ::CloseTrace(trace_handles_[i]);
+ trace_handles_[i] = NULL;
+
+ if (FAILED(HRESULT_FROM_WIN32(ret)))
+ hr = HRESULT_FROM_WIN32(ret);
+ }
+ }
+ trace_handles_.clear();
+
+ return hr;
+}
+
+} // namespace win
+} // namespace base
+
+#endif // BASE_WIN_EVENT_TRACE_CONSUMER_H_
diff --git a/base/win/event_trace_consumer_unittest.cc b/base/win/event_trace_consumer_unittest.cc
new file mode 100644
index 0000000..327e1c7
--- /dev/null
+++ b/base/win/event_trace_consumer_unittest.cc
@@ -0,0 +1,383 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+//
+// Unit tests for event trace consumer base class.
+#include "base/win/event_trace_consumer.h"
+
+#include <list>
+
+#include <objbase.h>
+
+#include "base/basictypes.h"
+#include "base/file_path.h"
+#include "base/file_util.h"
+#include "base/logging.h"
+#include "base/process.h"
+#include "base/scoped_temp_dir.h"
+#include "base/stringprintf.h"
+#include "base/win/event_trace_controller.h"
+#include "base/win/event_trace_provider.h"
+#include "base/win/scoped_handle.h"
+#include "testing/gtest/include/gtest/gtest.h"
+
+#include <initguid.h> // NOLINT - has to be last
+
+namespace {
+
+using base::win::EtwMofEvent;
+using base::win::EtwTraceController;
+using base::win::EtwTraceConsumerBase;
+using base::win::EtwTraceProperties;
+using base::win::EtwTraceProvider;
+
+typedef std::list<EVENT_TRACE> EventQueue;
+
+class TestConsumer: public EtwTraceConsumerBase<TestConsumer> {
+ public:
+ TestConsumer() {
+ sank_event_.Set(::CreateEvent(NULL, TRUE, FALSE, NULL));
+ ClearQueue();
+ }
+
+ ~TestConsumer() {
+ ClearQueue();
+ sank_event_.Close();
+ }
+
+ void ClearQueue() {
+ EventQueue::const_iterator it(events_.begin()), end(events_.end());
+
+ for (; it != end; ++it) {
+ delete [] it->MofData;
+ }
+
+ events_.clear();
+ }
+
+ static void EnqueueEvent(EVENT_TRACE* event) {
+ events_.push_back(*event);
+ EVENT_TRACE& back = events_.back();
+
+ if (NULL != event->MofData && 0 != event->MofLength) {
+ back.MofData = new char[event->MofLength];
+ memcpy(back.MofData, event->MofData, event->MofLength);
+ }
+ }
+
+ static void ProcessEvent(EVENT_TRACE* event) {
+ EnqueueEvent(event);
+ ::SetEvent(sank_event_.Get());
+ }
+
+ static base::win::ScopedHandle sank_event_;
+ static EventQueue events_;
+
+ private:
+ DISALLOW_COPY_AND_ASSIGN(TestConsumer);
+};
+
+base::win::ScopedHandle TestConsumer::sank_event_;
+EventQueue TestConsumer::events_;
+
+class EtwTraceConsumerBaseTest: public testing::Test {
+ public:
+ EtwTraceConsumerBaseTest()
+ : session_name_(base::StringPrintf(L"TestSession-%d",
+ base::Process::Current().pid())) {
+ }
+
+ virtual void SetUp() {
+ // Cleanup any potentially dangling sessions.
+ EtwTraceProperties ignore;
+ EtwTraceController::Stop(session_name_.c_str(), &ignore);
+
+ // Allocate a new GUID for each provider test.
+ ASSERT_HRESULT_SUCCEEDED(::CoCreateGuid(&test_provider_));
+ }
+
+ virtual void TearDown() {
+ // Cleanup any potentially danging sessions.
+ EtwTraceProperties ignore;
+ EtwTraceController::Stop(session_name_.c_str(), &ignore);
+ }
+
+ protected:
+ GUID test_provider_;
+ std::wstring session_name_;
+};
+
+} // namespace
+
+TEST_F(EtwTraceConsumerBaseTest, Initialize) {
+ TestConsumer consumer_;
+}
+
+TEST_F(EtwTraceConsumerBaseTest, OpenRealtimeSucceedsWhenNoSession) {
+ TestConsumer consumer_;
+
+ ASSERT_HRESULT_SUCCEEDED(
+ consumer_.OpenRealtimeSession(session_name_.c_str()));
+}
+
+TEST_F(EtwTraceConsumerBaseTest, ConsumerImmediateFailureWhenNoSession) {
+ TestConsumer consumer_;
+
+ ASSERT_HRESULT_SUCCEEDED(
+ consumer_.OpenRealtimeSession(session_name_.c_str()));
+ ASSERT_HRESULT_FAILED(consumer_.Consume());
+}
+
+namespace {
+
+class EtwTraceConsumerRealtimeTest: public EtwTraceConsumerBaseTest {
+ public:
+ virtual void SetUp() {
+ EtwTraceConsumerBaseTest::SetUp();
+
+ ASSERT_HRESULT_SUCCEEDED(
+ consumer_.OpenRealtimeSession(session_name_.c_str()));
+ }
+
+ virtual void TearDown() {
+ consumer_.Close();
+
+ EtwTraceConsumerBaseTest::TearDown();
+ }
+
+ DWORD ConsumerThread() {
+ ::SetEvent(consumer_ready_.Get());
+
+ HRESULT hr = consumer_.Consume();
+ return hr;
+ }
+
+ static DWORD WINAPI ConsumerThreadMainProc(void* arg) {
+ return reinterpret_cast<EtwTraceConsumerRealtimeTest*>(arg)->
+ ConsumerThread();
+ }
+
+ HRESULT StartConsumerThread() {
+ consumer_ready_.Set(::CreateEvent(NULL, TRUE, FALSE, NULL));
+ EXPECT_TRUE(consumer_ready_ != NULL);
+ consumer_thread_.Set(::CreateThread(NULL, 0, ConsumerThreadMainProc,
+ this, 0, NULL));
+ if (NULL == consumer_thread_.Get())
+ return HRESULT_FROM_WIN32(::GetLastError());
+
+ HRESULT hr = S_OK;
+ HANDLE events[] = { consumer_ready_, consumer_thread_ };
+ DWORD result = ::WaitForMultipleObjects(arraysize(events), events,
+ FALSE, INFINITE);
+ switch (result) {
+ case WAIT_OBJECT_0:
+ // The event was set, the consumer_ is ready.
+ return S_OK;
+ case WAIT_OBJECT_0 + 1: {
+ // The thread finished. This may race with the event, so check
+ // explicitly for the event here, before concluding there's trouble.
+ if (WAIT_OBJECT_0 == ::WaitForSingleObject(consumer_ready_, 0))
+ return S_OK;
+ DWORD exit_code = 0;
+ if (::GetExitCodeThread(consumer_thread_, &exit_code))
+ return exit_code;
+ else
+ return HRESULT_FROM_WIN32(::GetLastError());
+ break;
+ }
+ default:
+ return E_UNEXPECTED;
+ break;
+ }
+
+ return hr;
+ }
+
+ // Waits for consumer_ thread to exit, and returns its exit code.
+ HRESULT JoinConsumerThread() {
+ if (WAIT_OBJECT_0 != ::WaitForSingleObject(consumer_thread_, INFINITE))
+ return HRESULT_FROM_WIN32(::GetLastError());
+
+ DWORD exit_code = 0;
+ if (::GetExitCodeThread(consumer_thread_, &exit_code))
+ return exit_code;
+
+ return HRESULT_FROM_WIN32(::GetLastError());
+ }
+
+ TestConsumer consumer_;
+ base::win::ScopedHandle consumer_ready_;
+ base::win::ScopedHandle consumer_thread_;
+};
+
+} // namespace
+
+TEST_F(EtwTraceConsumerRealtimeTest, ConsumerReturnsWhenSessionClosed) {
+ EtwTraceController controller;
+
+ HRESULT hr = controller.StartRealtimeSession(session_name_.c_str(),
+ 100 * 1024);
+ if (hr == E_ACCESSDENIED) {
+ VLOG(1) << "You must be an administrator to run this test on Vista";
+ return;
+ }
+
+ // Start the consumer_.
+ ASSERT_HRESULT_SUCCEEDED(StartConsumerThread());
+
+ // Wait around for the consumer_ thread a bit.
+ ASSERT_EQ(WAIT_TIMEOUT, ::WaitForSingleObject(consumer_thread_, 50));
+
+ ASSERT_HRESULT_SUCCEEDED(controller.Stop(NULL));
+
+ // The consumer_ returns success on session stop.
+ ASSERT_HRESULT_SUCCEEDED(JoinConsumerThread());
+}
+
+namespace {
+
+// {57E47923-A549-476f-86CA-503D57F59E62}
+DEFINE_GUID(kTestEventType,
+ 0x57e47923, 0xa549, 0x476f, 0x86, 0xca, 0x50, 0x3d, 0x57, 0xf5, 0x9e, 0x62);
+
+} // namespace
+
+TEST_F(EtwTraceConsumerRealtimeTest, ConsumeEvent) {
+ EtwTraceController controller;
+ HRESULT hr = controller.StartRealtimeSession(session_name_.c_str(),
+ 100 * 1024);
+ if (hr == E_ACCESSDENIED) {
+ VLOG(1) << "You must be an administrator to run this test on Vista";
+ return;
+ }
+
+ ASSERT_HRESULT_SUCCEEDED(controller.EnableProvider(test_provider_,
+ TRACE_LEVEL_VERBOSE, 0xFFFFFFFF));
+
+ EtwTraceProvider provider(test_provider_);
+ ASSERT_EQ(ERROR_SUCCESS, provider.Register());
+
+ // Start the consumer_.
+ ASSERT_HRESULT_SUCCEEDED(StartConsumerThread());
+
+ ASSERT_EQ(0, TestConsumer::events_.size());
+
+ EtwMofEvent<1> event(kTestEventType, 1, TRACE_LEVEL_ERROR);
+ EXPECT_EQ(ERROR_SUCCESS, provider.Log(&event.header));
+
+ EXPECT_EQ(WAIT_OBJECT_0, ::WaitForSingleObject(TestConsumer::sank_event_,
+ INFINITE));
+ ASSERT_HRESULT_SUCCEEDED(controller.Stop(NULL));
+ ASSERT_HRESULT_SUCCEEDED(JoinConsumerThread());
+ ASSERT_NE(0u, TestConsumer::events_.size());
+}
+
+namespace {
+
+// We run events through a file session to assert that
+// the content comes through.
+class EtwTraceConsumerDataTest: public EtwTraceConsumerBaseTest {
+ public:
+ EtwTraceConsumerDataTest() {
+ }
+
+ virtual void SetUp() {
+ EtwTraceConsumerBaseTest::SetUp();
+
+ EtwTraceProperties prop;
+ EtwTraceController::Stop(session_name_.c_str(), &prop);
+
+ // Create a temp dir for this test.
+ ASSERT_TRUE(temp_dir_.CreateUniqueTempDir());
+ // Construct a temp file name in our dir.
+ temp_file_ = temp_dir_.path().Append(L"test.etl");
+ }
+
+ virtual void TearDown() {
+ EXPECT_TRUE(file_util::Delete(temp_file_, false));
+
+ EtwTraceConsumerBaseTest::TearDown();
+ }
+
+ HRESULT LogEventToTempSession(PEVENT_TRACE_HEADER header) {
+ EtwTraceController controller;
+
+ // Set up a file session.
+ HRESULT hr = controller.StartFileSession(session_name_.c_str(),
+ temp_file_.value().c_str());
+ if (FAILED(hr))
+ return hr;
+
+ // Enable our provider.
+ EXPECT_HRESULT_SUCCEEDED(controller.EnableProvider(test_provider_,
+ TRACE_LEVEL_VERBOSE, 0xFFFFFFFF));
+
+ EtwTraceProvider provider(test_provider_);
+ // Then register our provider, means we get a session handle immediately.
+ EXPECT_EQ(ERROR_SUCCESS, provider.Register());
+ // Trace the event, it goes to the temp file.
+ EXPECT_EQ(ERROR_SUCCESS, provider.Log(header));
+ EXPECT_HRESULT_SUCCEEDED(controller.DisableProvider(test_provider_));
+ EXPECT_HRESULT_SUCCEEDED(provider.Unregister());
+ EXPECT_HRESULT_SUCCEEDED(controller.Flush(NULL));
+ EXPECT_HRESULT_SUCCEEDED(controller.Stop(NULL));
+
+ return S_OK;
+ }
+
+ HRESULT ConsumeEventFromTempSession() {
+ // Now consume the event(s).
+ TestConsumer consumer_;
+ HRESULT hr = consumer_.OpenFileSession(temp_file_.value().c_str());
+ if (SUCCEEDED(hr))
+ hr = consumer_.Consume();
+ consumer_.Close();
+ // And nab the result.
+ events_.swap(TestConsumer::events_);
+ return hr;
+ }
+
+ HRESULT RoundTripEvent(PEVENT_TRACE_HEADER header, PEVENT_TRACE* trace) {
+ file_util::Delete(temp_file_, false);
+
+ HRESULT hr = LogEventToTempSession(header);
+ if (SUCCEEDED(hr))
+ hr = ConsumeEventFromTempSession();
+
+ if (FAILED(hr))
+ return hr;
+
+ // We should now have the event in the queue.
+ if (events_.empty())
+ return E_FAIL;
+
+ *trace = &events_.back();
+ return S_OK;
+ }
+
+ EventQueue events_;
+ ScopedTempDir temp_dir_;
+ FilePath temp_file_;
+};
+
+} // namespace
+
+
+TEST_F(EtwTraceConsumerDataTest, RoundTrip) {
+ EtwMofEvent<1> event(kTestEventType, 1, TRACE_LEVEL_ERROR);
+
+ static const char kData[] = "This is but test data";
+ event.fields[0].DataPtr = reinterpret_cast<ULONG64>(kData);
+ event.fields[0].Length = sizeof(kData);
+
+ PEVENT_TRACE trace = NULL;
+ HRESULT hr = RoundTripEvent(&event.header, &trace);
+ if (hr == E_ACCESSDENIED) {
+ VLOG(1) << "You must be an administrator to run this test on Vista";
+ return;
+ }
+ ASSERT_HRESULT_SUCCEEDED(hr) << "RoundTripEvent failed";
+ ASSERT_TRUE(NULL != trace);
+ ASSERT_EQ(sizeof(kData), trace->MofLength);
+ ASSERT_STREQ(kData, reinterpret_cast<const char*>(trace->MofData));
+}
diff --git a/base/win/event_trace_controller.cc b/base/win/event_trace_controller.cc
new file mode 100644
index 0000000..0391fbc
--- /dev/null
+++ b/base/win/event_trace_controller.cc
@@ -0,0 +1,173 @@
+// Copyright (c) 2009 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+//
+// Implementation of a Windows event trace controller class.
+#include "base/win/event_trace_controller.h"
+#include "base/logging.h"
+
+namespace base {
+namespace win {
+
+EtwTraceProperties::EtwTraceProperties() {
+ memset(buffer_, 0, sizeof(buffer_));
+ EVENT_TRACE_PROPERTIES* prop = get();
+
+ prop->Wnode.BufferSize = sizeof(buffer_);
+ prop->Wnode.Flags = WNODE_FLAG_TRACED_GUID;
+ prop->LoggerNameOffset = sizeof(EVENT_TRACE_PROPERTIES);
+ prop->LogFileNameOffset = sizeof(EVENT_TRACE_PROPERTIES) +
+ sizeof(wchar_t) * kMaxStringLen;
+}
+
+HRESULT EtwTraceProperties::SetLoggerName(const wchar_t* logger_name) {
+ size_t len = wcslen(logger_name) + 1;
+ if (kMaxStringLen < len)
+ return E_INVALIDARG;
+
+ memcpy(buffer_ + get()->LoggerNameOffset,
+ logger_name,
+ sizeof(wchar_t) * len);
+ return S_OK;
+}
+
+HRESULT EtwTraceProperties::SetLoggerFileName(const wchar_t* logger_file_name) {
+ size_t len = wcslen(logger_file_name) + 1;
+ if (kMaxStringLen < len)
+ return E_INVALIDARG;
+
+ memcpy(buffer_ + get()->LogFileNameOffset,
+ logger_file_name,
+ sizeof(wchar_t) * len);
+ return S_OK;
+}
+
+EtwTraceController::EtwTraceController() : session_(NULL) {
+}
+
+EtwTraceController::~EtwTraceController() {
+ Stop(NULL);
+}
+
+HRESULT EtwTraceController::Start(const wchar_t* session_name,
+ EtwTraceProperties* prop) {
+ DCHECK(NULL == session_ && session_name_.empty());
+ EtwTraceProperties ignore;
+ if (prop == NULL)
+ prop = &ignore;
+
+ HRESULT hr = Start(session_name, prop, &session_);
+ if (SUCCEEDED(hr))
+ session_name_ = session_name;
+
+ return hr;
+}
+
+HRESULT EtwTraceController::StartFileSession(const wchar_t* session_name,
+ const wchar_t* logfile_path, bool realtime) {
+ DCHECK(NULL == session_ && session_name_.empty());
+
+ EtwTraceProperties prop;
+ prop.SetLoggerFileName(logfile_path);
+ EVENT_TRACE_PROPERTIES& p = *prop.get();
+ p.Wnode.ClientContext = 1; // QPC timer accuracy.
+ p.LogFileMode = EVENT_TRACE_FILE_MODE_SEQUENTIAL; // Sequential log.
+ if (realtime)
+ p.LogFileMode |= EVENT_TRACE_REAL_TIME_MODE;
+
+ p.MaximumFileSize = 100; // 100M file size.
+ p.FlushTimer = 30; // 30 seconds flush lag.
+ return Start(session_name, &prop);
+}
+
+HRESULT EtwTraceController::StartRealtimeSession(const wchar_t* session_name,
+ size_t buffer_size) {
+ DCHECK(NULL == session_ && session_name_.empty());
+ EtwTraceProperties prop;
+ EVENT_TRACE_PROPERTIES& p = *prop.get();
+ p.LogFileMode = EVENT_TRACE_REAL_TIME_MODE | EVENT_TRACE_USE_PAGED_MEMORY;
+ p.FlushTimer = 1; // flush every second.
+ p.BufferSize = 16; // 16 K buffers.
+ p.LogFileNameOffset = 0;
+ return Start(session_name, &prop);
+}
+
+HRESULT EtwTraceController::EnableProvider(REFGUID provider, UCHAR level,
+ ULONG flags) {
+ ULONG error = ::EnableTrace(TRUE, flags, level, &provider, session_);
+ return HRESULT_FROM_WIN32(error);
+}
+
+HRESULT EtwTraceController::DisableProvider(REFGUID provider) {
+ ULONG error = ::EnableTrace(FALSE, 0, 0, &provider, session_);
+ return HRESULT_FROM_WIN32(error);
+}
+
+HRESULT EtwTraceController::Stop(EtwTraceProperties* properties) {
+ EtwTraceProperties ignore;
+ if (properties == NULL)
+ properties = &ignore;
+
+ ULONG error = ::ControlTrace(session_, NULL, properties->get(),
+ EVENT_TRACE_CONTROL_STOP);
+ if (ERROR_SUCCESS != error)
+ return HRESULT_FROM_WIN32(error);
+
+ session_ = NULL;
+ session_name_.clear();
+ return S_OK;
+}
+
+HRESULT EtwTraceController::Flush(EtwTraceProperties* properties) {
+ EtwTraceProperties ignore;
+ if (properties == NULL)
+ properties = &ignore;
+
+ ULONG error = ::ControlTrace(session_, NULL, properties->get(),
+ EVENT_TRACE_CONTROL_FLUSH);
+ if (ERROR_SUCCESS != error)
+ return HRESULT_FROM_WIN32(error);
+
+ return S_OK;
+}
+
+HRESULT EtwTraceController::Start(const wchar_t* session_name,
+ EtwTraceProperties* properties, TRACEHANDLE* session_handle) {
+ DCHECK(properties != NULL);
+ ULONG err = ::StartTrace(session_handle, session_name, properties->get());
+ return HRESULT_FROM_WIN32(err);
+}
+
+HRESULT EtwTraceController::Query(const wchar_t* session_name,
+ EtwTraceProperties* properties) {
+ ULONG err = ::ControlTrace(NULL, session_name, properties->get(),
+ EVENT_TRACE_CONTROL_QUERY);
+ return HRESULT_FROM_WIN32(err);
+};
+
+HRESULT EtwTraceController::Update(const wchar_t* session_name,
+ EtwTraceProperties* properties) {
+ DCHECK(properties != NULL);
+ ULONG err = ::ControlTrace(NULL, session_name, properties->get(),
+ EVENT_TRACE_CONTROL_UPDATE);
+ return HRESULT_FROM_WIN32(err);
+}
+
+HRESULT EtwTraceController::Stop(const wchar_t* session_name,
+ EtwTraceProperties* properties) {
+ DCHECK(properties != NULL);
+ ULONG err = ::ControlTrace(NULL, session_name, properties->get(),
+ EVENT_TRACE_CONTROL_STOP);
+ return HRESULT_FROM_WIN32(err);
+}
+
+HRESULT EtwTraceController::Flush(const wchar_t* session_name,
+ EtwTraceProperties* properties) {
+ DCHECK(properties != NULL);
+ ULONG err = ::ControlTrace(NULL, session_name, properties->get(),
+ EVENT_TRACE_CONTROL_FLUSH);
+ return HRESULT_FROM_WIN32(err);
+}
+
+} // namespace win
+} // namespace base
diff --git a/base/win/event_trace_controller.h b/base/win/event_trace_controller.h
new file mode 100644
index 0000000..69e755b
--- /dev/null
+++ b/base/win/event_trace_controller.h
@@ -0,0 +1,151 @@
+// Copyright (c) 2011 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+//
+// Declaration of a Windows event trace controller class.
+// The controller takes care of creating and manipulating event trace
+// sessions.
+//
+// Event tracing for Windows is a system-provided service that provides
+// logging control and high-performance transport for generic, binary trace
+// events. Event trace providers register with the system by their name,
+// which is a GUID, and can from that point forward receive callbacks that
+// start or end tracing and that change their trace level and enable mask.
+//
+// A trace controller can create an event tracing session, which either
+// sends events to a binary file, or to a realtime consumer, or both.
+//
+// A trace consumer consumes events from zero or one realtime session,
+// as well as potentially from multiple binary trace files.
+#ifndef BASE_WIN_EVENT_TRACE_CONTROLLER_H_
+#define BASE_WIN_EVENT_TRACE_CONTROLLER_H_
+
+#include <windows.h>
+#include <wmistr.h>
+#include <evntrace.h>
+#include <string>
+
+#include "base/base_export.h"
+#include "base/basictypes.h"
+
+namespace base {
+namespace win {
+
+// Utility class to make it easier to work with EVENT_TRACE_PROPERTIES.
+// The EVENT_TRACE_PROPERTIES structure contains information about an
+// event tracing session.
+class BASE_EXPORT EtwTraceProperties {
+ public:
+ EtwTraceProperties();
+
+ EVENT_TRACE_PROPERTIES* get() {
+ return &properties_;
+ }
+
+ const EVENT_TRACE_PROPERTIES* get() const {
+ return reinterpret_cast<const EVENT_TRACE_PROPERTIES*>(&properties_);
+ }
+
+ const wchar_t* GetLoggerName() const {
+ return reinterpret_cast<const wchar_t *>(buffer_ + get()->LoggerNameOffset);
+ }
+
+ // Copies logger_name to the properties structure.
+ HRESULT SetLoggerName(const wchar_t* logger_name);
+ const wchar_t* GetLoggerFileName() const {
+ return reinterpret_cast<const wchar_t*>(buffer_ + get()->LogFileNameOffset);
+ }
+
+ // Copies logger_file_name to the properties structure.
+ HRESULT SetLoggerFileName(const wchar_t* logger_file_name);
+
+ // Max string len for name and session name is 1024 per documentation.
+ static const size_t kMaxStringLen = 1024;
+ // Properties buffer allocates space for header and for
+ // max length for name and session name.
+ static const size_t kBufSize = sizeof(EVENT_TRACE_PROPERTIES)
+ + 2 * sizeof(wchar_t) * (kMaxStringLen);
+
+ private:
+ // The EVENT_TRACE_PROPERTIES structure needs to be overlaid on a
+ // larger buffer to allow storing the logger name and logger file
+ // name contiguously with the structure.
+ union {
+ public:
+ // Our properties header.
+ EVENT_TRACE_PROPERTIES properties_;
+ // The actual size of the buffer is forced by this member.
+ char buffer_[kBufSize];
+ };
+
+ DISALLOW_COPY_AND_ASSIGN(EtwTraceProperties);
+};
+
+// This class implements an ETW controller, which knows how to start and
+// stop event tracing sessions, as well as controlling ETW provider
+// log levels and enable bit masks under the session.
+class BASE_EXPORT EtwTraceController {
+ public:
+ EtwTraceController();
+ ~EtwTraceController();
+
+ // Start a session with given name and properties.
+ HRESULT Start(const wchar_t* session_name, EtwTraceProperties* prop);
+
+ // Starts a session tracing to a file with some default properties.
+ HRESULT StartFileSession(const wchar_t* session_name,
+ const wchar_t* logfile_path,
+ bool realtime = false);
+
+ // Starts a realtime session with some default properties.
+ HRESULT StartRealtimeSession(const wchar_t* session_name,
+ size_t buffer_size);
+
+ // Enables "provider" at "level" for this session.
+ // This will cause all providers registered with the GUID
+ // "provider" to start tracing at the new level, systemwide.
+ HRESULT EnableProvider(const GUID& provider, UCHAR level,
+ ULONG flags = 0xFFFFFFFF);
+ // Disables "provider".
+ HRESULT DisableProvider(const GUID& provider);
+
+ // Stops our session and retrieve the new properties of the session,
+ // properties may be NULL.
+ HRESULT Stop(EtwTraceProperties* properties);
+
+ // Flushes our session and retrieve the current properties,
+ // properties may be NULL.
+ HRESULT Flush(EtwTraceProperties* properties);
+
+ // Static utility functions for controlling
+ // sessions we don't necessarily own.
+ static HRESULT Start(const wchar_t* session_name,
+ EtwTraceProperties* properties,
+ TRACEHANDLE* session_handle);
+
+ static HRESULT Query(const wchar_t* session_name,
+ EtwTraceProperties* properties);
+
+ static HRESULT Update(const wchar_t* session_name,
+ EtwTraceProperties* properties);
+
+ static HRESULT Stop(const wchar_t* session_name,
+ EtwTraceProperties* properties);
+ static HRESULT Flush(const wchar_t* session_name,
+ EtwTraceProperties* properties);
+
+ // Accessors.
+ TRACEHANDLE session() const { return session_; }
+ const wchar_t* session_name() const { return session_name_.c_str(); }
+
+ private:
+ std::wstring session_name_;
+ TRACEHANDLE session_;
+
+ DISALLOW_COPY_AND_ASSIGN(EtwTraceController);
+};
+
+} // namespace win
+} // namespace base
+
+#endif // BASE_WIN_EVENT_TRACE_CONTROLLER_H_
diff --git a/base/win/event_trace_controller_unittest.cc b/base/win/event_trace_controller_unittest.cc
new file mode 100644
index 0000000..2e3a403
--- /dev/null
+++ b/base/win/event_trace_controller_unittest.cc
@@ -0,0 +1,236 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+//
+// Unit tests for event trace controller.
+
+#include <objbase.h>
+#include <initguid.h>
+
+#include "base/file_path.h"
+#include "base/file_util.h"
+#include "base/logging.h"
+#include "base/process.h"
+#include "base/scoped_temp_dir.h"
+#include "base/stringprintf.h"
+#include "base/sys_info.h"
+#include "base/win/event_trace_controller.h"
+#include "base/win/event_trace_provider.h"
+#include "base/win/scoped_handle.h"
+#include "testing/gtest/include/gtest/gtest.h"
+
+namespace {
+
+using base::win::EtwTraceController;
+using base::win::EtwTraceProvider;
+using base::win::EtwTraceProperties;
+
+DEFINE_GUID(kGuidNull,
+ 0x0000000, 0x0000, 0x0000, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0);
+
+const ULONG kTestProviderFlags = 0xCAFEBABE;
+
+class TestingProvider: public EtwTraceProvider {
+ public:
+ explicit TestingProvider(const GUID& provider_name)
+ : EtwTraceProvider(provider_name) {
+ callback_event_.Set(::CreateEvent(NULL, TRUE, FALSE, NULL));
+ }
+
+ void WaitForCallback() {
+ ::WaitForSingleObject(callback_event_.Get(), INFINITE);
+ ::ResetEvent(callback_event_.Get());
+ }
+
+ private:
+ virtual void OnEventsEnabled() {
+ ::SetEvent(callback_event_.Get());
+ }
+ virtual void PostEventsDisabled() {
+ ::SetEvent(callback_event_.Get());
+ }
+
+ base::win::ScopedHandle callback_event_;
+
+ DISALLOW_COPY_AND_ASSIGN(TestingProvider);
+};
+
+} // namespace
+
+TEST(EtwTracePropertiesTest, Initialization) {
+ EtwTraceProperties prop;
+
+ EVENT_TRACE_PROPERTIES* p = prop.get();
+ EXPECT_NE(0u, p->Wnode.BufferSize);
+ EXPECT_EQ(0u, p->Wnode.ProviderId);
+ EXPECT_EQ(0u, p->Wnode.HistoricalContext);
+
+ EXPECT_TRUE(kGuidNull == p->Wnode.Guid);
+ EXPECT_EQ(0, p->Wnode.ClientContext);
+ EXPECT_EQ(WNODE_FLAG_TRACED_GUID, p->Wnode.Flags);
+
+ EXPECT_EQ(0, p->BufferSize);
+ EXPECT_EQ(0, p->MinimumBuffers);
+ EXPECT_EQ(0, p->MaximumBuffers);
+ EXPECT_EQ(0, p->MaximumFileSize);
+ EXPECT_EQ(0, p->LogFileMode);
+ EXPECT_EQ(0, p->FlushTimer);
+ EXPECT_EQ(0, p->EnableFlags);
+ EXPECT_EQ(0, p->AgeLimit);
+
+ EXPECT_EQ(0, p->NumberOfBuffers);
+ EXPECT_EQ(0, p->FreeBuffers);
+ EXPECT_EQ(0, p->EventsLost);
+ EXPECT_EQ(0, p->BuffersWritten);
+ EXPECT_EQ(0, p->LogBuffersLost);
+ EXPECT_EQ(0, p->RealTimeBuffersLost);
+ EXPECT_EQ(0, p->LoggerThreadId);
+ EXPECT_NE(0u, p->LogFileNameOffset);
+ EXPECT_NE(0u, p->LoggerNameOffset);
+}
+
+TEST(EtwTracePropertiesTest, Strings) {
+ EtwTraceProperties prop;
+
+ ASSERT_STREQ(L"", prop.GetLoggerFileName());
+ ASSERT_STREQ(L"", prop.GetLoggerName());
+
+ std::wstring name(1023, L'A');
+ ASSERT_HRESULT_SUCCEEDED(prop.SetLoggerFileName(name.c_str()));
+ ASSERT_HRESULT_SUCCEEDED(prop.SetLoggerName(name.c_str()));
+ ASSERT_STREQ(name.c_str(), prop.GetLoggerFileName());
+ ASSERT_STREQ(name.c_str(), prop.GetLoggerName());
+
+ std::wstring name2(1024, L'A');
+ ASSERT_HRESULT_FAILED(prop.SetLoggerFileName(name2.c_str()));
+ ASSERT_HRESULT_FAILED(prop.SetLoggerName(name2.c_str()));
+}
+
+namespace {
+
+class EtwTraceControllerTest : public testing::Test {
+ public:
+ EtwTraceControllerTest() : session_name_(
+ base::StringPrintf(L"TestSession-%d", base::Process::Current().pid())) {
+ }
+
+ virtual void SetUp() {
+ EtwTraceProperties ignore;
+ EtwTraceController::Stop(session_name_.c_str(), &ignore);
+
+ // Allocate a new provider name GUID for each test.
+ ASSERT_HRESULT_SUCCEEDED(::CoCreateGuid(&test_provider_));
+ }
+
+ virtual void TearDown() {
+ EtwTraceProperties prop;
+ EtwTraceController::Stop(session_name_.c_str(), &prop);
+ }
+
+ protected:
+ GUID test_provider_;
+ std::wstring session_name_;
+};
+
+} // namespace
+
+TEST_F(EtwTraceControllerTest, Initialize) {
+ EtwTraceController controller;
+
+ EXPECT_EQ(NULL, controller.session());
+ EXPECT_STREQ(L"", controller.session_name());
+}
+
+
+TEST_F(EtwTraceControllerTest, StartRealTimeSession) {
+ EtwTraceController controller;
+
+ HRESULT hr = controller.StartRealtimeSession(session_name_.c_str(),
+ 100 * 1024);
+ if (hr == E_ACCESSDENIED) {
+ VLOG(1) << "You must be an administrator to run this test on Vista";
+ return;
+ }
+
+ EXPECT_TRUE(NULL != controller.session());
+ EXPECT_STREQ(session_name_.c_str(), controller.session_name());
+
+ EXPECT_HRESULT_SUCCEEDED(controller.Stop(NULL));
+ EXPECT_EQ(NULL, controller.session());
+ EXPECT_STREQ(L"", controller.session_name());
+}
+
+TEST_F(EtwTraceControllerTest, StartFileSession) {
+ ScopedTempDir temp_dir;
+ ASSERT_TRUE(temp_dir.CreateUniqueTempDir());
+ FilePath temp;
+ ASSERT_TRUE(file_util::CreateTemporaryFileInDir(temp_dir.path(), &temp));
+
+ EtwTraceController controller;
+ HRESULT hr = controller.StartFileSession(session_name_.c_str(),
+ temp.value().c_str());
+ if (hr == E_ACCESSDENIED) {
+ VLOG(1) << "You must be an administrator to run this test on Vista";
+ file_util::Delete(temp, false);
+ return;
+ }
+
+ EXPECT_TRUE(NULL != controller.session());
+ EXPECT_STREQ(session_name_.c_str(), controller.session_name());
+
+ EXPECT_HRESULT_SUCCEEDED(controller.Stop(NULL));
+ EXPECT_EQ(NULL, controller.session());
+ EXPECT_STREQ(L"", controller.session_name());
+ file_util::Delete(temp, false);
+}
+
+TEST_F(EtwTraceControllerTest, EnableDisable) {
+ TestingProvider provider(test_provider_);
+
+ EXPECT_EQ(ERROR_SUCCESS, provider.Register());
+ EXPECT_EQ(NULL, provider.session_handle());
+
+ EtwTraceController controller;
+ HRESULT hr = controller.StartRealtimeSession(session_name_.c_str(),
+ 100 * 1024);
+ if (hr == E_ACCESSDENIED) {
+ VLOG(1) << "You must be an administrator to run this test on Vista";
+ return;
+ }
+
+ EXPECT_HRESULT_SUCCEEDED(controller.EnableProvider(test_provider_,
+ TRACE_LEVEL_VERBOSE, kTestProviderFlags));
+
+ provider.WaitForCallback();
+
+ EXPECT_EQ(TRACE_LEVEL_VERBOSE, provider.enable_level());
+ EXPECT_EQ(kTestProviderFlags, provider.enable_flags());
+
+ EXPECT_HRESULT_SUCCEEDED(controller.DisableProvider(test_provider_));
+
+ provider.WaitForCallback();
+
+ EXPECT_EQ(0, provider.enable_level());
+ EXPECT_EQ(0, provider.enable_flags());
+
+ EXPECT_EQ(ERROR_SUCCESS, provider.Unregister());
+
+ // Enable the provider again, before registering.
+ EXPECT_HRESULT_SUCCEEDED(controller.EnableProvider(test_provider_,
+ TRACE_LEVEL_VERBOSE, kTestProviderFlags));
+
+ // Register the provider again, the settings above
+ // should take immediate effect.
+ EXPECT_EQ(ERROR_SUCCESS, provider.Register());
+
+ EXPECT_EQ(TRACE_LEVEL_VERBOSE, provider.enable_level());
+ EXPECT_EQ(kTestProviderFlags, provider.enable_flags());
+
+ EXPECT_HRESULT_SUCCEEDED(controller.Stop(NULL));
+
+ provider.WaitForCallback();
+
+ // Session should have wound down.
+ EXPECT_EQ(0, provider.enable_level());
+ EXPECT_EQ(0, provider.enable_flags());
+}
diff --git a/base/win/event_trace_provider.cc b/base/win/event_trace_provider.cc
new file mode 100644
index 0000000..8fcf67d
--- /dev/null
+++ b/base/win/event_trace_provider.cc
@@ -0,0 +1,134 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+//
+#include "base/win/event_trace_provider.h"
+#include <windows.h>
+#include <cguid.h>
+
+namespace base {
+namespace win {
+
+TRACE_GUID_REGISTRATION EtwTraceProvider::obligatory_guid_registration_ = {
+ &GUID_NULL,
+ NULL
+};
+
+EtwTraceProvider::EtwTraceProvider(const GUID& provider_name)
+ : provider_name_(provider_name), registration_handle_(NULL),
+ session_handle_(NULL), enable_flags_(0), enable_level_(0) {
+}
+
+EtwTraceProvider::EtwTraceProvider()
+ : provider_name_(GUID_NULL), registration_handle_(NULL),
+ session_handle_(NULL), enable_flags_(0), enable_level_(0) {
+}
+
+EtwTraceProvider::~EtwTraceProvider() {
+ Unregister();
+}
+
+ULONG EtwTraceProvider::EnableEvents(void* buffer) {
+ session_handle_ = ::GetTraceLoggerHandle(buffer);
+ if (NULL == session_handle_) {
+ return ::GetLastError();
+ }
+
+ enable_flags_ = ::GetTraceEnableFlags(session_handle_);
+ enable_level_ = ::GetTraceEnableLevel(session_handle_);
+
+ // Give subclasses a chance to digest the state change.
+ OnEventsEnabled();
+
+ return ERROR_SUCCESS;
+}
+
+ULONG EtwTraceProvider::DisableEvents() {
+ // Give subclasses a chance to digest the state change.
+ OnEventsDisabled();
+
+ enable_level_ = 0;
+ enable_flags_ = 0;
+ session_handle_ = NULL;
+
+ PostEventsDisabled();
+
+ return ERROR_SUCCESS;
+}
+
+ULONG EtwTraceProvider::Callback(WMIDPREQUESTCODE request, void* buffer) {
+ switch (request) {
+ case WMI_ENABLE_EVENTS:
+ return EnableEvents(buffer);
+ case WMI_DISABLE_EVENTS:
+ return DisableEvents();
+ default:
+ return ERROR_INVALID_PARAMETER;
+ }
+ // Not reached.
+}
+
+ULONG WINAPI EtwTraceProvider::ControlCallback(WMIDPREQUESTCODE request,
+ void* context, ULONG *reserved, void* buffer) {
+ EtwTraceProvider *provider = reinterpret_cast<EtwTraceProvider*>(context);
+
+ return provider->Callback(request, buffer);
+}
+
+ULONG EtwTraceProvider::Register() {
+ if (provider_name_ == GUID_NULL)
+ return ERROR_INVALID_NAME;
+
+ return ::RegisterTraceGuids(ControlCallback, this, &provider_name_,
+ 1, &obligatory_guid_registration_, NULL, NULL, ®istration_handle_);
+}
+
+ULONG EtwTraceProvider::Unregister() {
+ // If a session is active, notify subclasses that it's going away.
+ if (session_handle_ != NULL)
+ DisableEvents();
+
+ ULONG ret = ::UnregisterTraceGuids(registration_handle_);
+
+ registration_handle_ = NULL;
+
+ return ret;
+}
+
+ULONG EtwTraceProvider::Log(const EtwEventClass& event_class,
+ EtwEventType type, EtwEventLevel level, const char *message) {
+ if (NULL == session_handle_ || enable_level_ < level)
+ return ERROR_SUCCESS; // No one listening.
+
+ EtwMofEvent<1> event(event_class, type, level);
+
+ event.fields[0].DataPtr = reinterpret_cast<ULONG64>(message);
+ event.fields[0].Length = message ?
+ static_cast<ULONG>(sizeof(message[0]) * (1 + strlen(message))) : 0;
+
+ return ::TraceEvent(session_handle_, &event.header);
+}
+
+ULONG EtwTraceProvider::Log(const EtwEventClass& event_class,
+ EtwEventType type, EtwEventLevel level, const wchar_t *message) {
+ if (NULL == session_handle_ || enable_level_ < level)
+ return ERROR_SUCCESS; // No one listening.
+
+ EtwMofEvent<1> event(event_class, type, level);
+
+ event.fields[0].DataPtr = reinterpret_cast<ULONG64>(message);
+ event.fields[0].Length = message ?
+ static_cast<ULONG>(sizeof(message[0]) * (1 + wcslen(message))) : 0;
+
+ return ::TraceEvent(session_handle_, &event.header);
+}
+
+ULONG EtwTraceProvider::Log(EVENT_TRACE_HEADER* event) {
+ if (enable_level_ < event->Class.Level)
+ return ERROR_SUCCESS;
+
+ return ::TraceEvent(session_handle_, event);
+}
+
+} // namespace win
+} // namespace base
diff --git a/base/win/event_trace_provider.h b/base/win/event_trace_provider.h
new file mode 100644
index 0000000..9f6e7c4
--- /dev/null
+++ b/base/win/event_trace_provider.h
@@ -0,0 +1,175 @@
+// Copyright (c) 2011 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+//
+// Declaration of a Windows event trace provider class, to allow using
+// Windows Event Tracing for logging transport and control.
+#ifndef BASE_WIN_EVENT_TRACE_PROVIDER_H_
+#define BASE_WIN_EVENT_TRACE_PROVIDER_H_
+
+#include <windows.h>
+#include <wmistr.h>
+#include <evntrace.h>
+
+#include "base/base_export.h"
+#include "base/basictypes.h"
+
+namespace base {
+namespace win {
+
+typedef GUID EtwEventClass;
+typedef UCHAR EtwEventType;
+typedef UCHAR EtwEventLevel;
+typedef USHORT EtwEventVersion;
+typedef ULONG EtwEventFlags;
+
+// Base class is a POD for correctness.
+template <size_t N> struct EtwMofEventBase {
+ EVENT_TRACE_HEADER header;
+ MOF_FIELD fields[N];
+};
+
+// Utility class to auto-initialize event trace header structures.
+template <size_t N> class EtwMofEvent: public EtwMofEventBase<N> {
+ public:
+ typedef EtwMofEventBase<N> Super;
+
+ EtwMofEvent() {
+ memset(static_cast<Super*>(this), 0, sizeof(Super));
+ }
+
+ EtwMofEvent(const EtwEventClass& event_class, EtwEventType type,
+ EtwEventLevel level) {
+ memset(static_cast<Super*>(this), 0, sizeof(Super));
+ header.Size = sizeof(Super);
+ header.Guid = event_class;
+ header.Class.Type = type;
+ header.Class.Level = level;
+ header.Flags = WNODE_FLAG_TRACED_GUID | WNODE_FLAG_USE_MOF_PTR;
+ }
+
+ EtwMofEvent(const EtwEventClass& event_class, EtwEventType type,
+ EtwEventVersion version, EtwEventLevel level) {
+ memset(static_cast<Super*>(this), 0, sizeof(Super));
+ header.Size = sizeof(Super);
+ header.Guid = event_class;
+ header.Class.Type = type;
+ header.Class.Version = version;
+ header.Class.Level = level;
+ header.Flags = WNODE_FLAG_TRACED_GUID | WNODE_FLAG_USE_MOF_PTR;
+ }
+
+ void SetField(int field, size_t size, const void *data) {
+ // DCHECK(field < N);
+ if ((field < N) && (size <= kuint32max)) {
+ fields[field].DataPtr = reinterpret_cast<ULONG64>(data);
+ fields[field].Length = static_cast<ULONG>(size);
+ }
+ }
+
+ EVENT_TRACE_HEADER* get() { return& header; }
+
+ private:
+ DISALLOW_COPY_AND_ASSIGN(EtwMofEvent);
+};
+
+// Trace provider with Event Tracing for Windows. The trace provider
+// registers with ETW by its name which is a GUID. ETW calls back to
+// the object whenever the trace level or enable flags for this provider
+// name changes.
+// Users of this class can test whether logging is currently enabled at
+// a particular trace level, and whether particular enable flags are set,
+// before other resources are consumed to generate and issue the log
+// messages themselves.
+class BASE_EXPORT EtwTraceProvider {
+ public:
+ // Creates an event trace provider identified by provider_name, which
+ // will be the name registered with Event Tracing for Windows (ETW).
+ explicit EtwTraceProvider(const GUID& provider_name);
+
+ // Creates an unnamed event trace provider, the provider must be given
+ // a name before registration.
+ EtwTraceProvider();
+ virtual ~EtwTraceProvider();
+
+ // Registers the trace provider with Event Tracing for Windows.
+ // Note: from this point forward ETW may call the provider's control
+ // callback. If the provider's name is enabled in some trace session
+ // already, the callback may occur recursively from this call, so
+ // call this only when you're ready to handle callbacks.
+ ULONG Register();
+ // Unregisters the trace provider with ETW.
+ ULONG Unregister();
+
+ // Accessors.
+ void set_provider_name(const GUID& provider_name) {
+ provider_name_ = provider_name;
+ }
+ const GUID& provider_name() const { return provider_name_; }
+ TRACEHANDLE registration_handle() const { return registration_handle_; }
+ TRACEHANDLE session_handle() const { return session_handle_; }
+ EtwEventFlags enable_flags() const { return enable_flags_; }
+ EtwEventLevel enable_level() const { return enable_level_; }
+
+ // Returns true iff logging should be performed for "level" and "flags".
+ // Note: flags is treated as a bitmask, and should normally have a single
+ // bit set, to test whether to log for a particular sub "facility".
+ bool ShouldLog(EtwEventLevel level, EtwEventFlags flags) {
+ return NULL != session_handle_ && level >= enable_level_ &&
+ (0 != (flags & enable_flags_));
+ }
+
+ // Simple wrappers to log Unicode and ANSI strings.
+ // Do nothing if !ShouldLog(level, 0xFFFFFFFF).
+ ULONG Log(const EtwEventClass& event_class, EtwEventType type,
+ EtwEventLevel level, const char *message);
+ ULONG Log(const EtwEventClass& event_class, EtwEventType type,
+ EtwEventLevel level, const wchar_t *message);
+
+ // Log the provided event.
+ ULONG Log(EVENT_TRACE_HEADER* event);
+
+ protected:
+ // Called after events have been enabled, override in subclasses
+ // to set up state or log at the start of a session.
+ // Note: This function may be called ETW's thread and may be racy,
+ // bring your own locking if needed.
+ virtual void OnEventsEnabled() {}
+
+ // Called just before events are disabled, override in subclasses
+ // to tear down state or log at the end of a session.
+ // Note: This function may be called ETW's thread and may be racy,
+ // bring your own locking if needed.
+ virtual void OnEventsDisabled() {}
+
+ // Called just after events have been disabled, override in subclasses
+ // to tear down state at the end of a session. At this point it's
+ // to late to log anything to the session.
+ // Note: This function may be called ETW's thread and may be racy,
+ // bring your own locking if needed.
+ virtual void PostEventsDisabled() {}
+
+ private:
+ ULONG EnableEvents(PVOID buffer);
+ ULONG DisableEvents();
+ ULONG Callback(WMIDPREQUESTCODE request, PVOID buffer);
+ static ULONG WINAPI ControlCallback(WMIDPREQUESTCODE request, PVOID context,
+ ULONG *reserved, PVOID buffer);
+
+ GUID provider_name_;
+ TRACEHANDLE registration_handle_;
+ TRACEHANDLE session_handle_;
+ EtwEventFlags enable_flags_;
+ EtwEventLevel enable_level_;
+
+ // We don't use this, but on XP we're obliged to pass one in to
+ // RegisterTraceGuids. Non-const, because that's how the API needs it.
+ static TRACE_GUID_REGISTRATION obligatory_guid_registration_;
+
+ DISALLOW_COPY_AND_ASSIGN(EtwTraceProvider);
+};
+
+} // namespace win
+} // namespace base
+
+#endif // BASE_WIN_EVENT_TRACE_PROVIDER_H_
diff --git a/base/win/event_trace_provider_unittest.cc b/base/win/event_trace_provider_unittest.cc
new file mode 100644
index 0000000..55b5ae6
--- /dev/null
+++ b/base/win/event_trace_provider_unittest.cc
@@ -0,0 +1,110 @@
+// Copyright (c) 2010 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+//
+// Unit tests for event trace provider.
+#include "base/win/event_trace_provider.h"
+#include <new>
+#include "testing/gtest/include/gtest/gtest.h"
+#include <initguid.h> // NOLINT - has to be last
+
+namespace {
+
+using base::win::EtwTraceProvider;
+using base::win::EtwMofEvent;
+
+// {7F0FD37F-FA3C-4cd6-9242-DF60967A2CB2}
+DEFINE_GUID(kTestProvider,
+ 0x7f0fd37f, 0xfa3c, 0x4cd6, 0x92, 0x42, 0xdf, 0x60, 0x96, 0x7a, 0x2c, 0xb2);
+
+// {7F0FD37F-FA3C-4cd6-9242-DF60967A2CB2}
+DEFINE_GUID(kTestEventClass,
+ 0x7f0fd37f, 0xfa3c, 0x4cd6, 0x92, 0x42, 0xdf, 0x60, 0x96, 0x7a, 0x2c, 0xb2);
+
+} // namespace
+
+TEST(EtwTraceProviderTest, ToleratesPreCreateInvocations) {
+ // Because the trace provider is used in logging, it's important that
+ // it be possible to use static provider instances without regard to
+ // whether they've been constructed or destructed.
+ // The interface of the class is designed to tolerate this usage.
+ char buf[sizeof(EtwTraceProvider)] = {0};
+ EtwTraceProvider& provider = reinterpret_cast<EtwTraceProvider&>(buf);
+
+ EXPECT_EQ(NULL, provider.registration_handle());
+ EXPECT_EQ(NULL, provider.session_handle());
+ EXPECT_EQ(0, provider.enable_flags());
+ EXPECT_EQ(0, provider.enable_level());
+
+ EXPECT_FALSE(provider.ShouldLog(TRACE_LEVEL_FATAL, 0xfffffff));
+
+ // We expect these not to crash.
+ provider.Log(kTestEventClass, 0, TRACE_LEVEL_FATAL, "foo");
+ provider.Log(kTestEventClass, 0, TRACE_LEVEL_FATAL, L"foo");
+
+ EtwMofEvent<1> dummy(kTestEventClass, 0, TRACE_LEVEL_FATAL);
+ DWORD data = 0;
+ dummy.SetField(0, sizeof(data), &data);
+ provider.Log(dummy.get());
+
+ // Placement-new the provider into our buffer.
+ new (buf) EtwTraceProvider(kTestProvider);
+
+ // Registration is now safe.
+ EXPECT_EQ(ERROR_SUCCESS, provider.Register());
+
+ // Destruct the instance, this should unregister it.
+ provider.EtwTraceProvider::~EtwTraceProvider();
+
+ // And post-destruction, all of the above should still be safe.
+ EXPECT_EQ(NULL, provider.registration_handle());
+ EXPECT_EQ(NULL, provider.session_handle());
+ EXPECT_EQ(0, provider.enable_flags());
+ EXPECT_EQ(0, provider.enable_level());
+
+ EXPECT_FALSE(provider.ShouldLog(TRACE_LEVEL_FATAL, 0xfffffff));
+
+ // We expect these not to crash.
+ provider.Log(kTestEventClass, 0, TRACE_LEVEL_FATAL, "foo");
+ provider.Log(kTestEventClass, 0, TRACE_LEVEL_FATAL, L"foo");
+ provider.Log(dummy.get());
+}
+
+TEST(EtwTraceProviderTest, Initialize) {
+ EtwTraceProvider provider(kTestProvider);
+
+ EXPECT_EQ(NULL, provider.registration_handle());
+ EXPECT_EQ(NULL, provider.session_handle());
+ EXPECT_EQ(0, provider.enable_flags());
+ EXPECT_EQ(0, provider.enable_level());
+}
+
+TEST(EtwTraceProviderTest, Register) {
+ EtwTraceProvider provider(kTestProvider);
+
+ ASSERT_EQ(ERROR_SUCCESS, provider.Register());
+ EXPECT_NE(NULL, provider.registration_handle());
+ ASSERT_EQ(ERROR_SUCCESS, provider.Unregister());
+ EXPECT_EQ(NULL, provider.registration_handle());
+}
+
+TEST(EtwTraceProviderTest, RegisterWithNoNameFails) {
+ EtwTraceProvider provider;
+
+ EXPECT_TRUE(provider.Register() != ERROR_SUCCESS);
+}
+
+TEST(EtwTraceProviderTest, Enable) {
+ EtwTraceProvider provider(kTestProvider);
+
+ ASSERT_EQ(ERROR_SUCCESS, provider.Register());
+ EXPECT_NE(NULL, provider.registration_handle());
+
+ // No session so far.
+ EXPECT_EQ(NULL, provider.session_handle());
+ EXPECT_EQ(0, provider.enable_flags());
+ EXPECT_EQ(0, provider.enable_level());
+
+ ASSERT_EQ(ERROR_SUCCESS, provider.Unregister());
+ EXPECT_EQ(NULL, provider.registration_handle());
+}
diff --git a/base/win/i18n.cc b/base/win/i18n.cc
new file mode 100644
index 0000000..9e523a1
--- /dev/null
+++ b/base/win/i18n.cc
@@ -0,0 +1,169 @@
+// Copyright (c) 2010 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "base/win/i18n.h"
+
+#include <windows.h>
+
+#include "base/logging.h"
+
+namespace {
+
+// Keep this enum in sync with kLanguageFunctionNames.
+enum LanguageFunction {
+ SYSTEM_LANGUAGES,
+ USER_LANGUAGES,
+ PROCESS_LANGUAGES,
+ THREAD_LANGUAGES,
+ NUM_FUNCTIONS
+};
+
+const char kSystemLanguagesFunctionName[] = "GetSystemPreferredUILanguages";
+const char kUserLanguagesFunctionName[] = "GetUserPreferredUILanguages";
+const char kProcessLanguagesFunctionName[] = "GetProcessPreferredUILanguages";
+const char kThreadLanguagesFunctionName[] = "GetThreadPreferredUILanguages";
+
+// Keep this array in sync with enum LanguageFunction.
+const char *const kLanguageFunctionNames[] = {
+ &kSystemLanguagesFunctionName[0],
+ &kUserLanguagesFunctionName[0],
+ &kProcessLanguagesFunctionName[0],
+ &kThreadLanguagesFunctionName[0]
+};
+
+COMPILE_ASSERT(NUM_FUNCTIONS == arraysize(kLanguageFunctionNames),
+ language_function_enum_and_names_out_of_sync);
+
+// Calls one of the MUI Get*PreferredUILanguages functions, placing the result
+// in |languages|. |function| identifies the function to call and |flags| is
+// the function-specific flags (callers must not specify MUI_LANGUAGE_ID or
+// MUI_LANGUAGE_NAME). Returns true if at least one language is placed in
+// |languages|.
+bool GetMUIPreferredUILanguageList(LanguageFunction function, ULONG flags,
+ std::vector<wchar_t>* languages) {
+ DCHECK(0 <= function && NUM_FUNCTIONS > function);
+ DCHECK_EQ(0U, (flags & (MUI_LANGUAGE_ID | MUI_LANGUAGE_NAME)));
+ DCHECK(languages);
+
+ HMODULE kernel32 = GetModuleHandle(L"kernel32.dll");
+ if (NULL != kernel32) {
+ typedef BOOL (WINAPI* GetPreferredUILanguages_Fn)(
+ DWORD, PULONG, PZZWSTR, PULONG);
+ GetPreferredUILanguages_Fn get_preferred_ui_languages =
+ reinterpret_cast<GetPreferredUILanguages_Fn>(
+ GetProcAddress(kernel32, kLanguageFunctionNames[function]));
+ if (NULL != get_preferred_ui_languages) {
+ const ULONG call_flags = flags | MUI_LANGUAGE_NAME;
+ ULONG language_count = 0;
+ ULONG buffer_length = 0;
+ if (get_preferred_ui_languages(call_flags, &language_count, NULL,
+ &buffer_length) &&
+ 0 != buffer_length) {
+ languages->resize(buffer_length);
+ if (get_preferred_ui_languages(call_flags, &language_count,
+ &(*languages)[0], &buffer_length) &&
+ 0 != language_count) {
+ DCHECK(languages->size() == buffer_length);
+ return true;
+ } else {
+ DPCHECK(0 == language_count)
+ << "Failed getting preferred UI languages.";
+ }
+ } else {
+ DPCHECK(0 == buffer_length)
+ << "Failed getting size of preferred UI languages.";
+ }
+ } else {
+ DVLOG(2) << "MUI not available.";
+ }
+ } else {
+ NOTREACHED() << "kernel32.dll not found.";
+ }
+
+ return false;
+}
+
+bool GetUserDefaultUILanguage(std::wstring* language, std::wstring* region) {
+ DCHECK(language);
+
+ LANGID lang_id = ::GetUserDefaultUILanguage();
+ if (LOCALE_CUSTOM_UI_DEFAULT != lang_id) {
+ const LCID locale_id = MAKELCID(lang_id, SORT_DEFAULT);
+ // max size for LOCALE_SISO639LANGNAME and LOCALE_SISO3166CTRYNAME is 9
+ wchar_t result_buffer[9];
+ int result_length =
+ GetLocaleInfo(locale_id, LOCALE_SISO639LANGNAME, &result_buffer[0],
+ arraysize(result_buffer));
+ DPCHECK(0 != result_length) << "Failed getting language id";
+ if (1 < result_length) {
+ language->assign(&result_buffer[0], result_length - 1);
+ region->clear();
+ if (SUBLANG_NEUTRAL != SUBLANGID(lang_id)) {
+ result_length =
+ GetLocaleInfo(locale_id, LOCALE_SISO3166CTRYNAME, &result_buffer[0],
+ arraysize(result_buffer));
+ DPCHECK(0 != result_length) << "Failed getting region id";
+ if (1 < result_length)
+ region->assign(&result_buffer[0], result_length - 1);
+ }
+ return true;
+ }
+ } else {
+ // This is entirely unexpected on pre-Vista, which is the only time we
+ // should try GetUserDefaultUILanguage anyway.
+ NOTREACHED() << "Cannot determine language for a supplemental locale.";
+ }
+ return false;
+}
+
+bool GetPreferredUILanguageList(LanguageFunction function, ULONG flags,
+ std::vector<std::wstring>* languages) {
+ std::vector<wchar_t> buffer;
+ std::wstring language;
+ std::wstring region;
+
+ if (GetMUIPreferredUILanguageList(function, flags, &buffer)) {
+ std::vector<wchar_t>::const_iterator scan = buffer.begin();
+ language.assign(&*scan);
+ while (!language.empty()) {
+ languages->push_back(language);
+ scan += language.size() + 1;
+ language.assign(&*scan);
+ }
+ } else if (GetUserDefaultUILanguage(&language, ®ion)) {
+ // Mimic the MUI behavior of putting the neutral version of the lang after
+ // the regional one (e.g., "fr-CA, fr").
+ if (!region.empty())
+ languages->push_back(std::wstring(language)
+ .append(1, L'-')
+ .append(region));
+ languages->push_back(language);
+ } else {
+ return false;
+ }
+
+ return true;
+}
+
+} // namespace
+
+namespace base {
+namespace win {
+namespace i18n {
+
+bool GetUserPreferredUILanguageList(std::vector<std::wstring>* languages) {
+ DCHECK(languages);
+ return GetPreferredUILanguageList(USER_LANGUAGES, 0, languages);
+}
+
+bool GetThreadPreferredUILanguageList(std::vector<std::wstring>* languages) {
+ DCHECK(languages);
+ return GetPreferredUILanguageList(
+ THREAD_LANGUAGES, MUI_MERGE_SYSTEM_FALLBACK | MUI_MERGE_USER_FALLBACK,
+ languages);
+}
+
+} // namespace i18n
+} // namespace win
+} // namespace base
diff --git a/base/win/i18n.h b/base/win/i18n.h
new file mode 100644
index 0000000..c0379c1
--- /dev/null
+++ b/base/win/i18n.h
@@ -0,0 +1,34 @@
+// Copyright (c) 2011 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef BASE_WIN_I18N_H_
+#define BASE_WIN_I18N_H_
+
+#include <string>
+#include <vector>
+
+#include "base/base_export.h"
+#include "base/basictypes.h"
+
+namespace base {
+namespace win {
+namespace i18n {
+
+// Adds to |languages| the list of user preferred UI languages from MUI, if
+// available, falling-back on the user default UI language otherwise. Returns
+// true if at least one language is added.
+BASE_EXPORT bool GetUserPreferredUILanguageList(
+ std::vector<std::wstring>* languages);
+
+// Adds to |languages| the list of thread, process, user, and system preferred
+// UI languages from MUI, if available, falling-back on the user default UI
+// language otherwise. Returns true if at least one language is added.
+BASE_EXPORT bool GetThreadPreferredUILanguageList(
+ std::vector<std::wstring>* languages);
+
+} // namespace i18n
+} // namespace win
+} // namespace base
+
+#endif // BASE_WIN_I18N_H_
diff --git a/base/win/i18n_unittest.cc b/base/win/i18n_unittest.cc
new file mode 100644
index 0000000..781fc39
--- /dev/null
+++ b/base/win/i18n_unittest.cc
@@ -0,0 +1,42 @@
+// Copyright (c) 2010 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+// This file contains unit tests for Windows internationalization funcs.
+
+#include "testing/gtest/include/gtest/gtest.h"
+
+#include "base/win/i18n.h"
+#include "base/win/windows_version.h"
+
+namespace base {
+namespace win {
+namespace i18n {
+
+// Tests that at least one user preferred UI language can be obtained.
+TEST(I18NTest, GetUserPreferredUILanguageList) {
+ std::vector<std::wstring> languages;
+ EXPECT_TRUE(GetUserPreferredUILanguageList(&languages));
+ EXPECT_NE(static_cast<std::vector<std::wstring>::size_type>(0),
+ languages.size());
+ for (std::vector<std::wstring>::const_iterator scan = languages.begin(),
+ end = languages.end(); scan != end; ++scan) {
+ EXPECT_FALSE((*scan).empty());
+ }
+}
+
+// Tests that at least one thread preferred UI language can be obtained.
+TEST(I18NTest, GetThreadPreferredUILanguageList) {
+ std::vector<std::wstring> languages;
+ EXPECT_TRUE(GetThreadPreferredUILanguageList(&languages));
+ EXPECT_NE(static_cast<std::vector<std::wstring>::size_type>(0),
+ languages.size());
+ for (std::vector<std::wstring>::const_iterator scan = languages.begin(),
+ end = languages.end(); scan != end; ++scan) {
+ EXPECT_FALSE((*scan).empty());
+ }
+}
+
+} // namespace i18n
+} // namespace win
+} // namespace base
diff --git a/base/win/iat_patch_function.cc b/base/win/iat_patch_function.cc
new file mode 100644
index 0000000..a4a8902
--- /dev/null
+++ b/base/win/iat_patch_function.cc
@@ -0,0 +1,278 @@
+// Copyright (c) 2011 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "base/win/iat_patch_function.h"
+
+#include "base/logging.h"
+#include "base/win/pe_image.h"
+
+namespace base {
+namespace win {
+
+namespace {
+
+struct InterceptFunctionInformation {
+ bool finished_operation;
+ const char* imported_from_module;
+ const char* function_name;
+ void* new_function;
+ void** old_function;
+ IMAGE_THUNK_DATA** iat_thunk;
+ DWORD return_code;
+};
+
+void* GetIATFunction(IMAGE_THUNK_DATA* iat_thunk) {
+ if (NULL == iat_thunk) {
+ NOTREACHED();
+ return NULL;
+ }
+
+ // Works around the 64 bit portability warning:
+ // The Function member inside IMAGE_THUNK_DATA is really a pointer
+ // to the IAT function. IMAGE_THUNK_DATA correctly maps to IMAGE_THUNK_DATA32
+ // or IMAGE_THUNK_DATA64 for correct pointer size.
+ union FunctionThunk {
+ IMAGE_THUNK_DATA thunk;
+ void* pointer;
+ } iat_function;
+
+ iat_function.thunk = *iat_thunk;
+ return iat_function.pointer;
+}
+// Change the page protection (of code pages) to writable and copy
+// the data at the specified location
+//
+// Arguments:
+// old_code Target location to copy
+// new_code Source
+// length Number of bytes to copy
+//
+// Returns: Windows error code (winerror.h). NO_ERROR if successful
+DWORD ModifyCode(void* old_code, void* new_code, int length) {
+ if ((NULL == old_code) || (NULL == new_code) || (0 == length)) {
+ NOTREACHED();
+ return ERROR_INVALID_PARAMETER;
+ }
+
+ // Change the page protection so that we can write.
+ DWORD error = NO_ERROR;
+ DWORD old_page_protection = 0;
+ if (VirtualProtect(old_code,
+ length,
+ PAGE_READWRITE,
+ &old_page_protection)) {
+
+ // Write the data.
+ CopyMemory(old_code, new_code, length);
+
+ // Restore the old page protection.
+ error = ERROR_SUCCESS;
+ VirtualProtect(old_code,
+ length,
+ old_page_protection,
+ &old_page_protection);
+ } else {
+ error = GetLastError();
+ NOTREACHED();
+ }
+
+ return error;
+}
+
+bool InterceptEnumCallback(const base::win::PEImage& image, const char* module,
+ DWORD ordinal, const char* name, DWORD hint,
+ IMAGE_THUNK_DATA* iat, void* cookie) {
+ InterceptFunctionInformation* intercept_information =
+ reinterpret_cast<InterceptFunctionInformation*>(cookie);
+
+ if (NULL == intercept_information) {
+ NOTREACHED();
+ return false;
+ }
+
+ DCHECK(module);
+
+ if ((0 == lstrcmpiA(module, intercept_information->imported_from_module)) &&
+ (NULL != name) &&
+ (0 == lstrcmpiA(name, intercept_information->function_name))) {
+ // Save the old pointer.
+ if (NULL != intercept_information->old_function) {
+ *(intercept_information->old_function) = GetIATFunction(iat);
+ }
+
+ if (NULL != intercept_information->iat_thunk) {
+ *(intercept_information->iat_thunk) = iat;
+ }
+
+ // portability check
+ COMPILE_ASSERT(sizeof(iat->u1.Function) ==
+ sizeof(intercept_information->new_function), unknown_IAT_thunk_format);
+
+ // Patch the function.
+ intercept_information->return_code =
+ ModifyCode(&(iat->u1.Function),
+ &(intercept_information->new_function),
+ sizeof(intercept_information->new_function));
+
+ // Terminate further enumeration.
+ intercept_information->finished_operation = true;
+ return false;
+ }
+
+ return true;
+}
+
+// Helper to intercept a function in an import table of a specific
+// module.
+//
+// Arguments:
+// module_handle Module to be intercepted
+// imported_from_module Module that exports the symbol
+// function_name Name of the API to be intercepted
+// new_function Interceptor function
+// old_function Receives the original function pointer
+// iat_thunk Receives pointer to IAT_THUNK_DATA
+// for the API from the import table.
+//
+// Returns: Returns NO_ERROR on success or Windows error code
+// as defined in winerror.h
+DWORD InterceptImportedFunction(HMODULE module_handle,
+ const char* imported_from_module,
+ const char* function_name, void* new_function,
+ void** old_function,
+ IMAGE_THUNK_DATA** iat_thunk) {
+ if ((NULL == module_handle) || (NULL == imported_from_module) ||
+ (NULL == function_name) || (NULL == new_function)) {
+ NOTREACHED();
+ return ERROR_INVALID_PARAMETER;
+ }
+
+ base::win::PEImage target_image(module_handle);
+ if (!target_image.VerifyMagic()) {
+ NOTREACHED();
+ return ERROR_INVALID_PARAMETER;
+ }
+
+ InterceptFunctionInformation intercept_information = {
+ false,
+ imported_from_module,
+ function_name,
+ new_function,
+ old_function,
+ iat_thunk,
+ ERROR_GEN_FAILURE};
+
+ // First go through the IAT. If we don't find the import we are looking
+ // for in IAT, search delay import table.
+ target_image.EnumAllImports(InterceptEnumCallback, &intercept_information);
+ if (!intercept_information.finished_operation) {
+ target_image.EnumAllDelayImports(InterceptEnumCallback,
+ &intercept_information);
+ }
+
+ return intercept_information.return_code;
+}
+
+// Restore intercepted IAT entry with the original function.
+//
+// Arguments:
+// intercept_function Interceptor function
+// original_function Receives the original function pointer
+//
+// Returns: Returns NO_ERROR on success or Windows error code
+// as defined in winerror.h
+DWORD RestoreImportedFunction(void* intercept_function,
+ void* original_function,
+ IMAGE_THUNK_DATA* iat_thunk) {
+ if ((NULL == intercept_function) || (NULL == original_function) ||
+ (NULL == iat_thunk)) {
+ NOTREACHED();
+ return ERROR_INVALID_PARAMETER;
+ }
+
+ if (GetIATFunction(iat_thunk) != intercept_function) {
+ // Check if someone else has intercepted on top of us.
+ // We cannot unpatch in this case, just raise a red flag.
+ NOTREACHED();
+ return ERROR_INVALID_FUNCTION;
+ }
+
+ return ModifyCode(&(iat_thunk->u1.Function),
+ &original_function,
+ sizeof(original_function));
+}
+
+} // namespace
+
+IATPatchFunction::IATPatchFunction()
+ : module_handle_(NULL),
+ original_function_(NULL),
+ iat_thunk_(NULL),
+ intercept_function_(NULL) {
+}
+
+IATPatchFunction::~IATPatchFunction() {
+ if (NULL != intercept_function_) {
+ DWORD error = Unpatch();
+ DCHECK_EQ(static_cast<DWORD>(NO_ERROR), error);
+ }
+}
+
+DWORD IATPatchFunction::Patch(const wchar_t* module,
+ const char* imported_from_module,
+ const char* function_name,
+ void* new_function) {
+ DCHECK_EQ(static_cast<void*>(NULL), original_function_);
+ DCHECK_EQ(static_cast<IMAGE_THUNK_DATA*>(NULL), iat_thunk_);
+ DCHECK_EQ(static_cast<void*>(NULL), intercept_function_);
+
+ HMODULE module_handle = LoadLibraryW(module);
+
+ if (module_handle == NULL) {
+ NOTREACHED();
+ return GetLastError();
+ }
+
+ DWORD error = InterceptImportedFunction(module_handle,
+ imported_from_module,
+ function_name,
+ new_function,
+ &original_function_,
+ &iat_thunk_);
+
+ if (NO_ERROR == error) {
+ DCHECK_NE(original_function_, intercept_function_);
+ module_handle_ = module_handle;
+ intercept_function_ = new_function;
+ } else {
+ FreeLibrary(module_handle);
+ }
+
+ return error;
+}
+
+DWORD IATPatchFunction::Unpatch() {
+ DWORD error = RestoreImportedFunction(intercept_function_,
+ original_function_,
+ iat_thunk_);
+ DCHECK_EQ(static_cast<DWORD>(NO_ERROR), error);
+
+ // Hands off the intercept if we fail to unpatch.
+ // If IATPatchFunction::Unpatch fails during RestoreImportedFunction
+ // it means that we cannot safely unpatch the import address table
+ // patch. In this case its better to be hands off the intercept as
+ // trying to unpatch again in the destructor of IATPatchFunction is
+ // not going to be any safer
+ if (module_handle_)
+ FreeLibrary(module_handle_);
+ module_handle_ = NULL;
+ intercept_function_ = NULL;
+ original_function_ = NULL;
+ iat_thunk_ = NULL;
+
+ return error;
+}
+
+} // namespace win
+} // namespace base
diff --git a/base/win/iat_patch_function.h b/base/win/iat_patch_function.h
new file mode 100644
index 0000000..3ae1f3c
--- /dev/null
+++ b/base/win/iat_patch_function.h
@@ -0,0 +1,72 @@
+// Copyright (c) 2011 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef BASE_WIN_IAT_PATCH_FUNCTION_H_
+#define BASE_WIN_IAT_PATCH_FUNCTION_H_
+
+#include <windows.h>
+
+#include "base/base_export.h"
+#include "base/basictypes.h"
+
+namespace base {
+namespace win {
+
+// A class that encapsulates Import Address Table patching helpers and restores
+// the original function in the destructor.
+//
+// It will intercept functions for a specific DLL imported from another DLL.
+// This is the case when, for example, we want to intercept
+// CertDuplicateCertificateContext function (exported from crypt32.dll) called
+// by wininet.dll.
+class BASE_EXPORT IATPatchFunction {
+ public:
+ IATPatchFunction();
+ ~IATPatchFunction();
+
+ // Intercept a function in an import table of a specific
+ // module. Save the original function and the import
+ // table address. These values will be used later
+ // during Unpatch
+ //
+ // Arguments:
+ // module Module to be intercepted
+ // imported_from_module Module that exports the 'function_name'
+ // function_name Name of the API to be intercepted
+ //
+ // Returns: Windows error code (winerror.h). NO_ERROR if successful
+ //
+ // Note: Patching a function will make the IAT patch take some "ownership" on
+ // |module|. It will LoadLibrary(module) to keep the DLL alive until a call
+ // to Unpatch(), which will call FreeLibrary() and allow the module to be
+ // unloaded. The idea is to help prevent the DLL from going away while a
+ // patch is still active.
+ DWORD Patch(const wchar_t* module,
+ const char* imported_from_module,
+ const char* function_name,
+ void* new_function);
+
+ // Unpatch the IAT entry using internally saved original
+ // function.
+ //
+ // Returns: Windows error code (winerror.h). NO_ERROR if successful
+ DWORD Unpatch();
+
+ bool is_patched() const {
+ return (NULL != intercept_function_);
+ }
+
+ private:
+ HMODULE module_handle_;
+ void* intercept_function_;
+ void* original_function_;
+ IMAGE_THUNK_DATA* iat_thunk_;
+
+ DISALLOW_COPY_AND_ASSIGN(IATPatchFunction);
+};
+
+} // namespace win
+} // namespace base
+
+#endif // BASE_WIN_IAT_PATCH_FUNCTION_H_
diff --git a/base/win/iunknown_impl.cc b/base/win/iunknown_impl.cc
new file mode 100644
index 0000000..9baa0f3
--- /dev/null
+++ b/base/win/iunknown_impl.cc
@@ -0,0 +1,42 @@
+// Copyright (c) 2011 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "base/win/iunknown_impl.h"
+
+namespace base {
+namespace win {
+
+IUnknownImpl::IUnknownImpl()
+ : ref_count_(0) {
+}
+
+IUnknownImpl::~IUnknownImpl() {
+}
+
+ULONG STDMETHODCALLTYPE IUnknownImpl::AddRef() {
+ base::AtomicRefCountInc(&ref_count_);
+ return 1;
+}
+
+ULONG STDMETHODCALLTYPE IUnknownImpl::Release() {
+ if (!base::AtomicRefCountDec(&ref_count_)) {
+ delete this;
+ return 0;
+ }
+ return 1;
+}
+
+STDMETHODIMP IUnknownImpl::QueryInterface(REFIID riid, void** ppv) {
+ if (riid == IID_IUnknown) {
+ *ppv = static_cast<IUnknown*>(this);
+ AddRef();
+ return S_OK;
+ }
+
+ *ppv = NULL;
+ return E_NOINTERFACE;
+}
+
+} // namespace win
+} // namespace base
diff --git a/base/win/iunknown_impl.h b/base/win/iunknown_impl.h
new file mode 100644
index 0000000..ff7e870
--- /dev/null
+++ b/base/win/iunknown_impl.h
@@ -0,0 +1,38 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef BASE_WIN_IUNKNOWN_IMPL_H_
+#define BASE_WIN_IUNKNOWN_IMPL_H_
+
+#include <unknwn.h>
+
+#include "base/atomic_ref_count.h"
+#include "base/base_export.h"
+#include "base/compiler_specific.h"
+
+namespace base {
+namespace win {
+
+// IUnknown implementation for other classes to derive from.
+class BASE_EXPORT IUnknownImpl : public IUnknown {
+ public:
+ IUnknownImpl();
+
+ virtual ULONG STDMETHODCALLTYPE AddRef() OVERRIDE;
+ virtual ULONG STDMETHODCALLTYPE Release() OVERRIDE;
+
+ // Subclasses should extend this to return any interfaces they provide.
+ virtual STDMETHODIMP QueryInterface(REFIID riid, void** ppv) OVERRIDE;
+
+ protected:
+ virtual ~IUnknownImpl();
+
+ private:
+ AtomicRefCount ref_count_;
+};
+
+} // namespace win
+} // namespace base
+
+#endif // BASE_WIN_IUNKNOWN_IMPL_H_
diff --git a/base/win/iunknown_impl_unittest.cc b/base/win/iunknown_impl_unittest.cc
new file mode 100644
index 0000000..db86214
--- /dev/null
+++ b/base/win/iunknown_impl_unittest.cc
@@ -0,0 +1,51 @@
+// Copyright (c) 2011 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "base/win/iunknown_impl.h"
+
+#include "base/win/scoped_com_initializer.h"
+#include "testing/gtest/include/gtest/gtest.h"
+
+namespace base {
+namespace win {
+
+class TestIUnknownImplSubclass : public IUnknownImpl {
+ public:
+ TestIUnknownImplSubclass() {
+ ++instance_count;
+ }
+ virtual ~TestIUnknownImplSubclass() {
+ --instance_count;
+ }
+ static int instance_count;
+};
+
+// static
+int TestIUnknownImplSubclass::instance_count = 0;
+
+TEST(IUnknownImplTest, IUnknownImpl) {
+ ScopedCOMInitializer com_initializer;
+
+ EXPECT_EQ(0, TestIUnknownImplSubclass::instance_count);
+ IUnknown* u = new TestIUnknownImplSubclass();
+
+ EXPECT_EQ(1, TestIUnknownImplSubclass::instance_count);
+
+ EXPECT_EQ(1, u->AddRef());
+ EXPECT_EQ(1, u->AddRef());
+
+ IUnknown* other = NULL;
+ EXPECT_EQ(E_NOINTERFACE, u->QueryInterface(
+ IID_IDispatch, reinterpret_cast<void**>(&other)));
+ EXPECT_EQ(S_OK, u->QueryInterface(
+ IID_IUnknown, reinterpret_cast<void**>(&other)));
+ other->Release();
+
+ EXPECT_EQ(1, u->Release());
+ EXPECT_EQ(0, u->Release());
+ EXPECT_EQ(0, TestIUnknownImplSubclass::instance_count);
+}
+
+} // namespace win
+} // namespace base
diff --git a/base/win/metro.cc b/base/win/metro.cc
new file mode 100644
index 0000000..22bc5e8
--- /dev/null
+++ b/base/win/metro.cc
@@ -0,0 +1,189 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "base/win/metro.h"
+
+#include "base/message_loop.h"
+#include "base/string_util.h"
+#include "base/win/scoped_comptr.h"
+#include "base/win/windows_version.h"
+
+namespace base {
+namespace win {
+
+namespace {
+bool g_should_tsf_aware_required = false;
+}
+
+HMODULE GetMetroModule() {
+ const HMODULE kUninitialized = reinterpret_cast<HMODULE>(1);
+ static HMODULE metro_module = kUninitialized;
+
+ if (metro_module == kUninitialized) {
+ // Initialize the cache, note that the initialization is idempotent
+ // under the assumption that metro_driver is never unloaded, so the
+ // race to this assignment is safe.
+ metro_module = GetModuleHandleA("metro_driver.dll");
+ if (metro_module != NULL) {
+ // This must be a metro process if the metro_driver is loaded.
+ DCHECK(IsMetroProcess());
+ }
+ }
+
+ DCHECK(metro_module != kUninitialized);
+ return metro_module;
+}
+
+bool IsMetroProcess() {
+ enum ImmersiveState {
+ kImmersiveUnknown,
+ kImmersiveTrue,
+ kImmersiveFalse
+ };
+ // The immersive state of a process can never change.
+ // Look it up once and cache it here.
+ static ImmersiveState state = kImmersiveUnknown;
+
+ if (state == kImmersiveUnknown) {
+ if (IsProcessImmersive(::GetCurrentProcess())) {
+ state = kImmersiveTrue;
+ } else {
+ state = kImmersiveFalse;
+ }
+ }
+ DCHECK_NE(kImmersiveUnknown, state);
+ return state == kImmersiveTrue;
+}
+
+bool IsProcessImmersive(HANDLE process) {
+ typedef BOOL (WINAPI* IsImmersiveProcessFunc)(HANDLE process);
+ HMODULE user32 = ::GetModuleHandleA("user32.dll");
+ DCHECK(user32 != NULL);
+
+ IsImmersiveProcessFunc is_immersive_process =
+ reinterpret_cast<IsImmersiveProcessFunc>(
+ ::GetProcAddress(user32, "IsImmersiveProcess"));
+
+ if (is_immersive_process)
+ return is_immersive_process(process) ? true: false;
+ return false;
+}
+
+bool IsTSFAwareRequired() {
+ // Although this function is equal to IsMetroProcess at this moment,
+ // Chrome for Win7 and Vista may support TSF in the future.
+ return g_should_tsf_aware_required || IsMetroProcess();
+}
+
+void SetForceToUseTSF() {
+ g_should_tsf_aware_required = true;
+
+ // Since Windows 8 Metro mode disables CUAS (Cicero Unaware Application
+ // Support) via ImmDisableLegacyIME API, Chrome must be fully TSF-aware on
+ // Metro mode. For debugging purposes, explicitly call ImmDisableLegacyIME so
+ // that one can test TSF functionality even on Windows 8 desktop mode. Note
+ // that CUAS cannot be disabled on Windows Vista/7 where ImmDisableLegacyIME
+ // is not available.
+ typedef BOOL (* ImmDisableLegacyIMEFunc)();
+ HMODULE imm32 = ::GetModuleHandleA("imm32.dll");
+ if (imm32 == NULL)
+ return;
+
+ ImmDisableLegacyIMEFunc imm_disable_legacy_ime =
+ reinterpret_cast<ImmDisableLegacyIMEFunc>(
+ ::GetProcAddress(imm32, "ImmDisableLegacyIME"));
+
+ if (imm_disable_legacy_ime == NULL) {
+ // Unsupported API, just do nothing.
+ return;
+ }
+
+ if (!imm_disable_legacy_ime()) {
+ DVLOG(1) << "Failed to disable legacy IME.";
+ }
+}
+
+wchar_t* LocalAllocAndCopyString(const string16& src) {
+ size_t dest_size = (src.length() + 1) * sizeof(wchar_t);
+ wchar_t* dest = reinterpret_cast<wchar_t*>(LocalAlloc(LPTR, dest_size));
+ base::wcslcpy(dest, src.c_str(), dest_size);
+ return dest;
+}
+
+bool IsTouchEnabled() {
+ int value = GetSystemMetrics(SM_DIGITIZER);
+ return (value & (NID_READY | NID_INTEGRATED_TOUCH)) ==
+ (NID_READY | NID_INTEGRATED_TOUCH);
+}
+
+bool IsParentalControlActivityLoggingOn() {
+ // Query this info on Windows Vista and above.
+ if (base::win::GetVersion() < base::win::VERSION_VISTA)
+ return false;
+
+ static bool parental_control_logging_required = false;
+ static bool parental_control_status_determined = false;
+
+ if (parental_control_status_determined)
+ return parental_control_logging_required;
+
+ parental_control_status_determined = true;
+
+ ScopedComPtr<IWindowsParentalControlsCore> parent_controls;
+ HRESULT hr = parent_controls.CreateInstance(
+ __uuidof(WindowsParentalControls));
+ if (FAILED(hr))
+ return false;
+
+ ScopedComPtr<IWPCSettings> settings;
+ hr = parent_controls->GetUserSettings(NULL, settings.Receive());
+ if (FAILED(hr))
+ return false;
+
+ unsigned long restrictions = 0;
+ settings->GetRestrictions(&restrictions);
+
+ parental_control_logging_required =
+ (restrictions & WPCFLAG_LOGGING_REQUIRED) == WPCFLAG_LOGGING_REQUIRED;
+ return parental_control_logging_required;
+}
+
+// Metro driver exports for getting the launch type, initial url, initial
+// search term, etc.
+extern "C" {
+typedef const wchar_t* (*GetInitialUrl)();
+typedef const wchar_t* (*GetInitialSearchString)();
+typedef base::win::MetroLaunchType (*GetLaunchType)(
+ base::win::MetroPreviousExecutionState* previous_state);
+}
+
+MetroLaunchType GetMetroLaunchParams(string16* params) {
+ HMODULE metro = base::win::GetMetroModule();
+ if (!metro)
+ return base::win::METRO_LAUNCH_ERROR;
+
+ GetLaunchType get_launch_type = reinterpret_cast<GetLaunchType>(
+ ::GetProcAddress(metro, "GetLaunchType"));
+ DCHECK(get_launch_type);
+
+ base::win::MetroLaunchType launch_type = get_launch_type(NULL);
+
+ if ((launch_type == base::win::METRO_PROTOCOL) ||
+ (launch_type == base::win::METRO_LAUNCH)) {
+ GetInitialUrl initial_metro_url = reinterpret_cast<GetInitialUrl>(
+ ::GetProcAddress(metro, "GetInitialUrl"));
+ DCHECK(initial_metro_url);
+ *params = initial_metro_url();
+ } else if (launch_type == base::win::METRO_SEARCH) {
+ GetInitialSearchString initial_search_string =
+ reinterpret_cast<GetInitialSearchString>(
+ ::GetProcAddress(metro, "GetInitialSearchString"));
+ DCHECK(initial_search_string);
+ *params = initial_search_string();
+ }
+ return launch_type;
+}
+
+} // namespace win
+} // namespace base
diff --git a/base/win/metro.h b/base/win/metro.h
new file mode 100644
index 0000000..a43bbc3
--- /dev/null
+++ b/base/win/metro.h
@@ -0,0 +1,103 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef BASE_WIN_METRO_H_
+#define BASE_WIN_METRO_H_
+
+#include <windows.h>
+#include <wpcapi.h>
+
+#include "base/base_export.h"
+#include "base/string16.h"
+
+namespace base {
+namespace win {
+
+// Identifies the type of the metro launch.
+enum MetroLaunchType {
+ METRO_LAUNCH,
+ METRO_SEARCH,
+ METRO_SHARE,
+ METRO_FILE,
+ METRO_PROTOCOL,
+ METRO_LAUNCH_ERROR,
+ METRO_LASTLAUNCHTYPE,
+};
+
+// In metro mode, this enum identifies the last execution state, i.e. whether
+// we crashed, terminated, etc.
+enum MetroPreviousExecutionState {
+ NOTRUNNING,
+ RUNNING,
+ SUSPENDED,
+ TERMINATED,
+ CLOSEDBYUSER,
+ LASTEXECUTIONSTATE,
+};
+
+// Contains information about the currently displayed tab in metro mode.
+struct CurrentTabInfo {
+ wchar_t* title;
+ wchar_t* url;
+};
+
+// Returns the handle to the metro dll loaded in the process. A NULL return
+// indicates that the metro dll was not loaded in the process.
+BASE_EXPORT HMODULE GetMetroModule();
+
+// Returns true if this process is running as an immersive program
+// in Windows Metro mode.
+BASE_EXPORT bool IsMetroProcess();
+
+// Returns true if the process identified by the handle passed in is an
+// immersive (Metro) process.
+BASE_EXPORT bool IsProcessImmersive(HANDLE process);
+
+// Returns true if this process is running under Text Services Framework (TSF)
+// and browser must be TSF-aware.
+BASE_EXPORT bool IsTSFAwareRequired();
+
+// Sets browser to use Text Services Framework (TSF) regardless of process
+// status. On Windows 8, this function also disables CUAS (Cicero Unaware
+// Application Support) to emulate Windows Metro mode in terms of IME
+// functionality. This should be beneficial in QA process because on can test
+// IME functionality in Windows 8 desktop mode.
+BASE_EXPORT void SetForceToUseTSF();
+
+// Allocates and returns the destination string via the LocalAlloc API after
+// copying the src to it.
+BASE_EXPORT wchar_t* LocalAllocAndCopyString(const string16& src);
+
+// Returns true if the screen supports touch.
+BASE_EXPORT bool IsTouchEnabled();
+
+// Returns true if Windows Parental control activity logging is enabled. This
+// feature is available on Windows Vista and beyond.
+// This function should ideally be called on the UI thread.
+BASE_EXPORT bool IsParentalControlActivityLoggingOn();
+
+// Returns the type of launch and the activation params. For example if the
+// the launch is for METRO_PROTOCOL then the params is a url.
+BASE_EXPORT MetroLaunchType GetMetroLaunchParams(string16* params);
+
+// Handler function for the buttons on a metro dialog box
+typedef void (*MetroDialogButtonPressedHandler)();
+
+// Handler function invoked when a metro style notification is clicked.
+typedef void (*MetroNotificationClickedHandler)(const wchar_t* context);
+
+// Function to display metro style notifications.
+typedef void (*MetroNotification)(const char* origin_url,
+ const char* icon_url,
+ const wchar_t* title,
+ const wchar_t* body,
+ const wchar_t* display_source,
+ const char* notification_id,
+ MetroNotificationClickedHandler handler,
+ const wchar_t* handler_context);
+
+} // namespace win
+} // namespace base
+
+#endif // BASE_WIN_METRO_H_
diff --git a/base/win/object_watcher.cc b/base/win/object_watcher.cc
new file mode 100644
index 0000000..ebe596f
--- /dev/null
+++ b/base/win/object_watcher.cc
@@ -0,0 +1,111 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "base/win/object_watcher.h"
+
+#include "base/bind.h"
+#include "base/logging.h"
+
+namespace base {
+namespace win {
+
+//-----------------------------------------------------------------------------
+
+ObjectWatcher::ObjectWatcher()
+ : weak_factory_(this),
+ object_(NULL),
+ wait_object_(NULL),
+ origin_loop_(NULL) {
+}
+
+ObjectWatcher::~ObjectWatcher() {
+ StopWatching();
+}
+
+bool ObjectWatcher::StartWatching(HANDLE object, Delegate* delegate) {
+ CHECK(delegate);
+ if (wait_object_) {
+ NOTREACHED() << "Already watching an object";
+ return false;
+ }
+
+ // Since our job is to just notice when an object is signaled and report the
+ // result back to this thread, we can just run on a Windows wait thread.
+ DWORD wait_flags = WT_EXECUTEINWAITTHREAD | WT_EXECUTEONLYONCE;
+
+ // DoneWaiting can be synchronously called from RegisterWaitForSingleObject,
+ // so set up all state now.
+ callback_ = base::Bind(&ObjectWatcher::Signal, weak_factory_.GetWeakPtr(),
+ delegate);
+ object_ = object;
+ origin_loop_ = MessageLoop::current();
+
+ if (!RegisterWaitForSingleObject(&wait_object_, object, DoneWaiting,
+ this, INFINITE, wait_flags)) {
+ NOTREACHED() << "RegisterWaitForSingleObject failed: " << GetLastError();
+ object_ = NULL;
+ wait_object_ = NULL;
+ return false;
+ }
+
+ // We need to know if the current message loop is going away so we can
+ // prevent the wait thread from trying to access a dead message loop.
+ MessageLoop::current()->AddDestructionObserver(this);
+ return true;
+}
+
+bool ObjectWatcher::StopWatching() {
+ if (!wait_object_)
+ return false;
+
+ // Make sure ObjectWatcher is used in a single-threaded fashion.
+ DCHECK(origin_loop_ == MessageLoop::current());
+
+ // Blocking call to cancel the wait. Any callbacks already in progress will
+ // finish before we return from this call.
+ if (!UnregisterWaitEx(wait_object_, INVALID_HANDLE_VALUE)) {
+ NOTREACHED() << "UnregisterWaitEx failed: " << GetLastError();
+ return false;
+ }
+
+ weak_factory_.InvalidateWeakPtrs();
+ object_ = NULL;
+ wait_object_ = NULL;
+
+ MessageLoop::current()->RemoveDestructionObserver(this);
+ return true;
+}
+
+HANDLE ObjectWatcher::GetWatchedObject() {
+ return object_;
+}
+
+// static
+void CALLBACK ObjectWatcher::DoneWaiting(void* param, BOOLEAN timed_out) {
+ DCHECK(!timed_out);
+
+ // The destructor blocks on any callbacks that are in flight, so we know that
+ // that is always a pointer to a valid ObjectWater.
+ ObjectWatcher* that = static_cast<ObjectWatcher*>(param);
+ that->origin_loop_->PostTask(FROM_HERE, that->callback_);
+ that->callback_.Reset();
+}
+
+void ObjectWatcher::Signal(Delegate* delegate) {
+ // Signaling the delegate may result in our destruction or a nested call to
+ // StartWatching(). As a result, we save any state we need and clear previous
+ // watcher state before signaling the delegate.
+ HANDLE object = object_;
+ StopWatching();
+ delegate->OnObjectSignaled(object);
+}
+
+void ObjectWatcher::WillDestroyCurrentMessageLoop() {
+ // Need to shutdown the watch so that we don't try to access the MessageLoop
+ // after this point.
+ StopWatching();
+}
+
+} // namespace win
+} // namespace base
diff --git a/base/win/object_watcher.h b/base/win/object_watcher.h
new file mode 100644
index 0000000..742f2b0
--- /dev/null
+++ b/base/win/object_watcher.h
@@ -0,0 +1,103 @@
+// Copyright (c) 2011 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef BASE_WIN_OBJECT_WATCHER_H_
+#define BASE_WIN_OBJECT_WATCHER_H_
+
+#include <windows.h>
+
+#include "base/base_export.h"
+#include "base/callback.h"
+#include "base/memory/weak_ptr.h"
+#include "base/message_loop.h"
+
+namespace base {
+namespace win {
+
+// A class that provides a means to asynchronously wait for a Windows object to
+// become signaled. It is an abstraction around RegisterWaitForSingleObject
+// that provides a notification callback, OnObjectSignaled, that runs back on
+// the origin thread (i.e., the thread that called StartWatching).
+//
+// This class acts like a smart pointer such that when it goes out-of-scope,
+// UnregisterWaitEx is automatically called, and any in-flight notification is
+// suppressed.
+//
+// Typical usage:
+//
+// class MyClass : public base::ObjectWatcher::Delegate {
+// public:
+// void DoStuffWhenSignaled(HANDLE object) {
+// watcher_.StartWatching(object, this);
+// }
+// virtual void OnObjectSignaled(HANDLE object) {
+// // OK, time to do stuff!
+// }
+// private:
+// base::ObjectWatcher watcher_;
+// };
+//
+// In the above example, MyClass wants to "do stuff" when object becomes
+// signaled. ObjectWatcher makes this task easy. When MyClass goes out of
+// scope, the watcher_ will be destroyed, and there is no need to worry about
+// OnObjectSignaled being called on a deleted MyClass pointer. Easy!
+// If the object is already signaled before being watched, OnObjectSignaled is
+// still called after (but not necessarily immediately after) watch is started.
+//
+class BASE_EXPORT ObjectWatcher : public MessageLoop::DestructionObserver {
+ public:
+ class BASE_EXPORT Delegate {
+ public:
+ virtual ~Delegate() {}
+ // Called from the MessageLoop when a signaled object is detected. To
+ // continue watching the object, StartWatching must be called again.
+ virtual void OnObjectSignaled(HANDLE object) = 0;
+ };
+
+ ObjectWatcher();
+ ~ObjectWatcher();
+
+ // When the object is signaled, the given delegate is notified on the thread
+ // where StartWatching is called. The ObjectWatcher is not responsible for
+ // deleting the delegate.
+ //
+ // Returns true if the watch was started. Otherwise, false is returned.
+ //
+ bool StartWatching(HANDLE object, Delegate* delegate);
+
+ // Stops watching. Does nothing if the watch has already completed. If the
+ // watch is still active, then it is canceled, and the associated delegate is
+ // not notified.
+ //
+ // Returns true if the watch was canceled. Otherwise, false is returned.
+ //
+ bool StopWatching();
+
+ // Returns the handle of the object being watched, or NULL if the object
+ // watcher is stopped.
+ HANDLE GetWatchedObject();
+
+ private:
+ // Called on a background thread when done waiting.
+ static void CALLBACK DoneWaiting(void* param, BOOLEAN timed_out);
+
+ void Signal(Delegate* delegate);
+
+ // MessageLoop::DestructionObserver implementation:
+ virtual void WillDestroyCurrentMessageLoop();
+
+ // Internal state.
+ WeakPtrFactory<ObjectWatcher> weak_factory_;
+ Closure callback_;
+ HANDLE object_; // The object being watched
+ HANDLE wait_object_; // Returned by RegisterWaitForSingleObject
+ MessageLoop* origin_loop_; // Used to get back to the origin thread
+
+ DISALLOW_COPY_AND_ASSIGN(ObjectWatcher);
+};
+
+} // namespace win
+} // namespace base
+
+#endif // BASE_OBJECT_WATCHER_H_
diff --git a/base/win/object_watcher_unittest.cc b/base/win/object_watcher_unittest.cc
new file mode 100644
index 0000000..e8484b9
--- /dev/null
+++ b/base/win/object_watcher_unittest.cc
@@ -0,0 +1,172 @@
+// Copyright (c) 2011 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include <process.h>
+
+#include "base/message_loop.h"
+#include "base/win/object_watcher.h"
+#include "testing/gtest/include/gtest/gtest.h"
+
+namespace base {
+namespace win {
+
+namespace {
+
+class QuitDelegate : public ObjectWatcher::Delegate {
+ public:
+ virtual void OnObjectSignaled(HANDLE object) {
+ MessageLoop::current()->Quit();
+ }
+};
+
+class DecrementCountDelegate : public ObjectWatcher::Delegate {
+ public:
+ explicit DecrementCountDelegate(int* counter) : counter_(counter) {
+ }
+ virtual void OnObjectSignaled(HANDLE object) {
+ --(*counter_);
+ }
+ private:
+ int* counter_;
+};
+
+void RunTest_BasicSignal(MessageLoop::Type message_loop_type) {
+ MessageLoop message_loop(message_loop_type);
+
+ ObjectWatcher watcher;
+ EXPECT_EQ(NULL, watcher.GetWatchedObject());
+
+ // A manual-reset event that is not yet signaled.
+ HANDLE event = CreateEvent(NULL, TRUE, FALSE, NULL);
+
+ QuitDelegate delegate;
+ bool ok = watcher.StartWatching(event, &delegate);
+ EXPECT_TRUE(ok);
+ EXPECT_EQ(event, watcher.GetWatchedObject());
+
+ SetEvent(event);
+
+ MessageLoop::current()->Run();
+
+ EXPECT_EQ(NULL, watcher.GetWatchedObject());
+ CloseHandle(event);
+}
+
+void RunTest_BasicCancel(MessageLoop::Type message_loop_type) {
+ MessageLoop message_loop(message_loop_type);
+
+ ObjectWatcher watcher;
+
+ // A manual-reset event that is not yet signaled.
+ HANDLE event = CreateEvent(NULL, TRUE, FALSE, NULL);
+
+ QuitDelegate delegate;
+ bool ok = watcher.StartWatching(event, &delegate);
+ EXPECT_TRUE(ok);
+
+ watcher.StopWatching();
+
+ CloseHandle(event);
+}
+
+void RunTest_CancelAfterSet(MessageLoop::Type message_loop_type) {
+ MessageLoop message_loop(message_loop_type);
+
+ ObjectWatcher watcher;
+
+ int counter = 1;
+ DecrementCountDelegate delegate(&counter);
+
+ // A manual-reset event that is not yet signaled.
+ HANDLE event = CreateEvent(NULL, TRUE, FALSE, NULL);
+
+ bool ok = watcher.StartWatching(event, &delegate);
+ EXPECT_TRUE(ok);
+
+ SetEvent(event);
+
+ // Let the background thread do its business
+ Sleep(30);
+
+ watcher.StopWatching();
+
+ MessageLoop::current()->RunUntilIdle();
+
+ // Our delegate should not have fired.
+ EXPECT_EQ(1, counter);
+
+ CloseHandle(event);
+}
+
+void RunTest_SignalBeforeWatch(MessageLoop::Type message_loop_type) {
+ MessageLoop message_loop(message_loop_type);
+
+ ObjectWatcher watcher;
+
+ // A manual-reset event that is signaled before we begin watching.
+ HANDLE event = CreateEvent(NULL, TRUE, TRUE, NULL);
+
+ QuitDelegate delegate;
+ bool ok = watcher.StartWatching(event, &delegate);
+ EXPECT_TRUE(ok);
+
+ MessageLoop::current()->Run();
+
+ EXPECT_EQ(NULL, watcher.GetWatchedObject());
+ CloseHandle(event);
+}
+
+void RunTest_OutlivesMessageLoop(MessageLoop::Type message_loop_type) {
+ // Simulate a MessageLoop that dies before an ObjectWatcher. This ordinarily
+ // doesn't happen when people use the Thread class, but it can happen when
+ // people use the Singleton pattern or atexit.
+ HANDLE event = CreateEvent(NULL, TRUE, FALSE, NULL); // not signaled
+ {
+ ObjectWatcher watcher;
+ {
+ MessageLoop message_loop(message_loop_type);
+
+ QuitDelegate delegate;
+ watcher.StartWatching(event, &delegate);
+ }
+ }
+ CloseHandle(event);
+}
+
+} // namespace
+
+//-----------------------------------------------------------------------------
+
+TEST(ObjectWatcherTest, BasicSignal) {
+ RunTest_BasicSignal(MessageLoop::TYPE_DEFAULT);
+ RunTest_BasicSignal(MessageLoop::TYPE_IO);
+ RunTest_BasicSignal(MessageLoop::TYPE_UI);
+}
+
+TEST(ObjectWatcherTest, BasicCancel) {
+ RunTest_BasicCancel(MessageLoop::TYPE_DEFAULT);
+ RunTest_BasicCancel(MessageLoop::TYPE_IO);
+ RunTest_BasicCancel(MessageLoop::TYPE_UI);
+}
+
+TEST(ObjectWatcherTest, CancelAfterSet) {
+ RunTest_CancelAfterSet(MessageLoop::TYPE_DEFAULT);
+ RunTest_CancelAfterSet(MessageLoop::TYPE_IO);
+ RunTest_CancelAfterSet(MessageLoop::TYPE_UI);
+}
+
+TEST(ObjectWatcherTest, SignalBeforeWatch) {
+ RunTest_SignalBeforeWatch(MessageLoop::TYPE_DEFAULT);
+ RunTest_SignalBeforeWatch(MessageLoop::TYPE_IO);
+ RunTest_SignalBeforeWatch(MessageLoop::TYPE_UI);
+}
+
+TEST(ObjectWatcherTest, OutlivesMessageLoop) {
+ RunTest_OutlivesMessageLoop(MessageLoop::TYPE_DEFAULT);
+ RunTest_OutlivesMessageLoop(MessageLoop::TYPE_IO);
+ RunTest_OutlivesMessageLoop(MessageLoop::TYPE_UI);
+}
+
+} // namespace win
+} // namespace base
diff --git a/base/win/pe_image.cc b/base/win/pe_image.cc
new file mode 100644
index 0000000..fcf03c1
--- /dev/null
+++ b/base/win/pe_image.cc
@@ -0,0 +1,570 @@
+// Copyright (c) 2010 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+// This file implements PEImage, a generic class to manipulate PE files.
+// This file was adapted from GreenBorder's Code.
+
+#include "base/win/pe_image.h"
+
+namespace base {
+namespace win {
+
+#if defined(_WIN64) && !defined(NACL_WIN64)
+// TODO(rvargas): Bug 27218. Make sure this is ok.
+#error This code is not tested on x64. Please make sure all the base unit tests\
+ pass before doing any real work. The current unit tests don't test the\
+ differences between 32- and 64-bits implementations. Bugs may slip through.\
+ You need to improve the coverage before continuing.
+#endif
+
+// Structure to perform imports enumerations.
+struct EnumAllImportsStorage {
+ PEImage::EnumImportsFunction callback;
+ PVOID cookie;
+};
+
+namespace {
+
+ // Compare two strings byte by byte on an unsigned basis.
+ // if s1 == s2, return 0
+ // if s1 < s2, return negative
+ // if s1 > s2, return positive
+ // Exception if inputs are invalid.
+ int StrCmpByByte(LPCSTR s1, LPCSTR s2) {
+ while (*s1 != '\0' && *s1 == *s2) {
+ ++s1;
+ ++s2;
+ }
+
+ return (*reinterpret_cast<const unsigned char*>(s1) -
+ *reinterpret_cast<const unsigned char*>(s2));
+ }
+
+} // namespace
+
+// Callback used to enumerate imports. See EnumImportChunksFunction.
+bool ProcessImportChunk(const PEImage &image, LPCSTR module,
+ PIMAGE_THUNK_DATA name_table,
+ PIMAGE_THUNK_DATA iat, PVOID cookie) {
+ EnumAllImportsStorage &storage = *reinterpret_cast<EnumAllImportsStorage*>(
+ cookie);
+
+ return image.EnumOneImportChunk(storage.callback, module, name_table, iat,
+ storage.cookie);
+}
+
+// Callback used to enumerate delay imports. See EnumDelayImportChunksFunction.
+bool ProcessDelayImportChunk(const PEImage &image,
+ PImgDelayDescr delay_descriptor,
+ LPCSTR module, PIMAGE_THUNK_DATA name_table,
+ PIMAGE_THUNK_DATA iat, PIMAGE_THUNK_DATA bound_iat,
+ PIMAGE_THUNK_DATA unload_iat, PVOID cookie) {
+ EnumAllImportsStorage &storage = *reinterpret_cast<EnumAllImportsStorage*>(
+ cookie);
+
+ return image.EnumOneDelayImportChunk(storage.callback, delay_descriptor,
+ module, name_table, iat, bound_iat,
+ unload_iat, storage.cookie);
+}
+
+void PEImage::set_module(HMODULE module) {
+ module_ = module;
+}
+
+PIMAGE_DOS_HEADER PEImage::GetDosHeader() const {
+ return reinterpret_cast<PIMAGE_DOS_HEADER>(module_);
+}
+
+PIMAGE_NT_HEADERS PEImage::GetNTHeaders() const {
+ PIMAGE_DOS_HEADER dos_header = GetDosHeader();
+
+ return reinterpret_cast<PIMAGE_NT_HEADERS>(
+ reinterpret_cast<char*>(dos_header) + dos_header->e_lfanew);
+}
+
+PIMAGE_SECTION_HEADER PEImage::GetSectionHeader(UINT section) const {
+ PIMAGE_NT_HEADERS nt_headers = GetNTHeaders();
+ PIMAGE_SECTION_HEADER first_section = IMAGE_FIRST_SECTION(nt_headers);
+
+ if (section < nt_headers->FileHeader.NumberOfSections)
+ return first_section + section;
+ else
+ return NULL;
+}
+
+WORD PEImage::GetNumSections() const {
+ return GetNTHeaders()->FileHeader.NumberOfSections;
+}
+
+DWORD PEImage::GetImageDirectoryEntrySize(UINT directory) const {
+ PIMAGE_NT_HEADERS nt_headers = GetNTHeaders();
+
+ return nt_headers->OptionalHeader.DataDirectory[directory].Size;
+}
+
+PVOID PEImage::GetImageDirectoryEntryAddr(UINT directory) const {
+ PIMAGE_NT_HEADERS nt_headers = GetNTHeaders();
+
+ return RVAToAddr(
+ nt_headers->OptionalHeader.DataDirectory[directory].VirtualAddress);
+}
+
+PIMAGE_SECTION_HEADER PEImage::GetImageSectionFromAddr(PVOID address) const {
+ PBYTE target = reinterpret_cast<PBYTE>(address);
+ PIMAGE_SECTION_HEADER section;
+
+ for (UINT i = 0; NULL != (section = GetSectionHeader(i)); i++) {
+ // Don't use the virtual RVAToAddr.
+ PBYTE start = reinterpret_cast<PBYTE>(
+ PEImage::RVAToAddr(section->VirtualAddress));
+
+ DWORD size = section->Misc.VirtualSize;
+
+ if ((start <= target) && (start + size > target))
+ return section;
+ }
+
+ return NULL;
+}
+
+PIMAGE_SECTION_HEADER PEImage::GetImageSectionHeaderByName(
+ LPCSTR section_name) const {
+ if (NULL == section_name)
+ return NULL;
+
+ PIMAGE_SECTION_HEADER ret = NULL;
+ int num_sections = GetNumSections();
+
+ for (int i = 0; i < num_sections; i++) {
+ PIMAGE_SECTION_HEADER section = GetSectionHeader(i);
+ if (0 == _strnicmp(reinterpret_cast<LPCSTR>(section->Name), section_name,
+ sizeof(section->Name))) {
+ ret = section;
+ break;
+ }
+ }
+
+ return ret;
+}
+
+PDWORD PEImage::GetExportEntry(LPCSTR name) const {
+ PIMAGE_EXPORT_DIRECTORY exports = GetExportDirectory();
+
+ if (NULL == exports)
+ return NULL;
+
+ WORD ordinal = 0;
+ if (!GetProcOrdinal(name, &ordinal))
+ return NULL;
+
+ PDWORD functions = reinterpret_cast<PDWORD>(
+ RVAToAddr(exports->AddressOfFunctions));
+
+ return functions + ordinal - exports->Base;
+}
+
+FARPROC PEImage::GetProcAddress(LPCSTR function_name) const {
+ PDWORD export_entry = GetExportEntry(function_name);
+ if (NULL == export_entry)
+ return NULL;
+
+ PBYTE function = reinterpret_cast<PBYTE>(RVAToAddr(*export_entry));
+
+ PBYTE exports = reinterpret_cast<PBYTE>(
+ GetImageDirectoryEntryAddr(IMAGE_DIRECTORY_ENTRY_EXPORT));
+ DWORD size = GetImageDirectoryEntrySize(IMAGE_DIRECTORY_ENTRY_EXPORT);
+
+ // Check for forwarded exports as a special case.
+ if (exports <= function && exports + size > function)
+#pragma warning(push)
+#pragma warning(disable: 4312)
+ // This cast generates a warning because it is 32 bit specific.
+ return reinterpret_cast<FARPROC>(0xFFFFFFFF);
+#pragma warning(pop)
+
+ return reinterpret_cast<FARPROC>(function);
+}
+
+bool PEImage::GetProcOrdinal(LPCSTR function_name, WORD *ordinal) const {
+ if (NULL == ordinal)
+ return false;
+
+ PIMAGE_EXPORT_DIRECTORY exports = GetExportDirectory();
+
+ if (NULL == exports)
+ return false;
+
+ if (IsOrdinal(function_name)) {
+ *ordinal = ToOrdinal(function_name);
+ } else {
+ PDWORD names = reinterpret_cast<PDWORD>(RVAToAddr(exports->AddressOfNames));
+ PDWORD lower = names;
+ PDWORD upper = names + exports->NumberOfNames;
+ int cmp = -1;
+
+ // Binary Search for the name.
+ while (lower != upper) {
+ PDWORD middle = lower + (upper - lower) / 2;
+ LPCSTR name = reinterpret_cast<LPCSTR>(RVAToAddr(*middle));
+
+ // This may be called by sandbox before MSVCRT dll loads, so can't use
+ // CRT function here.
+ cmp = StrCmpByByte(function_name, name);
+
+ if (cmp == 0) {
+ lower = middle;
+ break;
+ }
+
+ if (cmp > 0)
+ lower = middle + 1;
+ else
+ upper = middle;
+ }
+
+ if (cmp != 0)
+ return false;
+
+
+ PWORD ordinals = reinterpret_cast<PWORD>(
+ RVAToAddr(exports->AddressOfNameOrdinals));
+
+ *ordinal = ordinals[lower - names] + static_cast<WORD>(exports->Base);
+ }
+
+ return true;
+}
+
+bool PEImage::EnumSections(EnumSectionsFunction callback, PVOID cookie) const {
+ PIMAGE_NT_HEADERS nt_headers = GetNTHeaders();
+ UINT num_sections = nt_headers->FileHeader.NumberOfSections;
+ PIMAGE_SECTION_HEADER section = GetSectionHeader(0);
+
+ for (UINT i = 0; i < num_sections; i++, section++) {
+ PVOID section_start = RVAToAddr(section->VirtualAddress);
+ DWORD size = section->Misc.VirtualSize;
+
+ if (!callback(*this, section, section_start, size, cookie))
+ return false;
+ }
+
+ return true;
+}
+
+bool PEImage::EnumExports(EnumExportsFunction callback, PVOID cookie) const {
+ PVOID directory = GetImageDirectoryEntryAddr(IMAGE_DIRECTORY_ENTRY_EXPORT);
+ DWORD size = GetImageDirectoryEntrySize(IMAGE_DIRECTORY_ENTRY_EXPORT);
+
+ // Check if there are any exports at all.
+ if (NULL == directory || 0 == size)
+ return true;
+
+ PIMAGE_EXPORT_DIRECTORY exports = reinterpret_cast<PIMAGE_EXPORT_DIRECTORY>(
+ directory);
+ UINT ordinal_base = exports->Base;
+ UINT num_funcs = exports->NumberOfFunctions;
+ UINT num_names = exports->NumberOfNames;
+ PDWORD functions = reinterpret_cast<PDWORD>(RVAToAddr(
+ exports->AddressOfFunctions));
+ PDWORD names = reinterpret_cast<PDWORD>(RVAToAddr(exports->AddressOfNames));
+ PWORD ordinals = reinterpret_cast<PWORD>(RVAToAddr(
+ exports->AddressOfNameOrdinals));
+
+ for (UINT count = 0; count < num_funcs; count++) {
+ PVOID func = RVAToAddr(functions[count]);
+ if (NULL == func)
+ continue;
+
+ // Check for a name.
+ LPCSTR name = NULL;
+ UINT hint;
+ for (hint = 0; hint < num_names; hint++) {
+ if (ordinals[hint] == count) {
+ name = reinterpret_cast<LPCSTR>(RVAToAddr(names[hint]));
+ break;
+ }
+ }
+
+ if (name == NULL)
+ hint = 0;
+
+ // Check for forwarded exports.
+ LPCSTR forward = NULL;
+ if (reinterpret_cast<char*>(func) >= reinterpret_cast<char*>(directory) &&
+ reinterpret_cast<char*>(func) <= reinterpret_cast<char*>(directory) +
+ size) {
+ forward = reinterpret_cast<LPCSTR>(func);
+ func = 0;
+ }
+
+ if (!callback(*this, ordinal_base + count, hint, name, func, forward,
+ cookie))
+ return false;
+ }
+
+ return true;
+}
+
+bool PEImage::EnumRelocs(EnumRelocsFunction callback, PVOID cookie) const {
+ PVOID directory = GetImageDirectoryEntryAddr(IMAGE_DIRECTORY_ENTRY_BASERELOC);
+ DWORD size = GetImageDirectoryEntrySize(IMAGE_DIRECTORY_ENTRY_BASERELOC);
+ PIMAGE_BASE_RELOCATION base = reinterpret_cast<PIMAGE_BASE_RELOCATION>(
+ directory);
+
+ if (directory == NULL || size < sizeof(IMAGE_BASE_RELOCATION))
+ return true;
+
+ while (base->SizeOfBlock) {
+ PWORD reloc = reinterpret_cast<PWORD>(base + 1);
+ UINT num_relocs = (base->SizeOfBlock - sizeof(IMAGE_BASE_RELOCATION)) /
+ sizeof(WORD);
+
+ for (UINT i = 0; i < num_relocs; i++, reloc++) {
+ WORD type = *reloc >> 12;
+ PVOID address = RVAToAddr(base->VirtualAddress + (*reloc & 0x0FFF));
+
+ if (!callback(*this, type, address, cookie))
+ return false;
+ }
+
+ base = reinterpret_cast<PIMAGE_BASE_RELOCATION>(
+ reinterpret_cast<char*>(base) + base->SizeOfBlock);
+ }
+
+ return true;
+}
+
+bool PEImage::EnumImportChunks(EnumImportChunksFunction callback,
+ PVOID cookie) const {
+ DWORD size = GetImageDirectoryEntrySize(IMAGE_DIRECTORY_ENTRY_IMPORT);
+ PIMAGE_IMPORT_DESCRIPTOR import = GetFirstImportChunk();
+
+ if (import == NULL || size < sizeof(IMAGE_IMPORT_DESCRIPTOR))
+ return true;
+
+ for (; import->FirstThunk; import++) {
+ LPCSTR module_name = reinterpret_cast<LPCSTR>(RVAToAddr(import->Name));
+ PIMAGE_THUNK_DATA name_table = reinterpret_cast<PIMAGE_THUNK_DATA>(
+ RVAToAddr(import->OriginalFirstThunk));
+ PIMAGE_THUNK_DATA iat = reinterpret_cast<PIMAGE_THUNK_DATA>(
+ RVAToAddr(import->FirstThunk));
+
+ if (!callback(*this, module_name, name_table, iat, cookie))
+ return false;
+ }
+
+ return true;
+}
+
+bool PEImage::EnumOneImportChunk(EnumImportsFunction callback,
+ LPCSTR module_name,
+ PIMAGE_THUNK_DATA name_table,
+ PIMAGE_THUNK_DATA iat, PVOID cookie) const {
+ if (NULL == name_table)
+ return false;
+
+ for (; name_table && name_table->u1.Ordinal; name_table++, iat++) {
+ LPCSTR name = NULL;
+ WORD ordinal = 0;
+ WORD hint = 0;
+
+ if (IMAGE_SNAP_BY_ORDINAL(name_table->u1.Ordinal)) {
+ ordinal = static_cast<WORD>(IMAGE_ORDINAL32(name_table->u1.Ordinal));
+ } else {
+ PIMAGE_IMPORT_BY_NAME import = reinterpret_cast<PIMAGE_IMPORT_BY_NAME>(
+ RVAToAddr(name_table->u1.ForwarderString));
+
+ hint = import->Hint;
+ name = reinterpret_cast<LPCSTR>(&import->Name);
+ }
+
+ if (!callback(*this, module_name, ordinal, name, hint, iat, cookie))
+ return false;
+ }
+
+ return true;
+}
+
+bool PEImage::EnumAllImports(EnumImportsFunction callback, PVOID cookie) const {
+ EnumAllImportsStorage temp = { callback, cookie };
+ return EnumImportChunks(ProcessImportChunk, &temp);
+}
+
+bool PEImage::EnumDelayImportChunks(EnumDelayImportChunksFunction callback,
+ PVOID cookie) const {
+ PVOID directory = GetImageDirectoryEntryAddr(
+ IMAGE_DIRECTORY_ENTRY_DELAY_IMPORT);
+ DWORD size = GetImageDirectoryEntrySize(IMAGE_DIRECTORY_ENTRY_DELAY_IMPORT);
+ PImgDelayDescr delay_descriptor = reinterpret_cast<PImgDelayDescr>(directory);
+
+ if (directory == NULL || size == 0)
+ return true;
+
+ for (; delay_descriptor->rvaHmod; delay_descriptor++) {
+ PIMAGE_THUNK_DATA name_table;
+ PIMAGE_THUNK_DATA iat;
+ PIMAGE_THUNK_DATA bound_iat; // address of the optional bound IAT
+ PIMAGE_THUNK_DATA unload_iat; // address of optional copy of original IAT
+ LPCSTR module_name;
+
+ // check if VC7-style imports, using RVAs instead of
+ // VC6-style addresses.
+ bool rvas = (delay_descriptor->grAttrs & dlattrRva) != 0;
+
+ if (rvas) {
+ module_name = reinterpret_cast<LPCSTR>(
+ RVAToAddr(delay_descriptor->rvaDLLName));
+ name_table = reinterpret_cast<PIMAGE_THUNK_DATA>(
+ RVAToAddr(delay_descriptor->rvaINT));
+ iat = reinterpret_cast<PIMAGE_THUNK_DATA>(
+ RVAToAddr(delay_descriptor->rvaIAT));
+ bound_iat = reinterpret_cast<PIMAGE_THUNK_DATA>(
+ RVAToAddr(delay_descriptor->rvaBoundIAT));
+ unload_iat = reinterpret_cast<PIMAGE_THUNK_DATA>(
+ RVAToAddr(delay_descriptor->rvaUnloadIAT));
+ } else {
+#pragma warning(push)
+#pragma warning(disable: 4312)
+ // These casts generate warnings because they are 32 bit specific.
+ module_name = reinterpret_cast<LPCSTR>(delay_descriptor->rvaDLLName);
+ name_table = reinterpret_cast<PIMAGE_THUNK_DATA>(
+ delay_descriptor->rvaINT);
+ iat = reinterpret_cast<PIMAGE_THUNK_DATA>(delay_descriptor->rvaIAT);
+ bound_iat = reinterpret_cast<PIMAGE_THUNK_DATA>(
+ delay_descriptor->rvaBoundIAT);
+ unload_iat = reinterpret_cast<PIMAGE_THUNK_DATA>(
+ delay_descriptor->rvaUnloadIAT);
+#pragma warning(pop)
+ }
+
+ if (!callback(*this, delay_descriptor, module_name, name_table, iat,
+ bound_iat, unload_iat, cookie))
+ return false;
+ }
+
+ return true;
+}
+
+bool PEImage::EnumOneDelayImportChunk(EnumImportsFunction callback,
+ PImgDelayDescr delay_descriptor,
+ LPCSTR module_name,
+ PIMAGE_THUNK_DATA name_table,
+ PIMAGE_THUNK_DATA iat,
+ PIMAGE_THUNK_DATA bound_iat,
+ PIMAGE_THUNK_DATA unload_iat,
+ PVOID cookie) const {
+ UNREFERENCED_PARAMETER(bound_iat);
+ UNREFERENCED_PARAMETER(unload_iat);
+
+ for (; name_table->u1.Ordinal; name_table++, iat++) {
+ LPCSTR name = NULL;
+ WORD ordinal = 0;
+ WORD hint = 0;
+
+ if (IMAGE_SNAP_BY_ORDINAL(name_table->u1.Ordinal)) {
+ ordinal = static_cast<WORD>(IMAGE_ORDINAL32(name_table->u1.Ordinal));
+ } else {
+ PIMAGE_IMPORT_BY_NAME import;
+ bool rvas = (delay_descriptor->grAttrs & dlattrRva) != 0;
+
+ if (rvas) {
+ import = reinterpret_cast<PIMAGE_IMPORT_BY_NAME>(
+ RVAToAddr(name_table->u1.ForwarderString));
+ } else {
+#pragma warning(push)
+#pragma warning(disable: 4312)
+ // This cast generates a warning because it is 32 bit specific.
+ import = reinterpret_cast<PIMAGE_IMPORT_BY_NAME>(
+ name_table->u1.ForwarderString);
+#pragma warning(pop)
+ }
+
+ hint = import->Hint;
+ name = reinterpret_cast<LPCSTR>(&import->Name);
+ }
+
+ if (!callback(*this, module_name, ordinal, name, hint, iat, cookie))
+ return false;
+ }
+
+ return true;
+}
+
+bool PEImage::EnumAllDelayImports(EnumImportsFunction callback,
+ PVOID cookie) const {
+ EnumAllImportsStorage temp = { callback, cookie };
+ return EnumDelayImportChunks(ProcessDelayImportChunk, &temp);
+}
+
+bool PEImage::VerifyMagic() const {
+ PIMAGE_DOS_HEADER dos_header = GetDosHeader();
+
+ if (dos_header->e_magic != IMAGE_DOS_SIGNATURE)
+ return false;
+
+ PIMAGE_NT_HEADERS nt_headers = GetNTHeaders();
+
+ if (nt_headers->Signature != IMAGE_NT_SIGNATURE)
+ return false;
+
+ if (nt_headers->FileHeader.SizeOfOptionalHeader !=
+ sizeof(IMAGE_OPTIONAL_HEADER))
+ return false;
+
+ if (nt_headers->OptionalHeader.Magic != IMAGE_NT_OPTIONAL_HDR_MAGIC)
+ return false;
+
+ return true;
+}
+
+bool PEImage::ImageRVAToOnDiskOffset(DWORD rva, DWORD *on_disk_offset) const {
+ LPVOID address = RVAToAddr(rva);
+ return ImageAddrToOnDiskOffset(address, on_disk_offset);
+}
+
+bool PEImage::ImageAddrToOnDiskOffset(LPVOID address,
+ DWORD *on_disk_offset) const {
+ if (NULL == address)
+ return false;
+
+ // Get the section that this address belongs to.
+ PIMAGE_SECTION_HEADER section_header = GetImageSectionFromAddr(address);
+ if (NULL == section_header)
+ return false;
+
+#pragma warning(push)
+#pragma warning(disable: 4311)
+ // These casts generate warnings because they are 32 bit specific.
+ // Don't follow the virtual RVAToAddr, use the one on the base.
+ DWORD offset_within_section = reinterpret_cast<DWORD>(address) -
+ reinterpret_cast<DWORD>(PEImage::RVAToAddr(
+ section_header->VirtualAddress));
+#pragma warning(pop)
+
+ *on_disk_offset = section_header->PointerToRawData + offset_within_section;
+ return true;
+}
+
+PVOID PEImage::RVAToAddr(DWORD rva) const {
+ if (rva == 0)
+ return NULL;
+
+ return reinterpret_cast<char*>(module_) + rva;
+}
+
+PVOID PEImageAsData::RVAToAddr(DWORD rva) const {
+ if (rva == 0)
+ return NULL;
+
+ PVOID in_memory = PEImage::RVAToAddr(rva);
+ DWORD disk_offset;
+
+ if (!ImageAddrToOnDiskOffset(in_memory, &disk_offset))
+ return NULL;
+
+ return PEImage::RVAToAddr(disk_offset);
+}
+
+} // namespace win
+} // namespace base
diff --git a/base/win/pe_image.h b/base/win/pe_image.h
new file mode 100644
index 0000000..878ef52
--- /dev/null
+++ b/base/win/pe_image.h
@@ -0,0 +1,268 @@
+// Copyright (c) 2010 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+// This file was adapted from GreenBorder's Code.
+// To understand what this class is about (for other than well known functions
+// as GetProcAddress), a good starting point is "An In-Depth Look into the
+// Win32 Portable Executable File Format" by Matt Pietrek:
+// http://msdn.microsoft.com/msdnmag/issues/02/02/PE/default.aspx
+
+#ifndef BASE_WIN_PE_IMAGE_H_
+#define BASE_WIN_PE_IMAGE_H_
+
+#include <windows.h>
+
+#if defined(_WIN32_WINNT_WIN8)
+// The Windows 8 SDK defines FACILITY_VISUALCPP in winerror.h.
+#undef FACILITY_VISUALCPP
+#endif
+#include <DelayIMP.h>
+
+namespace base {
+namespace win {
+
+// This class is a wrapper for the Portable Executable File Format (PE).
+// It's main purpose is to provide an easy way to work with imports and exports
+// from a file, mapped in memory as image.
+class PEImage {
+ public:
+ // Callback to enumerate sections.
+ // cookie is the value passed to the enumerate method.
+ // Returns true to continue the enumeration.
+ typedef bool (*EnumSectionsFunction)(const PEImage &image,
+ PIMAGE_SECTION_HEADER header,
+ PVOID section_start, DWORD section_size,
+ PVOID cookie);
+
+ // Callback to enumerate exports.
+ // function is the actual address of the symbol. If forward is not null, it
+ // contains the dll and symbol to forward this export to. cookie is the value
+ // passed to the enumerate method.
+ // Returns true to continue the enumeration.
+ typedef bool (*EnumExportsFunction)(const PEImage &image, DWORD ordinal,
+ DWORD hint, LPCSTR name, PVOID function,
+ LPCSTR forward, PVOID cookie);
+
+ // Callback to enumerate import blocks.
+ // name_table and iat point to the imports name table and address table for
+ // this block. cookie is the value passed to the enumerate method.
+ // Returns true to continue the enumeration.
+ typedef bool (*EnumImportChunksFunction)(const PEImage &image, LPCSTR module,
+ PIMAGE_THUNK_DATA name_table,
+ PIMAGE_THUNK_DATA iat, PVOID cookie);
+
+ // Callback to enumerate imports.
+ // module is the dll that exports this symbol. cookie is the value passed to
+ // the enumerate method.
+ // Returns true to continue the enumeration.
+ typedef bool (*EnumImportsFunction)(const PEImage &image, LPCSTR module,
+ DWORD ordinal, LPCSTR name, DWORD hint,
+ PIMAGE_THUNK_DATA iat, PVOID cookie);
+
+ // Callback to enumerate dalayed import blocks.
+ // module is the dll that exports this block of symbols. cookie is the value
+ // passed to the enumerate method.
+ // Returns true to continue the enumeration.
+ typedef bool (*EnumDelayImportChunksFunction)(const PEImage &image,
+ PImgDelayDescr delay_descriptor,
+ LPCSTR module,
+ PIMAGE_THUNK_DATA name_table,
+ PIMAGE_THUNK_DATA iat,
+ PIMAGE_THUNK_DATA bound_iat,
+ PIMAGE_THUNK_DATA unload_iat,
+ PVOID cookie);
+
+ // Callback to enumerate relocations.
+ // cookie is the value passed to the enumerate method.
+ // Returns true to continue the enumeration.
+ typedef bool (*EnumRelocsFunction)(const PEImage &image, WORD type,
+ PVOID address, PVOID cookie);
+
+ explicit PEImage(HMODULE module) : module_(module) {}
+ explicit PEImage(const void* module) {
+ module_ = reinterpret_cast<HMODULE>(const_cast<void*>(module));
+ }
+
+ // Gets the HMODULE for this object.
+ HMODULE module() const;
+
+ // Sets this object's HMODULE.
+ void set_module(HMODULE module);
+
+ // Checks if this symbol is actually an ordinal.
+ static bool IsOrdinal(LPCSTR name);
+
+ // Converts a named symbol to the corresponding ordinal.
+ static WORD ToOrdinal(LPCSTR name);
+
+ // Returns the DOS_HEADER for this PE.
+ PIMAGE_DOS_HEADER GetDosHeader() const;
+
+ // Returns the NT_HEADER for this PE.
+ PIMAGE_NT_HEADERS GetNTHeaders() const;
+
+ // Returns number of sections of this PE.
+ WORD GetNumSections() const;
+
+ // Returns the header for a given section.
+ // returns NULL if there is no such section.
+ PIMAGE_SECTION_HEADER GetSectionHeader(UINT section) const;
+
+ // Returns the size of a given directory entry.
+ DWORD GetImageDirectoryEntrySize(UINT directory) const;
+
+ // Returns the address of a given directory entry.
+ PVOID GetImageDirectoryEntryAddr(UINT directory) const;
+
+ // Returns the section header for a given address.
+ // Use: s = image.GetImageSectionFromAddr(a);
+ // Post: 's' is the section header of the section that contains 'a'
+ // or NULL if there is no such section.
+ PIMAGE_SECTION_HEADER GetImageSectionFromAddr(PVOID address) const;
+
+ // Returns the section header for a given section.
+ PIMAGE_SECTION_HEADER GetImageSectionHeaderByName(LPCSTR section_name) const;
+
+ // Returns the first block of imports.
+ PIMAGE_IMPORT_DESCRIPTOR GetFirstImportChunk() const;
+
+ // Returns the exports directory.
+ PIMAGE_EXPORT_DIRECTORY GetExportDirectory() const;
+
+ // Returns a given export entry.
+ // Use: e = image.GetExportEntry(f);
+ // Pre: 'f' is either a zero terminated string or ordinal
+ // Post: 'e' is a pointer to the export directory entry
+ // that contains 'f's export RVA, or NULL if 'f'
+ // is not exported from this image
+ PDWORD GetExportEntry(LPCSTR name) const;
+
+ // Returns the address for a given exported symbol.
+ // Use: p = image.GetProcAddress(f);
+ // Pre: 'f' is either a zero terminated string or ordinal.
+ // Post: if 'f' is a non-forwarded export from image, 'p' is
+ // the exported function. If 'f' is a forwarded export
+ // then p is the special value 0xFFFFFFFF. In this case
+ // RVAToAddr(*GetExportEntry) can be used to resolve
+ // the string that describes the forward.
+ FARPROC GetProcAddress(LPCSTR function_name) const;
+
+ // Retrieves the ordinal for a given exported symbol.
+ // Returns true if the symbol was found.
+ bool GetProcOrdinal(LPCSTR function_name, WORD *ordinal) const;
+
+ // Enumerates PE sections.
+ // cookie is a generic cookie to pass to the callback.
+ // Returns true on success.
+ bool EnumSections(EnumSectionsFunction callback, PVOID cookie) const;
+
+ // Enumerates PE exports.
+ // cookie is a generic cookie to pass to the callback.
+ // Returns true on success.
+ bool EnumExports(EnumExportsFunction callback, PVOID cookie) const;
+
+ // Enumerates PE imports.
+ // cookie is a generic cookie to pass to the callback.
+ // Returns true on success.
+ bool EnumAllImports(EnumImportsFunction callback, PVOID cookie) const;
+
+ // Enumerates PE import blocks.
+ // cookie is a generic cookie to pass to the callback.
+ // Returns true on success.
+ bool EnumImportChunks(EnumImportChunksFunction callback, PVOID cookie) const;
+
+ // Enumerates the imports from a single PE import block.
+ // cookie is a generic cookie to pass to the callback.
+ // Returns true on success.
+ bool EnumOneImportChunk(EnumImportsFunction callback, LPCSTR module_name,
+ PIMAGE_THUNK_DATA name_table, PIMAGE_THUNK_DATA iat,
+ PVOID cookie) const;
+
+
+ // Enumerates PE delay imports.
+ // cookie is a generic cookie to pass to the callback.
+ // Returns true on success.
+ bool EnumAllDelayImports(EnumImportsFunction callback, PVOID cookie) const;
+
+ // Enumerates PE delay import blocks.
+ // cookie is a generic cookie to pass to the callback.
+ // Returns true on success.
+ bool EnumDelayImportChunks(EnumDelayImportChunksFunction callback,
+ PVOID cookie) const;
+
+ // Enumerates imports from a single PE delay import block.
+ // cookie is a generic cookie to pass to the callback.
+ // Returns true on success.
+ bool EnumOneDelayImportChunk(EnumImportsFunction callback,
+ PImgDelayDescr delay_descriptor,
+ LPCSTR module_name,
+ PIMAGE_THUNK_DATA name_table,
+ PIMAGE_THUNK_DATA iat,
+ PIMAGE_THUNK_DATA bound_iat,
+ PIMAGE_THUNK_DATA unload_iat,
+ PVOID cookie) const;
+
+ // Enumerates PE relocation entries.
+ // cookie is a generic cookie to pass to the callback.
+ // Returns true on success.
+ bool EnumRelocs(EnumRelocsFunction callback, PVOID cookie) const;
+
+ // Verifies the magic values on the PE file.
+ // Returns true if all values are correct.
+ bool VerifyMagic() const;
+
+ // Converts an rva value to the appropriate address.
+ virtual PVOID RVAToAddr(DWORD rva) const;
+
+ // Converts an rva value to an offset on disk.
+ // Returns true on success.
+ bool ImageRVAToOnDiskOffset(DWORD rva, DWORD *on_disk_offset) const;
+
+ // Converts an address to an offset on disk.
+ // Returns true on success.
+ bool ImageAddrToOnDiskOffset(LPVOID address, DWORD *on_disk_offset) const;
+
+ private:
+ HMODULE module_;
+};
+
+// This class is an extension to the PEImage class that allows working with PE
+// files mapped as data instead of as image file.
+class PEImageAsData : public PEImage {
+ public:
+ explicit PEImageAsData(HMODULE hModule) : PEImage(hModule) {}
+
+ virtual PVOID RVAToAddr(DWORD rva) const;
+};
+
+inline bool PEImage::IsOrdinal(LPCSTR name) {
+#pragma warning(push)
+#pragma warning(disable: 4311)
+ // This cast generates a warning because it is 32 bit specific.
+ return reinterpret_cast<DWORD>(name) <= 0xFFFF;
+#pragma warning(pop)
+}
+
+inline WORD PEImage::ToOrdinal(LPCSTR name) {
+ return reinterpret_cast<WORD>(name);
+}
+
+inline HMODULE PEImage::module() const {
+ return module_;
+}
+
+inline PIMAGE_IMPORT_DESCRIPTOR PEImage::GetFirstImportChunk() const {
+ return reinterpret_cast<PIMAGE_IMPORT_DESCRIPTOR>(
+ GetImageDirectoryEntryAddr(IMAGE_DIRECTORY_ENTRY_IMPORT));
+}
+
+inline PIMAGE_EXPORT_DIRECTORY PEImage::GetExportDirectory() const {
+ return reinterpret_cast<PIMAGE_EXPORT_DIRECTORY>(
+ GetImageDirectoryEntryAddr(IMAGE_DIRECTORY_ENTRY_EXPORT));
+}
+
+} // namespace win
+} // namespace base
+
+#endif // BASE_WIN_PE_IMAGE_H_
diff --git a/base/win/pe_image_unittest.cc b/base/win/pe_image_unittest.cc
new file mode 100644
index 0000000..e308eae
--- /dev/null
+++ b/base/win/pe_image_unittest.cc
@@ -0,0 +1,256 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+// This file contains unit tests for PEImage.
+
+#include "testing/gtest/include/gtest/gtest.h"
+#include "base/win/pe_image.h"
+#include "base/win/windows_version.h"
+
+namespace base {
+namespace win {
+
+// Just counts the number of invocations.
+bool ExportsCallback(const PEImage &image,
+ DWORD ordinal,
+ DWORD hint,
+ LPCSTR name,
+ PVOID function,
+ LPCSTR forward,
+ PVOID cookie) {
+ int* count = reinterpret_cast<int*>(cookie);
+ (*count)++;
+ return true;
+}
+
+// Just counts the number of invocations.
+bool ImportsCallback(const PEImage &image,
+ LPCSTR module,
+ DWORD ordinal,
+ LPCSTR name,
+ DWORD hint,
+ PIMAGE_THUNK_DATA iat,
+ PVOID cookie) {
+ int* count = reinterpret_cast<int*>(cookie);
+ (*count)++;
+ return true;
+}
+
+// Just counts the number of invocations.
+bool SectionsCallback(const PEImage &image,
+ PIMAGE_SECTION_HEADER header,
+ PVOID section_start,
+ DWORD section_size,
+ PVOID cookie) {
+ int* count = reinterpret_cast<int*>(cookie);
+ (*count)++;
+ return true;
+}
+
+// Just counts the number of invocations.
+bool RelocsCallback(const PEImage &image,
+ WORD type,
+ PVOID address,
+ PVOID cookie) {
+ int* count = reinterpret_cast<int*>(cookie);
+ (*count)++;
+ return true;
+}
+
+// Just counts the number of invocations.
+bool ImportChunksCallback(const PEImage &image,
+ LPCSTR module,
+ PIMAGE_THUNK_DATA name_table,
+ PIMAGE_THUNK_DATA iat,
+ PVOID cookie) {
+ int* count = reinterpret_cast<int*>(cookie);
+ (*count)++;
+ return true;
+}
+
+// Just counts the number of invocations.
+bool DelayImportChunksCallback(const PEImage &image,
+ PImgDelayDescr delay_descriptor,
+ LPCSTR module,
+ PIMAGE_THUNK_DATA name_table,
+ PIMAGE_THUNK_DATA iat,
+ PIMAGE_THUNK_DATA bound_iat,
+ PIMAGE_THUNK_DATA unload_iat,
+ PVOID cookie) {
+ int* count = reinterpret_cast<int*>(cookie);
+ (*count)++;
+ return true;
+}
+
+// Identifiers for the set of supported expectations.
+enum ExpectationSet {
+ WIN_2K_SET,
+ WIN_XP_SET,
+ WIN_VISTA_SET,
+ WIN_7_SET,
+ WIN_8_SET,
+ UNSUPPORTED_SET,
+};
+
+// We'll be using some known values for the tests.
+enum Value {
+ sections = 0,
+ imports_dlls,
+ delay_dlls,
+ exports,
+ imports,
+ delay_imports,
+ relocs
+};
+
+ExpectationSet GetExpectationSet(DWORD os) {
+ if (os == 50)
+ return WIN_2K_SET;
+ if (os == 51)
+ return WIN_XP_SET;
+ if (os == 60)
+ return WIN_VISTA_SET;
+ if (os == 61)
+ return WIN_7_SET;
+ if (os >= 62)
+ return WIN_8_SET;
+ return UNSUPPORTED_SET;
+}
+
+// Retrieves the expected value from advapi32.dll based on the OS.
+int GetExpectedValue(Value value, DWORD os) {
+ const int xp_delay_dlls = 2;
+ const int xp_exports = 675;
+ const int xp_imports = 422;
+ const int xp_delay_imports = 8;
+ const int xp_relocs = 9180;
+ const int vista_delay_dlls = 4;
+ const int vista_exports = 799;
+ const int vista_imports = 476;
+ const int vista_delay_imports = 24;
+ const int vista_relocs = 10188;
+ const int w2k_delay_dlls = 0;
+ const int w2k_exports = 566;
+ const int w2k_imports = 357;
+ const int w2k_delay_imports = 0;
+ const int w2k_relocs = 7388;
+ const int win7_delay_dlls = 7;
+ const int win7_exports = 806;
+ const int win7_imports = 568;
+ const int win7_delay_imports = 71;
+ const int win7_relocs = 7812;
+ const int win8_delay_dlls = 9;
+ const int win8_exports = 806;
+ const int win8_imports = 568;
+ const int win8_delay_imports = 113;
+ const int win8_relocs = 9478;
+ int win8_sections = 4;
+ int win8_import_dlls = 17;
+
+ base::win::OSInfo* os_info = base::win::OSInfo::GetInstance();
+ if (os_info->architecture() == base::win::OSInfo::X86_ARCHITECTURE) {
+ win8_sections = 5;
+ win8_import_dlls = 19;
+ }
+
+ // Contains the expected value, for each enumerated property (Value), and the
+ // OS version: [Value][os_version]
+ const int expected[][5] = {
+ {4, 4, 4, 4, win8_sections},
+ {3, 3, 3, 13, win8_import_dlls},
+ {w2k_delay_dlls, xp_delay_dlls, vista_delay_dlls, win7_delay_dlls,
+ win8_delay_dlls},
+ {w2k_exports, xp_exports, vista_exports, win7_exports, win8_exports},
+ {w2k_imports, xp_imports, vista_imports, win7_imports, win8_imports},
+ {w2k_delay_imports, xp_delay_imports,
+ vista_delay_imports, win7_delay_imports, win8_delay_imports},
+ {w2k_relocs, xp_relocs, vista_relocs, win7_relocs, win8_relocs}
+ };
+ COMPILE_ASSERT(arraysize(expected[0]) == UNSUPPORTED_SET,
+ expected_value_set_mismatch);
+
+ if (value > relocs)
+ return 0;
+ ExpectationSet expected_set = GetExpectationSet(os);
+ if (expected_set >= arraysize(expected)) {
+ // This should never happen. Log a failure if it does.
+ EXPECT_NE(UNSUPPORTED_SET, expected_set);
+ expected_set = WIN_2K_SET;
+ }
+
+ return expected[value][expected_set];
+}
+
+// Tests that we are able to enumerate stuff from a PE file, and that
+// the actual number of items found is within the expected range.
+TEST(PEImageTest, EnumeratesPE) {
+ HMODULE module = LoadLibrary(L"advapi32.dll");
+ ASSERT_TRUE(NULL != module);
+
+ PEImage pe(module);
+ int count = 0;
+ EXPECT_TRUE(pe.VerifyMagic());
+
+ DWORD os = pe.GetNTHeaders()->OptionalHeader.MajorOperatingSystemVersion;
+ os = os * 10 + pe.GetNTHeaders()->OptionalHeader.MinorOperatingSystemVersion;
+
+ // Skip this test for unsupported OS versions.
+ if (GetExpectationSet(os) == UNSUPPORTED_SET)
+ return;
+
+ pe.EnumSections(SectionsCallback, &count);
+ EXPECT_EQ(GetExpectedValue(sections, os), count);
+
+ count = 0;
+ pe.EnumImportChunks(ImportChunksCallback, &count);
+ EXPECT_EQ(GetExpectedValue(imports_dlls, os), count);
+
+ count = 0;
+ pe.EnumDelayImportChunks(DelayImportChunksCallback, &count);
+ EXPECT_EQ(GetExpectedValue(delay_dlls, os), count);
+
+ count = 0;
+ pe.EnumExports(ExportsCallback, &count);
+ EXPECT_GT(count, GetExpectedValue(exports, os) - 20);
+ EXPECT_LT(count, GetExpectedValue(exports, os) + 100);
+
+ count = 0;
+ pe.EnumAllImports(ImportsCallback, &count);
+ EXPECT_GT(count, GetExpectedValue(imports, os) - 20);
+ EXPECT_LT(count, GetExpectedValue(imports, os) + 100);
+
+ count = 0;
+ pe.EnumAllDelayImports(ImportsCallback, &count);
+ EXPECT_GT(count, GetExpectedValue(delay_imports, os) - 2);
+ EXPECT_LT(count, GetExpectedValue(delay_imports, os) + 8);
+
+ count = 0;
+ pe.EnumRelocs(RelocsCallback, &count);
+ EXPECT_GT(count, GetExpectedValue(relocs, os) - 150);
+ EXPECT_LT(count, GetExpectedValue(relocs, os) + 1500);
+
+ FreeLibrary(module);
+}
+
+// Tests that we can locate an specific exported symbol, by name and by ordinal.
+TEST(PEImageTest, RetrievesExports) {
+ HMODULE module = LoadLibrary(L"advapi32.dll");
+ ASSERT_TRUE(NULL != module);
+
+ PEImage pe(module);
+ WORD ordinal;
+
+ EXPECT_TRUE(pe.GetProcOrdinal("RegEnumKeyExW", &ordinal));
+
+ FARPROC address1 = pe.GetProcAddress("RegEnumKeyExW");
+ FARPROC address2 = pe.GetProcAddress(reinterpret_cast<char*>(ordinal));
+ EXPECT_TRUE(address1 != NULL);
+ EXPECT_TRUE(address2 != NULL);
+ EXPECT_TRUE(address1 == address2);
+
+ FreeLibrary(module);
+}
+
+} // namespace win
+} // namespace base
diff --git a/base/win/registry.cc b/base/win/registry.cc
new file mode 100644
index 0000000..83eb590
--- /dev/null
+++ b/base/win/registry.cc
@@ -0,0 +1,464 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "base/win/registry.h"
+
+#include <shlwapi.h>
+#include <algorithm>
+
+#include "base/logging.h"
+#include "base/string_util.h"
+#include "base/threading/thread_restrictions.h"
+
+#pragma comment(lib, "shlwapi.lib") // for SHDeleteKey
+
+namespace base {
+namespace win {
+
+namespace {
+
+// RegEnumValue() reports the number of characters from the name that were
+// written to the buffer, not how many there are. This constant is the maximum
+// name size, such that a buffer with this size should read any name.
+const DWORD MAX_REGISTRY_NAME_SIZE = 16384;
+
+// Registry values are read as BYTE* but can have wchar_t* data whose last
+// wchar_t is truncated. This function converts the reported |byte_size| to
+// a size in wchar_t that can store a truncated wchar_t if necessary.
+inline DWORD to_wchar_size(DWORD byte_size) {
+ return (byte_size + sizeof(wchar_t) - 1) / sizeof(wchar_t);
+}
+
+} // namespace
+
+// RegKey ----------------------------------------------------------------------
+
+RegKey::RegKey()
+ : key_(NULL),
+ watch_event_(0) {
+}
+
+RegKey::RegKey(HKEY rootkey, const wchar_t* subkey, REGSAM access)
+ : key_(NULL),
+ watch_event_(0) {
+ if (rootkey) {
+ if (access & (KEY_SET_VALUE | KEY_CREATE_SUB_KEY | KEY_CREATE_LINK))
+ Create(rootkey, subkey, access);
+ else
+ Open(rootkey, subkey, access);
+ } else {
+ DCHECK(!subkey);
+ }
+}
+
+RegKey::~RegKey() {
+ Close();
+}
+
+LONG RegKey::Create(HKEY rootkey, const wchar_t* subkey, REGSAM access) {
+ DWORD disposition_value;
+ return CreateWithDisposition(rootkey, subkey, &disposition_value, access);
+}
+
+LONG RegKey::CreateWithDisposition(HKEY rootkey, const wchar_t* subkey,
+ DWORD* disposition, REGSAM access) {
+ DCHECK(rootkey && subkey && access && disposition);
+ Close();
+
+ LONG result = RegCreateKeyEx(rootkey, subkey, 0, NULL,
+ REG_OPTION_NON_VOLATILE, access, NULL, &key_,
+ disposition);
+ return result;
+}
+
+LONG RegKey::CreateKey(const wchar_t* name, REGSAM access) {
+ DCHECK(name && access);
+ HKEY subkey = NULL;
+ LONG result = RegCreateKeyEx(key_, name, 0, NULL, REG_OPTION_NON_VOLATILE,
+ access, NULL, &subkey, NULL);
+ Close();
+
+ key_ = subkey;
+ return result;
+}
+
+LONG RegKey::Open(HKEY rootkey, const wchar_t* subkey, REGSAM access) {
+ DCHECK(rootkey && subkey && access);
+ Close();
+
+ LONG result = RegOpenKeyEx(rootkey, subkey, 0, access, &key_);
+ return result;
+}
+
+LONG RegKey::OpenKey(const wchar_t* relative_key_name, REGSAM access) {
+ DCHECK(relative_key_name && access);
+ HKEY subkey = NULL;
+ LONG result = RegOpenKeyEx(key_, relative_key_name, 0, access, &subkey);
+
+ // We have to close the current opened key before replacing it with the new
+ // one.
+ Close();
+
+ key_ = subkey;
+ return result;
+}
+
+void RegKey::Close() {
+ StopWatching();
+ if (key_) {
+ ::RegCloseKey(key_);
+ key_ = NULL;
+ }
+}
+
+bool RegKey::HasValue(const wchar_t* name) const {
+ return RegQueryValueEx(key_, name, 0, NULL, NULL, NULL) == ERROR_SUCCESS;
+}
+
+DWORD RegKey::GetValueCount() const {
+ DWORD count = 0;
+ LONG result = RegQueryInfoKey(key_, NULL, 0, NULL, NULL, NULL, NULL, &count,
+ NULL, NULL, NULL, NULL);
+ return (result == ERROR_SUCCESS) ? count : 0;
+}
+
+LONG RegKey::GetValueNameAt(int index, std::wstring* name) const {
+ wchar_t buf[256];
+ DWORD bufsize = arraysize(buf);
+ LONG r = ::RegEnumValue(key_, index, buf, &bufsize, NULL, NULL, NULL, NULL);
+ if (r == ERROR_SUCCESS)
+ *name = buf;
+
+ return r;
+}
+
+LONG RegKey::DeleteKey(const wchar_t* name) {
+ DCHECK(key_);
+ DCHECK(name);
+ LONG result = SHDeleteKey(key_, name);
+ return result;
+}
+
+LONG RegKey::DeleteValue(const wchar_t* value_name) {
+ DCHECK(key_);
+ LONG result = RegDeleteValue(key_, value_name);
+ return result;
+}
+
+LONG RegKey::ReadValueDW(const wchar_t* name, DWORD* out_value) const {
+ DCHECK(out_value);
+ DWORD type = REG_DWORD;
+ DWORD size = sizeof(DWORD);
+ DWORD local_value = 0;
+ LONG result = ReadValue(name, &local_value, &size, &type);
+ if (result == ERROR_SUCCESS) {
+ if ((type == REG_DWORD || type == REG_BINARY) && size == sizeof(DWORD))
+ *out_value = local_value;
+ else
+ result = ERROR_CANTREAD;
+ }
+
+ return result;
+}
+
+LONG RegKey::ReadInt64(const wchar_t* name, int64* out_value) const {
+ DCHECK(out_value);
+ DWORD type = REG_QWORD;
+ int64 local_value = 0;
+ DWORD size = sizeof(local_value);
+ LONG result = ReadValue(name, &local_value, &size, &type);
+ if (result == ERROR_SUCCESS) {
+ if ((type == REG_QWORD || type == REG_BINARY) &&
+ size == sizeof(local_value))
+ *out_value = local_value;
+ else
+ result = ERROR_CANTREAD;
+ }
+
+ return result;
+}
+
+LONG RegKey::ReadValue(const wchar_t* name, std::wstring* out_value) const {
+ DCHECK(out_value);
+ const size_t kMaxStringLength = 1024; // This is after expansion.
+ // Use the one of the other forms of ReadValue if 1024 is too small for you.
+ wchar_t raw_value[kMaxStringLength];
+ DWORD type = REG_SZ, size = sizeof(raw_value);
+ LONG result = ReadValue(name, raw_value, &size, &type);
+ if (result == ERROR_SUCCESS) {
+ if (type == REG_SZ) {
+ *out_value = raw_value;
+ } else if (type == REG_EXPAND_SZ) {
+ wchar_t expanded[kMaxStringLength];
+ size = ExpandEnvironmentStrings(raw_value, expanded, kMaxStringLength);
+ // Success: returns the number of wchar_t's copied
+ // Fail: buffer too small, returns the size required
+ // Fail: other, returns 0
+ if (size == 0 || size > kMaxStringLength) {
+ result = ERROR_MORE_DATA;
+ } else {
+ *out_value = expanded;
+ }
+ } else {
+ // Not a string. Oops.
+ result = ERROR_CANTREAD;
+ }
+ }
+
+ return result;
+}
+
+LONG RegKey::ReadValue(const wchar_t* name,
+ void* data,
+ DWORD* dsize,
+ DWORD* dtype) const {
+ LONG result = RegQueryValueEx(key_, name, 0, dtype,
+ reinterpret_cast<LPBYTE>(data), dsize);
+ return result;
+}
+
+LONG RegKey::ReadValues(const wchar_t* name,
+ std::vector<std::wstring>* values) {
+ values->clear();
+
+ DWORD type = REG_MULTI_SZ;
+ DWORD size = 0;
+ LONG result = ReadValue(name, NULL, &size, &type);
+ if (FAILED(result) || size == 0)
+ return result;
+
+ if (type != REG_MULTI_SZ)
+ return ERROR_CANTREAD;
+
+ std::vector<wchar_t> buffer(size / sizeof(wchar_t));
+ result = ReadValue(name, &buffer[0], &size, NULL);
+ if (FAILED(result) || size == 0)
+ return result;
+
+ // Parse the double-null-terminated list of strings.
+ // Note: This code is paranoid to not read outside of |buf|, in the case where
+ // it may not be properly terminated.
+ const wchar_t* entry = &buffer[0];
+ const wchar_t* buffer_end = entry + (size / sizeof(wchar_t));
+ while (entry < buffer_end && entry[0] != '\0') {
+ const wchar_t* entry_end = std::find(entry, buffer_end, L'\0');
+ values->push_back(std::wstring(entry, entry_end));
+ entry = entry_end + 1;
+ }
+ return 0;
+}
+
+LONG RegKey::WriteValue(const wchar_t* name, DWORD in_value) {
+ return WriteValue(
+ name, &in_value, static_cast<DWORD>(sizeof(in_value)), REG_DWORD);
+}
+
+LONG RegKey::WriteValue(const wchar_t * name, const wchar_t* in_value) {
+ return WriteValue(name, in_value,
+ static_cast<DWORD>(sizeof(*in_value) * (wcslen(in_value) + 1)), REG_SZ);
+}
+
+LONG RegKey::WriteValue(const wchar_t* name,
+ const void* data,
+ DWORD dsize,
+ DWORD dtype) {
+ DCHECK(data || !dsize);
+
+ LONG result = RegSetValueEx(key_, name, 0, dtype,
+ reinterpret_cast<LPBYTE>(const_cast<void*>(data)), dsize);
+ return result;
+}
+
+LONG RegKey::StartWatching() {
+ DCHECK(key_);
+ if (!watch_event_)
+ watch_event_ = CreateEvent(NULL, TRUE, FALSE, NULL);
+
+ DWORD filter = REG_NOTIFY_CHANGE_NAME |
+ REG_NOTIFY_CHANGE_ATTRIBUTES |
+ REG_NOTIFY_CHANGE_LAST_SET |
+ REG_NOTIFY_CHANGE_SECURITY;
+
+ // Watch the registry key for a change of value.
+ LONG result = RegNotifyChangeKeyValue(key_, TRUE, filter, watch_event_, TRUE);
+ if (result != ERROR_SUCCESS) {
+ CloseHandle(watch_event_);
+ watch_event_ = 0;
+ }
+
+ return result;
+}
+
+bool RegKey::HasChanged() {
+ if (watch_event_) {
+ if (WaitForSingleObject(watch_event_, 0) == WAIT_OBJECT_0) {
+ StartWatching();
+ return true;
+ }
+ }
+ return false;
+}
+
+LONG RegKey::StopWatching() {
+ LONG result = ERROR_INVALID_HANDLE;
+ if (watch_event_) {
+ CloseHandle(watch_event_);
+ watch_event_ = 0;
+ result = ERROR_SUCCESS;
+ }
+ return result;
+}
+
+// RegistryValueIterator ------------------------------------------------------
+
+RegistryValueIterator::RegistryValueIterator(HKEY root_key,
+ const wchar_t* folder_key)
+ : name_(MAX_PATH, L'\0'),
+ value_(MAX_PATH, L'\0') {
+ LONG result = RegOpenKeyEx(root_key, folder_key, 0, KEY_READ, &key_);
+ if (result != ERROR_SUCCESS) {
+ key_ = NULL;
+ } else {
+ DWORD count = 0;
+ result = ::RegQueryInfoKey(key_, NULL, 0, NULL, NULL, NULL, NULL, &count,
+ NULL, NULL, NULL, NULL);
+
+ if (result != ERROR_SUCCESS) {
+ ::RegCloseKey(key_);
+ key_ = NULL;
+ } else {
+ index_ = count - 1;
+ }
+ }
+
+ Read();
+}
+
+RegistryValueIterator::~RegistryValueIterator() {
+ if (key_)
+ ::RegCloseKey(key_);
+}
+
+DWORD RegistryValueIterator::ValueCount() const {
+ DWORD count = 0;
+ LONG result = ::RegQueryInfoKey(key_, NULL, 0, NULL, NULL, NULL, NULL,
+ &count, NULL, NULL, NULL, NULL);
+ if (result != ERROR_SUCCESS)
+ return 0;
+
+ return count;
+}
+
+bool RegistryValueIterator::Valid() const {
+ return key_ != NULL && index_ >= 0;
+}
+
+void RegistryValueIterator::operator++() {
+ --index_;
+ Read();
+}
+
+bool RegistryValueIterator::Read() {
+ if (Valid()) {
+ DWORD capacity = static_cast<DWORD>(name_.capacity());
+ DWORD name_size = capacity;
+ // |value_size_| is in bytes. Reserve the last character for a NUL.
+ value_size_ = static_cast<DWORD>((value_.size() - 1) * sizeof(wchar_t));
+ LONG result = ::RegEnumValue(
+ key_, index_, WriteInto(&name_, name_size), &name_size, NULL, &type_,
+ reinterpret_cast<BYTE*>(vector_as_array(&value_)), &value_size_);
+
+ if (result == ERROR_MORE_DATA) {
+ // Registry key names are limited to 255 characters and fit within
+ // MAX_PATH (which is 260) but registry value names can use up to 16,383
+ // characters and the value itself is not limited
+ // (from http://msdn.microsoft.com/en-us/library/windows/desktop/
+ // ms724872(v=vs.85).aspx).
+ // Resize the buffers and retry if their size caused the failure.
+ DWORD value_size_in_wchars = to_wchar_size(value_size_);
+ if (value_size_in_wchars + 1 > value_.size())
+ value_.resize(value_size_in_wchars + 1, L'\0');
+ value_size_ = static_cast<DWORD>((value_.size() - 1) * sizeof(wchar_t));
+ name_size = name_size == capacity ? MAX_REGISTRY_NAME_SIZE : capacity;
+ result = ::RegEnumValue(
+ key_, index_, WriteInto(&name_, name_size), &name_size, NULL, &type_,
+ reinterpret_cast<BYTE*>(vector_as_array(&value_)), &value_size_);
+ }
+
+ if (result == ERROR_SUCCESS) {
+ DCHECK_LT(to_wchar_size(value_size_), value_.size());
+ value_[to_wchar_size(value_size_)] = L'\0';
+ return true;
+ }
+ }
+
+ name_[0] = L'\0';
+ value_[0] = L'\0';
+ value_size_ = 0;
+ return false;
+}
+
+// RegistryKeyIterator --------------------------------------------------------
+
+RegistryKeyIterator::RegistryKeyIterator(HKEY root_key,
+ const wchar_t* folder_key) {
+ LONG result = RegOpenKeyEx(root_key, folder_key, 0, KEY_READ, &key_);
+ if (result != ERROR_SUCCESS) {
+ key_ = NULL;
+ } else {
+ DWORD count = 0;
+ LONG result = ::RegQueryInfoKey(key_, NULL, 0, NULL, &count, NULL, NULL,
+ NULL, NULL, NULL, NULL, NULL);
+
+ if (result != ERROR_SUCCESS) {
+ ::RegCloseKey(key_);
+ key_ = NULL;
+ } else {
+ index_ = count - 1;
+ }
+ }
+
+ Read();
+}
+
+RegistryKeyIterator::~RegistryKeyIterator() {
+ if (key_)
+ ::RegCloseKey(key_);
+}
+
+DWORD RegistryKeyIterator::SubkeyCount() const {
+ DWORD count = 0;
+ LONG result = ::RegQueryInfoKey(key_, NULL, 0, NULL, &count, NULL, NULL,
+ NULL, NULL, NULL, NULL, NULL);
+ if (result != ERROR_SUCCESS)
+ return 0;
+
+ return count;
+}
+
+bool RegistryKeyIterator::Valid() const {
+ return key_ != NULL && index_ >= 0;
+}
+
+void RegistryKeyIterator::operator++() {
+ --index_;
+ Read();
+}
+
+bool RegistryKeyIterator::Read() {
+ if (Valid()) {
+ DWORD ncount = arraysize(name_);
+ FILETIME written;
+ LONG r = ::RegEnumKeyEx(key_, index_, name_, &ncount, NULL, NULL,
+ NULL, &written);
+ if (ERROR_SUCCESS == r)
+ return true;
+ }
+
+ name_[0] = '\0';
+ return false;
+}
+
+} // namespace win
+} // namespace base
diff --git a/base/win/registry.h b/base/win/registry.h
new file mode 100644
index 0000000..7a3d970
--- /dev/null
+++ b/base/win/registry.h
@@ -0,0 +1,212 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef BASE_WIN_REGISTRY_H_
+#define BASE_WIN_REGISTRY_H_
+
+#include <windows.h>
+#include <string>
+#include <vector>
+
+#include "base/base_export.h"
+#include "base/basictypes.h"
+#include "base/stl_util.h"
+
+namespace base {
+namespace win {
+
+// Utility class to read, write and manipulate the Windows Registry.
+// Registry vocabulary primer: a "key" is like a folder, in which there
+// are "values", which are <name, data> pairs, with an associated data type.
+//
+// Note:
+// ReadValue family of functions guarantee that the return arguments
+// are not touched in case of failure.
+class BASE_EXPORT RegKey {
+ public:
+ RegKey();
+ RegKey(HKEY rootkey, const wchar_t* subkey, REGSAM access);
+ ~RegKey();
+
+ LONG Create(HKEY rootkey, const wchar_t* subkey, REGSAM access);
+
+ LONG CreateWithDisposition(HKEY rootkey, const wchar_t* subkey,
+ DWORD* disposition, REGSAM access);
+
+ // Creates a subkey or open it if it already exists.
+ LONG CreateKey(const wchar_t* name, REGSAM access);
+
+ // Opens an existing reg key.
+ LONG Open(HKEY rootkey, const wchar_t* subkey, REGSAM access);
+
+ // Opens an existing reg key, given the relative key name.
+ LONG OpenKey(const wchar_t* relative_key_name, REGSAM access);
+
+ // Closes this reg key.
+ void Close();
+
+ // Returns false if this key does not have the specified value, of if an error
+ // occurrs while attempting to access it.
+ bool HasValue(const wchar_t* value_name) const;
+
+ // Returns the number of values for this key, of 0 if the number cannot be
+ // determined.
+ DWORD GetValueCount() const;
+
+ // Determine the nth value's name.
+ LONG GetValueNameAt(int index, std::wstring* name) const;
+
+ // True while the key is valid.
+ bool Valid() const { return key_ != NULL; }
+
+ // Kill a key and everything that live below it; please be careful when using
+ // it.
+ LONG DeleteKey(const wchar_t* name);
+
+ // Deletes a single value within the key.
+ LONG DeleteValue(const wchar_t* name);
+
+ // Getters:
+
+ // Returns an int32 value. If |name| is NULL or empty, returns the default
+ // value, if any.
+ LONG ReadValueDW(const wchar_t* name, DWORD* out_value) const;
+
+ // Returns an int64 value. If |name| is NULL or empty, returns the default
+ // value, if any.
+ LONG ReadInt64(const wchar_t* name, int64* out_value) const;
+
+ // Returns a string value. If |name| is NULL or empty, returns the default
+ // value, if any.
+ LONG ReadValue(const wchar_t* name, std::wstring* out_value) const;
+
+ // Reads a REG_MULTI_SZ registry field into a vector of strings. Clears
+ // |values| initially and adds further strings to the list. Returns
+ // ERROR_CANTREAD if type is not REG_MULTI_SZ.
+ LONG ReadValues(const wchar_t* name, std::vector<std::wstring>* values);
+
+ // Returns raw data. If |name| is NULL or empty, returns the default
+ // value, if any.
+ LONG ReadValue(const wchar_t* name,
+ void* data,
+ DWORD* dsize,
+ DWORD* dtype) const;
+
+ // Setters:
+
+ // Sets an int32 value.
+ LONG WriteValue(const wchar_t* name, DWORD in_value);
+
+ // Sets a string value.
+ LONG WriteValue(const wchar_t* name, const wchar_t* in_value);
+
+ // Sets raw data, including type.
+ LONG WriteValue(const wchar_t* name,
+ const void* data,
+ DWORD dsize,
+ DWORD dtype);
+
+ // Starts watching the key to see if any of its values have changed.
+ // The key must have been opened with the KEY_NOTIFY access privilege.
+ LONG StartWatching();
+
+ // If StartWatching hasn't been called, always returns false.
+ // Otherwise, returns true if anything under the key has changed.
+ // This can't be const because the |watch_event_| may be refreshed.
+ bool HasChanged();
+
+ // Will automatically be called by destructor if not manually called
+ // beforehand. Returns true if it was watching, false otherwise.
+ LONG StopWatching();
+
+ inline bool IsWatching() const { return watch_event_ != 0; }
+ HANDLE watch_event() const { return watch_event_; }
+ HKEY Handle() const { return key_; }
+
+ private:
+ HKEY key_; // The registry key being iterated.
+ HANDLE watch_event_;
+
+ DISALLOW_COPY_AND_ASSIGN(RegKey);
+};
+
+// Iterates the entries found in a particular folder on the registry.
+class BASE_EXPORT RegistryValueIterator {
+ public:
+ RegistryValueIterator(HKEY root_key, const wchar_t* folder_key);
+
+ ~RegistryValueIterator();
+
+ DWORD ValueCount() const;
+
+ // True while the iterator is valid.
+ bool Valid() const;
+
+ // Advances to the next registry entry.
+ void operator++();
+
+ const wchar_t* Name() const { return name_.c_str(); }
+ const wchar_t* Value() const { return vector_as_array(&value_); }
+ // ValueSize() is in bytes.
+ DWORD ValueSize() const { return value_size_; }
+ DWORD Type() const { return type_; }
+
+ int Index() const { return index_; }
+
+ private:
+ // Read in the current values.
+ bool Read();
+
+ // The registry key being iterated.
+ HKEY key_;
+
+ // Current index of the iteration.
+ int index_;
+
+ // Current values.
+ std::wstring name_;
+ std::vector<wchar_t> value_;
+ DWORD value_size_;
+ DWORD type_;
+
+ DISALLOW_COPY_AND_ASSIGN(RegistryValueIterator);
+};
+
+class BASE_EXPORT RegistryKeyIterator {
+ public:
+ RegistryKeyIterator(HKEY root_key, const wchar_t* folder_key);
+
+ ~RegistryKeyIterator();
+
+ DWORD SubkeyCount() const;
+
+ // True while the iterator is valid.
+ bool Valid() const;
+
+ // Advances to the next entry in the folder.
+ void operator++();
+
+ const wchar_t* Name() const { return name_; }
+
+ int Index() const { return index_; }
+
+ private:
+ // Read in the current values.
+ bool Read();
+
+ // The registry key being iterated.
+ HKEY key_;
+
+ // Current index of the iteration.
+ int index_;
+
+ wchar_t name_[MAX_PATH];
+
+ DISALLOW_COPY_AND_ASSIGN(RegistryKeyIterator);
+};
+
+} // namespace win
+} // namespace base
+
+#endif // BASE_WIN_REGISTRY_H_
diff --git a/base/win/registry_unittest.cc b/base/win/registry_unittest.cc
new file mode 100644
index 0000000..155402a
--- /dev/null
+++ b/base/win/registry_unittest.cc
@@ -0,0 +1,164 @@
+// Copyright (c) 2011 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "base/win/registry.h"
+
+#include <cstring>
+#include <vector>
+
+#include "base/compiler_specific.h"
+#include "base/stl_util.h"
+#include "testing/gtest/include/gtest/gtest.h"
+
+namespace base {
+namespace win {
+
+namespace {
+
+const wchar_t kRootKey[] = L"Base_Registry_Unittest";
+
+class RegistryTest : public testing::Test {
+ public:
+ RegistryTest() {}
+
+ protected:
+ virtual void SetUp() OVERRIDE {
+ // Create a temporary key.
+ RegKey key(HKEY_CURRENT_USER, L"", KEY_ALL_ACCESS);
+ key.DeleteKey(kRootKey);
+ ASSERT_NE(ERROR_SUCCESS, key.Open(HKEY_CURRENT_USER, kRootKey, KEY_READ));
+ ASSERT_EQ(ERROR_SUCCESS, key.Create(HKEY_CURRENT_USER, kRootKey, KEY_READ));
+ }
+
+ virtual void TearDown() OVERRIDE {
+ // Clean up the temporary key.
+ RegKey key(HKEY_CURRENT_USER, L"", KEY_SET_VALUE);
+ ASSERT_EQ(ERROR_SUCCESS, key.DeleteKey(kRootKey));
+ }
+
+ private:
+ DISALLOW_COPY_AND_ASSIGN(RegistryTest);
+};
+
+TEST_F(RegistryTest, ValueTest) {
+ RegKey key;
+
+ std::wstring foo_key(kRootKey);
+ foo_key += L"\\Foo";
+ ASSERT_EQ(ERROR_SUCCESS, key.Create(HKEY_CURRENT_USER, foo_key.c_str(),
+ KEY_READ));
+
+ {
+ ASSERT_EQ(ERROR_SUCCESS, key.Open(HKEY_CURRENT_USER, foo_key.c_str(),
+ KEY_READ | KEY_SET_VALUE));
+ ASSERT_TRUE(key.Valid());
+
+ const wchar_t kStringValueName[] = L"StringValue";
+ const wchar_t kDWORDValueName[] = L"DWORDValue";
+ const wchar_t kInt64ValueName[] = L"Int64Value";
+ const wchar_t kStringData[] = L"string data";
+ const DWORD kDWORDData = 0xdeadbabe;
+ const int64 kInt64Data = 0xdeadbabedeadbabeLL;
+
+ // Test value creation
+ ASSERT_EQ(ERROR_SUCCESS, key.WriteValue(kStringValueName, kStringData));
+ ASSERT_EQ(ERROR_SUCCESS, key.WriteValue(kDWORDValueName, kDWORDData));
+ ASSERT_EQ(ERROR_SUCCESS, key.WriteValue(kInt64ValueName, &kInt64Data,
+ sizeof(kInt64Data), REG_QWORD));
+ EXPECT_EQ(3U, key.GetValueCount());
+ EXPECT_TRUE(key.HasValue(kStringValueName));
+ EXPECT_TRUE(key.HasValue(kDWORDValueName));
+ EXPECT_TRUE(key.HasValue(kInt64ValueName));
+
+ // Test Read
+ std::wstring string_value;
+ DWORD dword_value = 0;
+ int64 int64_value = 0;
+ ASSERT_EQ(ERROR_SUCCESS, key.ReadValue(kStringValueName, &string_value));
+ ASSERT_EQ(ERROR_SUCCESS, key.ReadValueDW(kDWORDValueName, &dword_value));
+ ASSERT_EQ(ERROR_SUCCESS, key.ReadInt64(kInt64ValueName, &int64_value));
+ EXPECT_STREQ(kStringData, string_value.c_str());
+ EXPECT_EQ(kDWORDData, dword_value);
+ EXPECT_EQ(kInt64Data, int64_value);
+
+ // Make sure out args are not touched if ReadValue fails
+ const wchar_t* kNonExistent = L"NonExistent";
+ ASSERT_NE(ERROR_SUCCESS, key.ReadValue(kNonExistent, &string_value));
+ ASSERT_NE(ERROR_SUCCESS, key.ReadValueDW(kNonExistent, &dword_value));
+ ASSERT_NE(ERROR_SUCCESS, key.ReadInt64(kNonExistent, &int64_value));
+ EXPECT_STREQ(kStringData, string_value.c_str());
+ EXPECT_EQ(kDWORDData, dword_value);
+ EXPECT_EQ(kInt64Data, int64_value);
+
+ // Test delete
+ ASSERT_EQ(ERROR_SUCCESS, key.DeleteValue(kStringValueName));
+ ASSERT_EQ(ERROR_SUCCESS, key.DeleteValue(kDWORDValueName));
+ ASSERT_EQ(ERROR_SUCCESS, key.DeleteValue(kInt64ValueName));
+ EXPECT_EQ(0U, key.GetValueCount());
+ EXPECT_FALSE(key.HasValue(kStringValueName));
+ EXPECT_FALSE(key.HasValue(kDWORDValueName));
+ EXPECT_FALSE(key.HasValue(kInt64ValueName));
+ }
+}
+
+TEST_F(RegistryTest, BigValueIteratorTest) {
+ RegKey key;
+ std::wstring foo_key(kRootKey);
+ foo_key += L"\\Foo";
+ ASSERT_EQ(ERROR_SUCCESS, key.Create(HKEY_CURRENT_USER, foo_key.c_str(),
+ KEY_READ));
+ ASSERT_EQ(ERROR_SUCCESS, key.Open(HKEY_CURRENT_USER, foo_key.c_str(),
+ KEY_READ | KEY_SET_VALUE));
+ ASSERT_TRUE(key.Valid());
+
+ // Create a test value that is larger than MAX_PATH.
+ std::wstring data(MAX_PATH * 2, L'a');
+
+ ASSERT_EQ(ERROR_SUCCESS, key.WriteValue(data.c_str(), data.c_str()));
+
+ RegistryValueIterator iterator(HKEY_CURRENT_USER, foo_key.c_str());
+ ASSERT_TRUE(iterator.Valid());
+ EXPECT_STREQ(data.c_str(), iterator.Name());
+ EXPECT_STREQ(data.c_str(), iterator.Value());
+ // ValueSize() is in bytes, including NUL.
+ EXPECT_EQ((MAX_PATH * 2 + 1) * sizeof(wchar_t), iterator.ValueSize());
+ ++iterator;
+ EXPECT_FALSE(iterator.Valid());
+}
+
+TEST_F(RegistryTest, TruncatedCharTest) {
+ RegKey key;
+ std::wstring foo_key(kRootKey);
+ foo_key += L"\\Foo";
+ ASSERT_EQ(ERROR_SUCCESS, key.Create(HKEY_CURRENT_USER, foo_key.c_str(),
+ KEY_READ));
+ ASSERT_EQ(ERROR_SUCCESS, key.Open(HKEY_CURRENT_USER, foo_key.c_str(),
+ KEY_READ | KEY_SET_VALUE));
+ ASSERT_TRUE(key.Valid());
+
+ const wchar_t kName[] = L"name";
+ // kData size is not a multiple of sizeof(wchar_t).
+ const uint8 kData[] = { 1, 2, 3, 4, 5 };
+ EXPECT_EQ(5, arraysize(kData));
+ ASSERT_EQ(ERROR_SUCCESS, key.WriteValue(kName, kData,
+ arraysize(kData), REG_BINARY));
+
+ RegistryValueIterator iterator(HKEY_CURRENT_USER, foo_key.c_str());
+ ASSERT_TRUE(iterator.Valid());
+ EXPECT_STREQ(kName, iterator.Name());
+ // ValueSize() is in bytes.
+ ASSERT_EQ(arraysize(kData), iterator.ValueSize());
+ // Value() is NUL terminated.
+ int end = (iterator.ValueSize() + sizeof(wchar_t) - 1) / sizeof(wchar_t);
+ EXPECT_NE(L'\0', iterator.Value()[end-1]);
+ EXPECT_EQ(L'\0', iterator.Value()[end]);
+ EXPECT_EQ(0, std::memcmp(kData, iterator.Value(), arraysize(kData)));
+ ++iterator;
+ EXPECT_FALSE(iterator.Valid());
+}
+
+} // namespace
+
+} // namespace win
+} // namespace base
diff --git a/base/win/resource_util.cc b/base/win/resource_util.cc
new file mode 100644
index 0000000..de9f583
--- /dev/null
+++ b/base/win/resource_util.cc
@@ -0,0 +1,39 @@
+// Copyright (c) 2006-2008 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "base/logging.h"
+#include "base/win/resource_util.h"
+
+namespace base {
+namespace win {
+bool GetDataResourceFromModule(HMODULE module, int resource_id,
+ void** data, size_t* length) {
+ if (!module)
+ return false;
+
+ if (!IS_INTRESOURCE(resource_id)) {
+ NOTREACHED();
+ return false;
+ }
+
+ HRSRC hres_info = FindResource(module, MAKEINTRESOURCE(resource_id),
+ L"BINDATA");
+ if (NULL == hres_info)
+ return false;
+
+ DWORD data_size = SizeofResource(module, hres_info);
+ HGLOBAL hres = LoadResource(module, hres_info);
+ if (!hres)
+ return false;
+
+ void* resource = LockResource(hres);
+ if (!resource)
+ return false;
+
+ *data = resource;
+ *length = static_cast<size_t>(data_size);
+ return true;
+}
+} // namespace win
+} // namespace base
diff --git a/base/win/resource_util.h b/base/win/resource_util.h
new file mode 100644
index 0000000..9955402
--- /dev/null
+++ b/base/win/resource_util.h
@@ -0,0 +1,28 @@
+// Copyright (c) 2011 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+// This file contains utility functions for accessing resources in external
+// files (DLLs) or embedded in the executable itself.
+
+#ifndef BASE_WIN_RESOURCE_UTIL_H__
+#define BASE_WIN_RESOURCE_UTIL_H__
+
+#include <windows.h>
+
+#include "base/base_export.h"
+#include "base/basictypes.h"
+
+namespace base {
+namespace win {
+
+// Function for getting a data resource (BINDATA) from a dll. Some
+// resources are optional, especially in unit tests, so this returns false
+// but doesn't raise an error if the resource can't be loaded.
+bool BASE_EXPORT GetDataResourceFromModule(HMODULE module, int resource_id,
+ void** data, size_t* length);
+
+} // namespace win
+} // namespace base
+
+#endif // BASE_WIN_RESOURCE_UTIL_H__
diff --git a/base/win/sampling_profiler.cc b/base/win/sampling_profiler.cc
new file mode 100644
index 0000000..150452c
--- /dev/null
+++ b/base/win/sampling_profiler.cc
@@ -0,0 +1,238 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "base/win/sampling_profiler.h"
+
+#include <winternl.h> // for NTSTATUS.
+
+#include "base/lazy_instance.h"
+
+// Copied from wdm.h in the WDK as we don't want to take
+// a dependency on the WDK.
+typedef enum _KPROFILE_SOURCE {
+ ProfileTime,
+ ProfileAlignmentFixup,
+ ProfileTotalIssues,
+ ProfilePipelineDry,
+ ProfileLoadInstructions,
+ ProfilePipelineFrozen,
+ ProfileBranchInstructions,
+ ProfileTotalNonissues,
+ ProfileDcacheMisses,
+ ProfileIcacheMisses,
+ ProfileCacheMisses,
+ ProfileBranchMispredictions,
+ ProfileStoreInstructions,
+ ProfileFpInstructions,
+ ProfileIntegerInstructions,
+ Profile2Issue,
+ Profile3Issue,
+ Profile4Issue,
+ ProfileSpecialInstructions,
+ ProfileTotalCycles,
+ ProfileIcacheIssues,
+ ProfileDcacheAccesses,
+ ProfileMemoryBarrierCycles,
+ ProfileLoadLinkedIssues,
+ ProfileMaximum
+} KPROFILE_SOURCE;
+
+
+namespace {
+
+// Signatures for the native functions we need to access the sampling profiler.
+typedef NTSTATUS (NTAPI *ZwSetIntervalProfileFunc)(ULONG, KPROFILE_SOURCE);
+typedef NTSTATUS (NTAPI *ZwQueryIntervalProfileFunc)(KPROFILE_SOURCE, PULONG);
+
+typedef NTSTATUS (NTAPI *ZwCreateProfileFunc)(PHANDLE profile,
+ HANDLE process,
+ PVOID code_start,
+ ULONG code_size,
+ ULONG eip_bucket_shift,
+ PULONG buckets,
+ ULONG buckets_byte_size,
+ KPROFILE_SOURCE source,
+ DWORD_PTR processor_mask);
+
+typedef NTSTATUS (NTAPI *ZwStartProfileFunc)(HANDLE);
+typedef NTSTATUS (NTAPI *ZwStopProfileFunc)(HANDLE);
+
+// This class is used to lazy-initialize pointers to the native
+// functions we need to access.
+class ProfilerFuncs {
+ public:
+ ProfilerFuncs();
+
+ ZwSetIntervalProfileFunc ZwSetIntervalProfile;
+ ZwQueryIntervalProfileFunc ZwQueryIntervalProfile;
+ ZwCreateProfileFunc ZwCreateProfile;
+ ZwStartProfileFunc ZwStartProfile;
+ ZwStopProfileFunc ZwStopProfile;
+
+ // True iff all of the function pointers above were successfully initialized.
+ bool initialized_;
+};
+
+ProfilerFuncs::ProfilerFuncs()
+ : ZwSetIntervalProfile(NULL),
+ ZwQueryIntervalProfile(NULL),
+ ZwCreateProfile(NULL),
+ ZwStartProfile(NULL),
+ ZwStopProfile(NULL),
+ initialized_(false) {
+ HMODULE ntdll = ::GetModuleHandle(L"ntdll.dll");
+ if (ntdll != NULL) {
+ ZwSetIntervalProfile = reinterpret_cast<ZwSetIntervalProfileFunc>(
+ ::GetProcAddress(ntdll, "ZwSetIntervalProfile"));
+ ZwQueryIntervalProfile = reinterpret_cast<ZwQueryIntervalProfileFunc>(
+ ::GetProcAddress(ntdll, "ZwQueryIntervalProfile"));
+ ZwCreateProfile = reinterpret_cast<ZwCreateProfileFunc>(
+ ::GetProcAddress(ntdll, "ZwCreateProfile"));
+ ZwStartProfile = reinterpret_cast<ZwStartProfileFunc>(
+ ::GetProcAddress(ntdll, "ZwStartProfile"));
+ ZwStopProfile = reinterpret_cast<ZwStopProfileFunc>(
+ ::GetProcAddress(ntdll, "ZwStopProfile"));
+
+ if (ZwSetIntervalProfile &&
+ ZwQueryIntervalProfile &&
+ ZwCreateProfile &&
+ ZwStartProfile &&
+ ZwStopProfile) {
+ initialized_ = true;
+ }
+ }
+}
+
+base::LazyInstance<ProfilerFuncs>::Leaky funcs = LAZY_INSTANCE_INITIALIZER;
+
+} // namespace
+
+
+namespace base {
+namespace win {
+
+SamplingProfiler::SamplingProfiler() : is_started_(false) {
+}
+
+SamplingProfiler::~SamplingProfiler() {
+ if (is_started_) {
+ CHECK(Stop()) <<
+ "Unable to stop sampling profiler, this will cause memory corruption.";
+ }
+}
+
+bool SamplingProfiler::Initialize(HANDLE process,
+ void* start,
+ size_t size,
+ size_t log2_bucket_size) {
+ // You only get to initialize each instance once.
+ DCHECK(!profile_handle_.IsValid());
+ DCHECK(!is_started_);
+ DCHECK(start != NULL);
+ DCHECK_NE(0U, size);
+ DCHECK_LE(2, log2_bucket_size);
+ DCHECK_GE(32, log2_bucket_size);
+
+ // Bail if the native functions weren't found.
+ if (!funcs.Get().initialized_)
+ return false;
+
+ size_t bucket_size = 1 << log2_bucket_size;
+ size_t num_buckets = (size + bucket_size - 1) / bucket_size;
+ DCHECK(num_buckets != 0);
+ buckets_.resize(num_buckets);
+
+ // Get our affinity mask for the call below.
+ DWORD_PTR process_affinity = 0;
+ DWORD_PTR system_affinity = 0;
+ if (!::GetProcessAffinityMask(process, &process_affinity, &system_affinity)) {
+ LOG(ERROR) << "Failed to get process affinity mask.";
+ return false;
+ }
+
+ HANDLE profile = NULL;
+ NTSTATUS status =
+ funcs.Get().ZwCreateProfile(&profile,
+ process,
+ start,
+ static_cast<ULONG>(size),
+ static_cast<ULONG>(log2_bucket_size),
+ &buckets_[0],
+ static_cast<ULONG>(
+ sizeof(buckets_[0]) * num_buckets),
+ ProfileTime,
+ process_affinity);
+
+ if (!NT_SUCCESS(status)) {
+ // Might as well deallocate the buckets.
+ buckets_.resize(0);
+ LOG(ERROR) << "Failed to create profile, error 0x" << std::hex << status;
+ return false;
+ }
+
+ DCHECK(profile != NULL);
+ profile_handle_.Set(profile);
+
+ return true;
+}
+
+bool SamplingProfiler::Start() {
+ DCHECK(profile_handle_.IsValid());
+ DCHECK(!is_started_);
+ DCHECK(funcs.Get().initialized_);
+
+ NTSTATUS status = funcs.Get().ZwStartProfile(profile_handle_.Get());
+ if (!NT_SUCCESS(status))
+ return false;
+
+ is_started_ = true;
+
+ return true;
+}
+
+bool SamplingProfiler::Stop() {
+ DCHECK(profile_handle_.IsValid());
+ DCHECK(is_started_);
+ DCHECK(funcs.Get().initialized_);
+
+ NTSTATUS status = funcs.Get().ZwStopProfile(profile_handle_.Get());
+ if (!NT_SUCCESS(status))
+ return false;
+ is_started_ = false;
+
+ return true;
+}
+
+bool SamplingProfiler::SetSamplingInterval(base::TimeDelta sampling_interval) {
+ if (!funcs.Get().initialized_)
+ return false;
+
+ // According to Nebbet, the sampling interval is in units of 100ns.
+ ULONG interval = sampling_interval.InMicroseconds() * 10;
+ NTSTATUS status = funcs.Get().ZwSetIntervalProfile(interval, ProfileTime);
+ if (!NT_SUCCESS(status))
+ return false;
+
+ return true;
+}
+
+bool SamplingProfiler::GetSamplingInterval(base::TimeDelta* sampling_interval) {
+ DCHECK(sampling_interval != NULL);
+
+ if (!funcs.Get().initialized_)
+ return false;
+
+ ULONG interval = 0;
+ NTSTATUS status = funcs.Get().ZwQueryIntervalProfile(ProfileTime, &interval);
+ if (!NT_SUCCESS(status))
+ return false;
+
+ // According to Nebbet, the sampling interval is in units of 100ns.
+ *sampling_interval = base::TimeDelta::FromMicroseconds(interval / 10);
+
+ return true;
+}
+
+} // namespace win
+} // namespace base
diff --git a/base/win/sampling_profiler.h b/base/win/sampling_profiler.h
new file mode 100644
index 0000000..e7e76d8
--- /dev/null
+++ b/base/win/sampling_profiler.h
@@ -0,0 +1,73 @@
+// Copyright (c) 2011 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef BASE_WIN_SAMPLING_PROFILER_H_
+#define BASE_WIN_SAMPLING_PROFILER_H_
+
+#include <vector>
+
+#include "base/basictypes.h"
+#include "base/time.h"
+#include "base/win/scoped_handle.h"
+
+namespace base {
+namespace win {
+
+// This class exposes the functionality of Window's built-in sampling profiler.
+// Each profiler instance covers a range of memory, and while the profiler is
+// running, its buckets will count the number of times the instruction counter
+// lands in the associated range of memory on a sample.
+// The sampling interval is settable, but the setting is system-wide.
+class BASE_EXPORT SamplingProfiler {
+ public:
+ // Create an uninitialized sampling profiler.
+ SamplingProfiler();
+ ~SamplingProfiler();
+
+ // Initializes the profiler to cover the memory range |start| through
+ // |start| + |size|, in the process |process_handle| with bucket size
+ // |2^log2_bucket_size|, |log2_bucket_size| must be in the range 2-31,
+ // for bucket sizes of 4 bytes to 2 gigabytes.
+ // The process handle must grant at least PROCESS_QUERY_INFORMATION.
+ // The memory range should be exectuable code, like e.g. the text segment
+ // of an exectuable (whether DLL or EXE).
+ // Returns true on success.
+ bool Initialize(HANDLE process_handle,
+ void* start,
+ size_t size,
+ size_t log2_bucket_size);
+
+ // Start this profiler, which must be initialized and not started.
+ bool Start();
+ // Stop this profiler, which must be started.
+ bool Stop();
+
+ // Get and set the sampling interval.
+ // Note that this is a system-wide setting.
+ static bool SetSamplingInterval(base::TimeDelta sampling_interval);
+ static bool GetSamplingInterval(base::TimeDelta* sampling_interval);
+
+ // Accessors.
+ bool is_started() const { return is_started_; }
+
+ // It is safe to read the counts in the sampling buckets at any time.
+ // Note however that there's no guarantee that you'll read consistent counts
+ // until the profiler has been stopped, as the counts may be updating on other
+ // CPU cores.
+ const std::vector<ULONG>& buckets() const { return buckets_; }
+
+ private:
+ // Handle to the corresponding kernel object.
+ ScopedHandle profile_handle_;
+ // True iff this profiler is started.
+ bool is_started_;
+ std::vector<ULONG> buckets_;
+
+ DISALLOW_COPY_AND_ASSIGN(SamplingProfiler);
+};
+
+} // namespace win
+} // namespace base
+
+#endif // BASE_WIN_SAMPLING_PROFILER_H_
diff --git a/base/win/sampling_profiler_unittest.cc b/base/win/sampling_profiler_unittest.cc
new file mode 100644
index 0000000..d022026
--- /dev/null
+++ b/base/win/sampling_profiler_unittest.cc
@@ -0,0 +1,120 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "base/logging.h"
+#include "base/test/test_timeouts.h"
+#include "base/win/sampling_profiler.h"
+#include "base/win/pe_image.h"
+#include "base/win/scoped_handle.h"
+#include "base/win/windows_version.h"
+#include "testing/gtest/include/gtest/gtest.h"
+
+// The address of our image base.
+extern "C" IMAGE_DOS_HEADER __ImageBase;
+
+namespace base {
+namespace win {
+
+namespace {
+
+class SamplingProfilerTest : public testing::Test {
+ public:
+ SamplingProfilerTest() : code_start(NULL), code_size(0) {
+ }
+
+ virtual void SetUp() {
+ process.Set(::OpenProcess(PROCESS_QUERY_INFORMATION,
+ FALSE,
+ ::GetCurrentProcessId()));
+ ASSERT_TRUE(process.IsValid());
+
+ PEImage image(&__ImageBase);
+
+ // Get the address of the .text section, which is the first section output
+ // by the VS tools.
+ ASSERT_TRUE(image.GetNumSections() > 0);
+ const IMAGE_SECTION_HEADER* text_section = image.GetSectionHeader(0);
+ ASSERT_EQ(0, strncmp(".text",
+ reinterpret_cast<const char*>(text_section->Name),
+ arraysize(text_section->Name)));
+ ASSERT_NE(0U, text_section->Characteristics & IMAGE_SCN_MEM_EXECUTE);
+
+ code_start = reinterpret_cast<uint8*>(&__ImageBase) +
+ text_section->VirtualAddress;
+ code_size = text_section->Misc.VirtualSize;
+ }
+
+ protected:
+ ScopedHandle process;
+ void* code_start;
+ size_t code_size;
+};
+
+} // namespace
+
+TEST_F(SamplingProfilerTest, Initialize) {
+ SamplingProfiler profiler;
+
+ ASSERT_TRUE(profiler.Initialize(process.Get(), code_start, code_size, 8));
+}
+
+TEST_F(SamplingProfilerTest, Sample) {
+ if (base::win::GetVersion() == base::win::VERSION_WIN8) {
+ LOG(INFO) << "Not running test on Windows 8";
+ return;
+ }
+ SamplingProfiler profiler;
+
+ // Initialize with a huge bucket size, aiming for a single bucket.
+ ASSERT_TRUE(
+ profiler.Initialize(process.Get(), code_start, code_size, 31));
+
+ ASSERT_EQ(1, profiler.buckets().size());
+ ASSERT_EQ(0, profiler.buckets()[0]);
+
+ // We use a roomy timeout to make sure this test is not flaky.
+ // On the buildbots, there may not be a whole lot of CPU time
+ // allotted to our process in this wall-clock time duration,
+ // and samples will only accrue while this thread is busy on
+ // a CPU core.
+ base::TimeDelta spin_time = TestTimeouts::action_timeout();
+
+ base::TimeDelta save_sampling_interval;
+ ASSERT_TRUE(SamplingProfiler::GetSamplingInterval(&save_sampling_interval));
+
+ // Sample every 0.5 millisecs.
+ ASSERT_TRUE(SamplingProfiler::SetSamplingInterval(
+ base::TimeDelta::FromMicroseconds(500)));
+
+ ASSERT_TRUE(SamplingProfiler::SetSamplingInterval(
+ base::TimeDelta::FromMicroseconds(500)));
+
+ // Start the profiler.
+ ASSERT_TRUE(profiler.Start());
+
+ // Get a volatile pointer to our bucket to make sure that the compiler
+ // doesn't optimize out the test in the loop that follows.
+ volatile const ULONG* bucket_ptr = &profiler.buckets()[0];
+
+ // Spin for spin_time wall-clock seconds, or until we get some samples.
+ // Note that sleeping isn't going to do us any good, the samples only
+ // accrue while we're executing code.
+ base::Time start = base::Time::Now();
+ base::TimeDelta elapsed;
+ do {
+ elapsed = base::Time::Now() - start;
+ } while((elapsed < spin_time) && *bucket_ptr == 0);
+
+ // Stop the profiler.
+ ASSERT_TRUE(profiler.Stop());
+
+ // Restore the sampling interval we found.
+ ASSERT_TRUE(SamplingProfiler::SetSamplingInterval(save_sampling_interval));
+
+ // Check that we got some samples.
+ ASSERT_NE(0U, profiler.buckets()[0]);
+}
+
+} // namespace win
+} // namespace base
diff --git a/base/win/scoped_bstr.cc b/base/win/scoped_bstr.cc
new file mode 100644
index 0000000..63ade0c
--- /dev/null
+++ b/base/win/scoped_bstr.cc
@@ -0,0 +1,71 @@
+// Copyright (c) 2010 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "base/win/scoped_bstr.h"
+
+#include "base/logging.h"
+
+namespace base {
+namespace win {
+
+ScopedBstr::ScopedBstr(const char16* non_bstr)
+ : bstr_(SysAllocString(non_bstr)) {
+}
+
+ScopedBstr::~ScopedBstr() {
+ COMPILE_ASSERT(sizeof(ScopedBstr) == sizeof(BSTR), ScopedBstrSize);
+ SysFreeString(bstr_);
+}
+
+void ScopedBstr::Reset(BSTR bstr) {
+ if (bstr != bstr_) {
+ // if |bstr_| is NULL, SysFreeString does nothing.
+ SysFreeString(bstr_);
+ bstr_ = bstr;
+ }
+}
+
+BSTR ScopedBstr::Release() {
+ BSTR bstr = bstr_;
+ bstr_ = NULL;
+ return bstr;
+}
+
+void ScopedBstr::Swap(ScopedBstr& bstr2) {
+ BSTR tmp = bstr_;
+ bstr_ = bstr2.bstr_;
+ bstr2.bstr_ = tmp;
+}
+
+BSTR* ScopedBstr::Receive() {
+ DCHECK(!bstr_) << "BSTR leak.";
+ return &bstr_;
+}
+
+BSTR ScopedBstr::Allocate(const char16* str) {
+ Reset(SysAllocString(str));
+ return bstr_;
+}
+
+BSTR ScopedBstr::AllocateBytes(size_t bytes) {
+ Reset(SysAllocStringByteLen(NULL, static_cast<UINT>(bytes)));
+ return bstr_;
+}
+
+void ScopedBstr::SetByteLen(size_t bytes) {
+ DCHECK(bstr_ != NULL) << "attempting to modify a NULL bstr";
+ uint32* data = reinterpret_cast<uint32*>(bstr_);
+ data[-1] = static_cast<uint32>(bytes);
+}
+
+size_t ScopedBstr::Length() const {
+ return SysStringLen(bstr_);
+}
+
+size_t ScopedBstr::ByteLength() const {
+ return SysStringByteLen(bstr_);
+}
+
+} // namespace win
+} // namespace base
diff --git a/base/win/scoped_bstr.h b/base/win/scoped_bstr.h
new file mode 100644
index 0000000..ed46d63
--- /dev/null
+++ b/base/win/scoped_bstr.h
@@ -0,0 +1,97 @@
+// Copyright (c) 2011 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef BASE_WIN_SCOPED_BSTR_H_
+#define BASE_WIN_SCOPED_BSTR_H_
+
+#include <windows.h>
+#include <oleauto.h>
+
+#include "base/base_export.h"
+#include "base/logging.h"
+#include "base/string16.h"
+
+namespace base {
+namespace win {
+
+// Manages a BSTR string pointer.
+// The class interface is based on scoped_ptr.
+class BASE_EXPORT ScopedBstr {
+ public:
+ ScopedBstr() : bstr_(NULL) {
+ }
+
+ // Constructor to create a new BSTR.
+ //
+ // NOTE: Do not pass a BSTR to this constructor expecting ownership to
+ // be transferred - even though it compiles! ;-)
+ explicit ScopedBstr(const char16* non_bstr);
+ ~ScopedBstr();
+
+ // Give ScopedBstr ownership over an already allocated BSTR or NULL.
+ // If you need to allocate a new BSTR instance, use |allocate| instead.
+ void Reset(BSTR bstr = NULL);
+
+ // Releases ownership of the BSTR to the caller.
+ BSTR Release();
+
+ // Creates a new BSTR from a 16-bit C-style string.
+ //
+ // If you already have a BSTR and want to transfer ownership to the
+ // ScopedBstr instance, call |reset| instead.
+ //
+ // Returns a pointer to the new BSTR, or NULL if allocation failed.
+ BSTR Allocate(const char16* str);
+
+ // Allocates a new BSTR with the specified number of bytes.
+ // Returns a pointer to the new BSTR, or NULL if allocation failed.
+ BSTR AllocateBytes(size_t bytes);
+
+ // Sets the allocated length field of the already-allocated BSTR to be
+ // |bytes|. This is useful when the BSTR was preallocated with e.g.
+ // SysAllocStringLen or SysAllocStringByteLen (call |AllocateBytes|) and then
+ // not all the bytes are being used.
+ //
+ // Note that if you want to set the length to a specific number of
+ // characters, you need to multiply by sizeof(wchar_t). Oddly, there's no
+ // public API to set the length, so we do this ourselves by hand.
+ //
+ // NOTE: The actual allocated size of the BSTR MUST be >= bytes. That
+ // responsibility is with the caller.
+ void SetByteLen(size_t bytes);
+
+ // Swap values of two ScopedBstr's.
+ void Swap(ScopedBstr& bstr2);
+
+ // Retrieves the pointer address.
+ // Used to receive BSTRs as out arguments (and take ownership).
+ // The function DCHECKs on the current value being NULL.
+ // Usage: GetBstr(bstr.Receive());
+ BSTR* Receive();
+
+ // Returns number of chars in the BSTR.
+ size_t Length() const;
+
+ // Returns the number of bytes allocated for the BSTR.
+ size_t ByteLength() const;
+
+ operator BSTR() const {
+ return bstr_;
+ }
+
+ protected:
+ BSTR bstr_;
+
+ private:
+ // Forbid comparison of ScopedBstr types. You should never have the same
+ // BSTR owned by two different scoped_ptrs.
+ bool operator==(const ScopedBstr& bstr2) const;
+ bool operator!=(const ScopedBstr& bstr2) const;
+ DISALLOW_COPY_AND_ASSIGN(ScopedBstr);
+};
+
+} // namespace win
+} // namespace base
+
+#endif // BASE_SCOPED_BSTR_H_
diff --git a/base/win/scoped_bstr_unittest.cc b/base/win/scoped_bstr_unittest.cc
new file mode 100644
index 0000000..5f6f7df
--- /dev/null
+++ b/base/win/scoped_bstr_unittest.cc
@@ -0,0 +1,77 @@
+// Copyright (c) 2010 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "base/win/scoped_bstr.h"
+#include "testing/gtest/include/gtest/gtest.h"
+
+namespace base {
+namespace win {
+
+namespace {
+
+static const wchar_t kTestString1[] = L"123";
+static const wchar_t kTestString2[] = L"456789";
+size_t test1_len = arraysize(kTestString1) - 1;
+size_t test2_len = arraysize(kTestString2) - 1;
+
+void DumbBstrTests() {
+ ScopedBstr b;
+ EXPECT_TRUE(b == NULL);
+ EXPECT_EQ(0, b.Length());
+ EXPECT_EQ(0, b.ByteLength());
+ b.Reset(NULL);
+ EXPECT_TRUE(b == NULL);
+ EXPECT_TRUE(b.Release() == NULL);
+ ScopedBstr b2;
+ b.Swap(b2);
+ EXPECT_TRUE(b2 == NULL);
+}
+
+void GiveMeABstr(BSTR* ret) {
+ *ret = SysAllocString(kTestString1);
+}
+
+void BasicBstrTests() {
+ ScopedBstr b1(kTestString1);
+ EXPECT_EQ(test1_len, b1.Length());
+ EXPECT_EQ(test1_len * sizeof(kTestString1[0]), b1.ByteLength());
+
+ ScopedBstr b2;
+ b1.Swap(b2);
+ EXPECT_EQ(test1_len, b2.Length());
+ EXPECT_EQ(0, b1.Length());
+ EXPECT_EQ(0, lstrcmp(b2, kTestString1));
+ BSTR tmp = b2.Release();
+ EXPECT_TRUE(tmp != NULL);
+ EXPECT_EQ(0, lstrcmp(tmp, kTestString1));
+ EXPECT_TRUE(b2 == NULL);
+ SysFreeString(tmp);
+
+ GiveMeABstr(b2.Receive());
+ EXPECT_TRUE(b2 != NULL);
+ b2.Reset();
+ EXPECT_TRUE(b2.AllocateBytes(100) != NULL);
+ EXPECT_EQ(100, b2.ByteLength());
+ EXPECT_EQ(100 / sizeof(kTestString1[0]), b2.Length());
+ lstrcpy(static_cast<BSTR>(b2), kTestString1);
+ EXPECT_EQ(test1_len, lstrlen(b2));
+ EXPECT_EQ(100 / sizeof(kTestString1[0]), b2.Length());
+ b2.SetByteLen(lstrlen(b2) * sizeof(kTestString2[0]));
+ EXPECT_EQ(b2.Length(), lstrlen(b2));
+
+ EXPECT_TRUE(b1.Allocate(kTestString2) != NULL);
+ EXPECT_EQ(test2_len, b1.Length());
+ b1.SetByteLen((test2_len - 1) * sizeof(kTestString2[0]));
+ EXPECT_EQ(test2_len - 1, b1.Length());
+}
+
+} // namespace
+
+TEST(ScopedBstrTest, ScopedBstr) {
+ DumbBstrTests();
+ BasicBstrTests();
+}
+
+} // namespace win
+} // namespace base
diff --git a/base/win/scoped_co_mem.h b/base/win/scoped_co_mem.h
new file mode 100644
index 0000000..572999a
--- /dev/null
+++ b/base/win/scoped_co_mem.h
@@ -0,0 +1,64 @@
+// Copyright (c) 2011 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef BASE_WIN_SCOPED_CO_MEM_H_
+#define BASE_WIN_SCOPED_CO_MEM_H_
+
+#include <objbase.h>
+
+#include "base/basictypes.h"
+#include "base/logging.h"
+
+namespace base {
+namespace win {
+
+// Simple scoped memory releaser class for COM allocated memory.
+// Example:
+// base::win::ScopedCoMem<ITEMIDLIST> file_item;
+// SHGetSomeInfo(&file_item, ...);
+// ...
+// return; <-- memory released
+template<typename T>
+class ScopedCoMem {
+ public:
+ ScopedCoMem() : mem_ptr_(NULL) {}
+ ~ScopedCoMem() {
+ Reset(NULL);
+ }
+
+ T** operator&() { // NOLINT
+ DCHECK(mem_ptr_ == NULL); // To catch memory leaks.
+ return &mem_ptr_;
+ }
+
+ operator T*() {
+ return mem_ptr_;
+ }
+
+ T* operator->() {
+ DCHECK(mem_ptr_ != NULL);
+ return mem_ptr_;
+ }
+
+ const T* operator->() const {
+ DCHECK(mem_ptr_ != NULL);
+ return mem_ptr_;
+ }
+
+ void Reset(T* ptr) {
+ if (mem_ptr_)
+ CoTaskMemFree(mem_ptr_);
+ mem_ptr_ = ptr;
+ }
+
+ private:
+ T* mem_ptr_;
+
+ DISALLOW_COPY_AND_ASSIGN(ScopedCoMem);
+};
+
+} // namespace win
+} // namespace base
+
+#endif // BASE_WIN_SCOPED_CO_MEM_H_
diff --git a/base/win/scoped_com_initializer.h b/base/win/scoped_com_initializer.h
new file mode 100644
index 0000000..392c351
--- /dev/null
+++ b/base/win/scoped_com_initializer.h
@@ -0,0 +1,74 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef BASE_WIN_SCOPED_COM_INITIALIZER_H_
+#define BASE_WIN_SCOPED_COM_INITIALIZER_H_
+
+#include <objbase.h>
+
+#include "base/basictypes.h"
+#include "base/logging.h"
+#include "build/build_config.h"
+
+namespace base {
+namespace win {
+
+// Initializes COM in the constructor (STA or MTA), and uninitializes COM in the
+// destructor.
+class ScopedCOMInitializer {
+ public:
+ // Enum value provided to initialize the thread as an MTA instead of STA.
+ enum SelectMTA { kMTA };
+
+ // Constructor for STA initialization.
+ ScopedCOMInitializer() {
+ Initialize(COINIT_APARTMENTTHREADED);
+ }
+
+ // Constructor for MTA initialization.
+ explicit ScopedCOMInitializer(SelectMTA mta) {
+ Initialize(COINIT_MULTITHREADED);
+ }
+
+ ~ScopedCOMInitializer() {
+#ifndef NDEBUG
+ // Using the windows API directly to avoid dependency on platform_thread.
+ DCHECK_EQ(GetCurrentThreadId(), thread_id_);
+#endif
+ if (succeeded())
+ CoUninitialize();
+ }
+
+ bool succeeded() const { return SUCCEEDED(hr_); }
+
+ private:
+ void Initialize(COINIT init) {
+#ifndef NDEBUG
+ thread_id_ = GetCurrentThreadId();
+#endif
+ hr_ = CoInitializeEx(NULL, init);
+#ifndef NDEBUG
+ if (hr_ == S_FALSE)
+ LOG(ERROR) << "Multiple CoInitialize() calls for thread " << thread_id_;
+ else
+ DCHECK_NE(RPC_E_CHANGED_MODE, hr_) << "Invalid COM thread model change";
+#endif
+ }
+
+ HRESULT hr_;
+#ifndef NDEBUG
+ // In debug builds we use this variable to catch a potential bug where a
+ // ScopedCOMInitializer instance is deleted on a different thread than it
+ // was initially created on. If that ever happens it can have bad
+ // consequences and the cause can be tricky to track down.
+ DWORD thread_id_;
+#endif
+
+ DISALLOW_COPY_AND_ASSIGN(ScopedCOMInitializer);
+};
+
+} // namespace win
+} // namespace base
+
+#endif // BASE_WIN_SCOPED_COM_INITIALIZER_H_
diff --git a/base/win/scoped_comptr.h b/base/win/scoped_comptr.h
new file mode 100644
index 0000000..9d5301f
--- /dev/null
+++ b/base/win/scoped_comptr.h
@@ -0,0 +1,168 @@
+// Copyright (c) 2011 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef BASE_WIN_SCOPED_COMPTR_H_
+#define BASE_WIN_SCOPED_COMPTR_H_
+
+#include <unknwn.h>
+
+#include "base/logging.h"
+#include "base/memory/ref_counted.h"
+
+namespace base {
+namespace win {
+
+// A fairly minimalistic smart class for COM interface pointers.
+// Uses scoped_refptr for the basic smart pointer functionality
+// and adds a few IUnknown specific services.
+template <class Interface, const IID* interface_id = &__uuidof(Interface)>
+class ScopedComPtr : public scoped_refptr<Interface> {
+ public:
+ // Utility template to prevent users of ScopedComPtr from calling AddRef
+ // and/or Release() without going through the ScopedComPtr class.
+ class BlockIUnknownMethods : public Interface {
+ private:
+ STDMETHOD(QueryInterface)(REFIID iid, void** object) = 0;
+ STDMETHOD_(ULONG, AddRef)() = 0;
+ STDMETHOD_(ULONG, Release)() = 0;
+ };
+
+ typedef scoped_refptr<Interface> ParentClass;
+
+ ScopedComPtr() {
+ }
+
+ explicit ScopedComPtr(Interface* p) : ParentClass(p) {
+ }
+
+ ScopedComPtr(const ScopedComPtr<Interface, interface_id>& p)
+ : ParentClass(p) {
+ }
+
+ ~ScopedComPtr() {
+ // We don't want the smart pointer class to be bigger than the pointer
+ // it wraps.
+ COMPILE_ASSERT(sizeof(ScopedComPtr<Interface, interface_id>) ==
+ sizeof(Interface*), ScopedComPtrSize);
+ }
+
+ // Explicit Release() of the held object. Useful for reuse of the
+ // ScopedComPtr instance.
+ // Note that this function equates to IUnknown::Release and should not
+ // be confused with e.g. scoped_ptr::release().
+ void Release() {
+ if (ptr_ != NULL) {
+ ptr_->Release();
+ ptr_ = NULL;
+ }
+ }
+
+ // Sets the internal pointer to NULL and returns the held object without
+ // releasing the reference.
+ Interface* Detach() {
+ Interface* p = ptr_;
+ ptr_ = NULL;
+ return p;
+ }
+
+ // Accepts an interface pointer that has already been addref-ed.
+ void Attach(Interface* p) {
+ DCHECK(!ptr_);
+ ptr_ = p;
+ }
+
+ // Retrieves the pointer address.
+ // Used to receive object pointers as out arguments (and take ownership).
+ // The function DCHECKs on the current value being NULL.
+ // Usage: Foo(p.Receive());
+ Interface** Receive() {
+ DCHECK(!ptr_) << "Object leak. Pointer must be NULL";
+ return &ptr_;
+ }
+
+ // A convenience for whenever a void pointer is needed as an out argument.
+ void** ReceiveVoid() {
+ return reinterpret_cast<void**>(Receive());
+ }
+
+ template <class Query>
+ HRESULT QueryInterface(Query** p) {
+ DCHECK(p != NULL);
+ DCHECK(ptr_ != NULL);
+ // IUnknown already has a template version of QueryInterface
+ // so the iid parameter is implicit here. The only thing this
+ // function adds are the DCHECKs.
+ return ptr_->QueryInterface(p);
+ }
+
+ // QI for times when the IID is not associated with the type.
+ HRESULT QueryInterface(const IID& iid, void** obj) {
+ DCHECK(obj != NULL);
+ DCHECK(ptr_ != NULL);
+ return ptr_->QueryInterface(iid, obj);
+ }
+
+ // Queries |other| for the interface this object wraps and returns the
+ // error code from the other->QueryInterface operation.
+ HRESULT QueryFrom(IUnknown* object) {
+ DCHECK(object != NULL);
+ return object->QueryInterface(Receive());
+ }
+
+ // Convenience wrapper around CoCreateInstance
+ HRESULT CreateInstance(const CLSID& clsid, IUnknown* outer = NULL,
+ DWORD context = CLSCTX_ALL) {
+ DCHECK(!ptr_);
+ HRESULT hr = ::CoCreateInstance(clsid, outer, context, *interface_id,
+ reinterpret_cast<void**>(&ptr_));
+ return hr;
+ }
+
+ // Checks if the identity of |other| and this object is the same.
+ bool IsSameObject(IUnknown* other) {
+ if (!other && !ptr_)
+ return true;
+
+ if (!other || !ptr_)
+ return false;
+
+ ScopedComPtr<IUnknown> my_identity;
+ QueryInterface(my_identity.Receive());
+
+ ScopedComPtr<IUnknown> other_identity;
+ other->QueryInterface(other_identity.Receive());
+
+ return static_cast<IUnknown*>(my_identity) ==
+ static_cast<IUnknown*>(other_identity);
+ }
+
+ // Provides direct access to the interface.
+ // Here we use a well known trick to make sure we block access to
+ // IUknown methods so that something bad like this doesn't happen:
+ // ScopedComPtr<IUnknown> p(Foo());
+ // p->Release();
+ // ... later the destructor runs, which will Release() again.
+ // and to get the benefit of the DCHECKs we add to QueryInterface.
+ // There's still a way to call these methods if you absolutely must
+ // by statically casting the ScopedComPtr instance to the wrapped interface
+ // and then making the call... but generally that shouldn't be necessary.
+ BlockIUnknownMethods* operator->() const {
+ DCHECK(ptr_ != NULL);
+ return reinterpret_cast<BlockIUnknownMethods*>(ptr_);
+ }
+
+ // Pull in operator=() from the parent class.
+ using scoped_refptr<Interface>::operator=;
+
+ // static methods
+
+ static const IID& iid() {
+ return *interface_id;
+ }
+};
+
+} // namespace win
+} // namespace base
+
+#endif // BASE_WIN_SCOPED_COMPTR_H_
diff --git a/base/win/scoped_comptr_unittest.cc b/base/win/scoped_comptr_unittest.cc
new file mode 100644
index 0000000..d8d12be
--- /dev/null
+++ b/base/win/scoped_comptr_unittest.cc
@@ -0,0 +1,111 @@
+// Copyright (c) 2011 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "base/win/scoped_comptr.h"
+
+#include <shlobj.h>
+
+#include "base/memory/scoped_ptr.h"
+#include "base/win/scoped_com_initializer.h"
+#include "testing/gtest/include/gtest/gtest.h"
+
+namespace base {
+namespace win {
+
+namespace {
+
+struct Dummy {
+ Dummy() : adds(0), releases(0) { }
+ void AddRef() { ++adds; }
+ void Release() { ++releases; }
+
+ int adds;
+ int releases;
+};
+
+extern const IID dummy_iid;
+const IID dummy_iid = { 0x12345678u, 0x1234u, 0x5678u, 01, 23, 45, 67, 89,
+ 01, 23, 45 };
+
+} // namespace
+
+TEST(ScopedComPtrTest, ScopedComPtr) {
+ EXPECT_TRUE(memcmp(&ScopedComPtr<IUnknown>::iid(), &IID_IUnknown,
+ sizeof(IID)) == 0);
+
+ base::win::ScopedCOMInitializer com_initializer;
+ EXPECT_TRUE(com_initializer.succeeded());
+
+ ScopedComPtr<IUnknown> unk;
+ EXPECT_TRUE(SUCCEEDED(unk.CreateInstance(CLSID_ShellLink)));
+ ScopedComPtr<IUnknown> unk2;
+ unk2.Attach(unk.Detach());
+ EXPECT_TRUE(unk == NULL);
+ EXPECT_TRUE(unk2 != NULL);
+
+ ScopedComPtr<IMalloc> mem_alloc;
+ EXPECT_TRUE(SUCCEEDED(CoGetMalloc(1, mem_alloc.Receive())));
+
+ ScopedComPtr<IUnknown> qi_test;
+ EXPECT_HRESULT_SUCCEEDED(mem_alloc.QueryInterface(IID_IUnknown,
+ reinterpret_cast<void**>(qi_test.Receive())));
+ EXPECT_TRUE(qi_test.get() != NULL);
+ qi_test.Release();
+
+ // test ScopedComPtr& constructor
+ ScopedComPtr<IMalloc> copy1(mem_alloc);
+ EXPECT_TRUE(copy1.IsSameObject(mem_alloc));
+ EXPECT_FALSE(copy1.IsSameObject(unk2)); // unk2 is valid but different
+ EXPECT_FALSE(copy1.IsSameObject(unk)); // unk is NULL
+
+ IMalloc* naked_copy = copy1.Detach();
+ copy1 = naked_copy; // Test the =(T*) operator.
+ naked_copy->Release();
+
+ copy1.Release();
+ EXPECT_FALSE(copy1.IsSameObject(unk2)); // unk2 is valid, copy1 is not
+
+ // test Interface* constructor
+ ScopedComPtr<IMalloc> copy2(static_cast<IMalloc*>(mem_alloc));
+ EXPECT_TRUE(copy2.IsSameObject(mem_alloc));
+
+ EXPECT_TRUE(SUCCEEDED(unk.QueryFrom(mem_alloc)));
+ EXPECT_TRUE(unk != NULL);
+ unk.Release();
+ EXPECT_TRUE(unk == NULL);
+ EXPECT_TRUE(unk.IsSameObject(copy1)); // both are NULL
+}
+
+TEST(ScopedComPtrTest, ScopedComPtrVector) {
+ // Verify we don't get error C2558.
+ typedef ScopedComPtr<Dummy, &dummy_iid> Ptr;
+ std::vector<Ptr> bleh;
+
+ scoped_ptr<Dummy> p(new Dummy);
+ {
+ Ptr p2(p.get());
+ EXPECT_EQ(p->adds, 1);
+ EXPECT_EQ(p->releases, 0);
+ Ptr p3 = p2;
+ EXPECT_EQ(p->adds, 2);
+ EXPECT_EQ(p->releases, 0);
+ p3 = p2;
+ EXPECT_EQ(p->adds, 3);
+ EXPECT_EQ(p->releases, 1);
+ // To avoid hitting a reallocation.
+ bleh.reserve(1);
+ bleh.push_back(p2);
+ EXPECT_EQ(p->adds, 4);
+ EXPECT_EQ(p->releases, 1);
+ EXPECT_EQ(bleh[0], p.get());
+ bleh.pop_back();
+ EXPECT_EQ(p->adds, 4);
+ EXPECT_EQ(p->releases, 2);
+ }
+ EXPECT_EQ(p->adds, 4);
+ EXPECT_EQ(p->releases, 4);
+}
+
+} // namespace win
+} // namespace base
diff --git a/base/win/scoped_gdi_object.h b/base/win/scoped_gdi_object.h
new file mode 100644
index 0000000..d44310a
--- /dev/null
+++ b/base/win/scoped_gdi_object.h
@@ -0,0 +1,77 @@
+// Copyright (c) 2010 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef BASE_WIN_SCOPED_GDI_OBJECT_H_
+#define BASE_WIN_SCOPED_GDI_OBJECT_H_
+
+#include <windows.h>
+
+#include "base/basictypes.h"
+#include "base/logging.h"
+
+namespace base {
+namespace win {
+
+// Like ScopedHandle but for GDI objects.
+template<class T>
+class ScopedGDIObject {
+ public:
+ ScopedGDIObject() : object_(NULL) {}
+ explicit ScopedGDIObject(T object) : object_(object) {}
+
+ ~ScopedGDIObject() {
+ Close();
+ }
+
+ T Get() {
+ return object_;
+ }
+
+ void Set(T object) {
+ if (object_ && object != object_)
+ Close();
+ object_ = object;
+ }
+
+ ScopedGDIObject& operator=(T object) {
+ Set(object);
+ return *this;
+ }
+
+ T release() {
+ T object = object_;
+ object_ = NULL;
+ return object;
+ }
+
+ operator T() { return object_; }
+
+ private:
+ void Close() {
+ if (object_)
+ DeleteObject(object_);
+ }
+
+ T object_;
+ DISALLOW_COPY_AND_ASSIGN(ScopedGDIObject);
+};
+
+// An explicit specialization for HICON because we have to call DestroyIcon()
+// instead of DeleteObject() for HICON.
+template<>
+void ScopedGDIObject<HICON>::Close() {
+ if (object_)
+ DestroyIcon(object_);
+}
+
+// Typedefs for some common use cases.
+typedef ScopedGDIObject<HBITMAP> ScopedBitmap;
+typedef ScopedGDIObject<HRGN> ScopedRegion;
+typedef ScopedGDIObject<HFONT> ScopedHFONT;
+typedef ScopedGDIObject<HICON> ScopedHICON;
+
+} // namespace win
+} // namespace base
+
+#endif // BASE_WIN_SCOPED_GDI_OBJECT_H_
diff --git a/base/win/scoped_handle.cc b/base/win/scoped_handle.cc
new file mode 100644
index 0000000..03d026a
--- /dev/null
+++ b/base/win/scoped_handle.cc
@@ -0,0 +1,83 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "base/win/scoped_handle.h"
+
+#include <map>
+#include <set>
+
+#include "base/debug/alias.h"
+#include "base/lazy_instance.h"
+#include "base/synchronization/lock.h"
+#include "base/win/windows_version.h"
+
+namespace {
+
+struct Info {
+ const void* owner;
+ const void* pc1;
+ const void* pc2;
+ DWORD thread_id;
+};
+typedef std::map<HANDLE, Info> HandleMap;
+
+base::LazyInstance<HandleMap>::Leaky g_handle_map = LAZY_INSTANCE_INITIALIZER;
+base::LazyInstance<std::set<const void*> >::Leaky g_owner_set =
+ LAZY_INSTANCE_INITIALIZER;
+base::LazyInstance<base::Lock>::Leaky g_lock = LAZY_INSTANCE_INITIALIZER;
+
+} // namespace
+
+namespace base {
+namespace win {
+
+// Static.
+void VerifierTraits::StartTracking(HANDLE handle, const void* owner,
+ const void* pc1, const void* pc2) {
+ // Grab the thread id before the lock.
+ DWORD thread_id = GetCurrentThreadId();
+
+ AutoLock lock(g_lock.Get());
+
+ if (handle == INVALID_HANDLE_VALUE) {
+ // Cannot track this handle.
+ g_owner_set.Get().insert(owner);
+ return;
+ }
+
+ Info handle_info = { owner, pc1, pc2, thread_id };
+ std::pair<HANDLE, Info> item(handle, handle_info);
+ std::pair<HandleMap::iterator, bool> result = g_handle_map.Get().insert(item);
+ if (!result.second) {
+ Info other = result.first->second;
+ debug::Alias(&other);
+ CHECK(false);
+ }
+}
+
+// Static.
+void VerifierTraits::StopTracking(HANDLE handle, const void* owner,
+ const void* pc1, const void* pc2) {
+ AutoLock lock(g_lock.Get());
+ HandleMap::iterator i = g_handle_map.Get().find(handle);
+ if (i == g_handle_map.Get().end()) {
+ std::set<const void*>::iterator j = g_owner_set.Get().find(owner);
+ if (j != g_owner_set.Get().end()) {
+ g_owner_set.Get().erase(j);
+ return;
+ }
+ CHECK(false);
+ }
+
+ Info other = i->second;
+ if (other.owner != owner) {
+ debug::Alias(&other);
+ CHECK(false);
+ }
+
+ g_handle_map.Get().erase(i);
+}
+
+} // namespace win
+} // namespace base
diff --git a/base/win/scoped_handle.h b/base/win/scoped_handle.h
new file mode 100644
index 0000000..b5d9b5c
--- /dev/null
+++ b/base/win/scoped_handle.h
@@ -0,0 +1,188 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef BASE_WIN_SCOPED_HANDLE_H_
+#define BASE_WIN_SCOPED_HANDLE_H_
+
+#include <windows.h>
+
+#include "base/base_export.h"
+#include "base/basictypes.h"
+#include "base/location.h"
+#include "base/logging.h"
+#include "base/move.h"
+
+namespace base {
+namespace win {
+
+// TODO(rvargas): remove this with the rest of the verifier.
+#if defined(COMPILER_MSVC)
+// MSDN says to #include <intrin.h>, but that breaks the VS2005 build.
+extern "C" {
+ void* _ReturnAddress();
+}
+#define BASE_WIN_GET_CALLER _ReturnAddress()
+#elif defined(COMPILER_GCC)
+#define BASE_WIN_GET_CALLER __builtin_extract_return_addr(\\
+ __builtin_return_address(0))
+#endif
+
+// Generic wrapper for raw handles that takes care of closing handles
+// automatically. The class interface follows the style of
+// the ScopedStdioHandle class with a few additions:
+// - IsValid() method can tolerate multiple invalid handle values such as NULL
+// and INVALID_HANDLE_VALUE (-1) for Win32 handles.
+// - Receive() method allows to receive a handle value from a function that
+// takes a raw handle pointer only.
+template <class Traits, class Verifier>
+class GenericScopedHandle {
+ MOVE_ONLY_TYPE_FOR_CPP_03(GenericScopedHandle, RValue)
+
+ public:
+ typedef typename Traits::Handle Handle;
+
+ GenericScopedHandle() : handle_(Traits::NullHandle()) {}
+
+ explicit GenericScopedHandle(Handle handle) : handle_(Traits::NullHandle()) {
+ Set(handle);
+ }
+
+ // Move constructor for C++03 move emulation of this type.
+ GenericScopedHandle(RValue other) : handle_(Traits::NullHandle()) {
+ Set(other.object->Take());
+ }
+
+ ~GenericScopedHandle() {
+ Close();
+ }
+
+ bool IsValid() const {
+ return Traits::IsHandleValid(handle_);
+ }
+
+ // Move operator= for C++03 move emulation of this type.
+ GenericScopedHandle& operator=(RValue other) {
+ if (this != other.object) {
+ Set(other.object->Take());
+ }
+ return *this;
+ }
+
+ void Set(Handle handle) {
+ if (handle_ != handle) {
+ Close();
+
+ if (Traits::IsHandleValid(handle)) {
+ handle_ = handle;
+ Verifier::StartTracking(handle, this, BASE_WIN_GET_CALLER,
+ tracked_objects::GetProgramCounter());
+ }
+ }
+ }
+
+ Handle Get() const {
+ return handle_;
+ }
+
+ operator Handle() const {
+ return handle_;
+ }
+
+ Handle* Receive() {
+ DCHECK(!Traits::IsHandleValid(handle_)) << "Handle must be NULL";
+
+ // We cannot track this case :(. Just tell the verifier about it.
+ Verifier::StartTracking(INVALID_HANDLE_VALUE, this, BASE_WIN_GET_CALLER,
+ tracked_objects::GetProgramCounter());
+ return &handle_;
+ }
+
+ // Transfers ownership away from this object.
+ Handle Take() {
+ Handle temp = handle_;
+ handle_ = Traits::NullHandle();
+ if (Traits::IsHandleValid(temp)) {
+ Verifier::StopTracking(temp, this, BASE_WIN_GET_CALLER,
+ tracked_objects::GetProgramCounter());
+ }
+ return temp;
+ }
+
+ // Explicitly closes the owned handle.
+ void Close() {
+ if (Traits::IsHandleValid(handle_)) {
+ Verifier::StopTracking(handle_, this, BASE_WIN_GET_CALLER,
+ tracked_objects::GetProgramCounter());
+
+ if (!Traits::CloseHandle(handle_))
+ CHECK(false);
+
+ handle_ = Traits::NullHandle();
+ }
+ }
+
+ private:
+ Handle handle_;
+};
+
+#undef BASE_WIN_GET_CALLER
+
+// The traits class for Win32 handles that can be closed via CloseHandle() API.
+class HandleTraits {
+ public:
+ typedef HANDLE Handle;
+
+ // Closes the handle.
+ static bool CloseHandle(HANDLE handle) {
+ return ::CloseHandle(handle) != FALSE;
+ }
+
+ // Returns true if the handle value is valid.
+ static bool IsHandleValid(HANDLE handle) {
+ return handle != NULL && handle != INVALID_HANDLE_VALUE;
+ }
+
+ // Returns NULL handle value.
+ static HANDLE NullHandle() {
+ return NULL;
+ }
+
+ private:
+ DISALLOW_IMPLICIT_CONSTRUCTORS(HandleTraits);
+};
+
+// Do-nothing verifier.
+class DummyVerifierTraits {
+ public:
+ typedef HANDLE Handle;
+
+ static void StartTracking(HANDLE handle, const void* owner,
+ const void* pc1, const void* pc2) {}
+ static void StopTracking(HANDLE handle, const void* owner,
+ const void* pc1, const void* pc2) {}
+
+ private:
+ DISALLOW_IMPLICIT_CONSTRUCTORS(DummyVerifierTraits);
+};
+
+// Performs actual run-time tracking.
+class BASE_EXPORT VerifierTraits {
+ public:
+ typedef HANDLE Handle;
+
+ static void StartTracking(HANDLE handle, const void* owner,
+ const void* pc1, const void* pc2);
+ static void StopTracking(HANDLE handle, const void* owner,
+ const void* pc1, const void* pc2);
+
+ private:
+ DISALLOW_IMPLICIT_CONSTRUCTORS(VerifierTraits);
+};
+
+typedef GenericScopedHandle<HandleTraits, VerifierTraits> ScopedHandle;
+
+} // namespace win
+} // namespace base
+
+#endif // BASE_SCOPED_HANDLE_WIN_H_
diff --git a/base/win/scoped_hdc.h b/base/win/scoped_hdc.h
new file mode 100644
index 0000000..9aead96
--- /dev/null
+++ b/base/win/scoped_hdc.h
@@ -0,0 +1,76 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef BASE_WIN_SCOPED_HDC_H_
+#define BASE_WIN_SCOPED_HDC_H_
+
+#include <windows.h>
+
+#include "base/basictypes.h"
+#include "base/logging.h"
+#include "base/win/scoped_handle.h"
+
+namespace base {
+namespace win {
+
+// Like ScopedHandle but for HDC. Only use this on HDCs returned from
+// GetDC.
+class ScopedGetDC {
+ public:
+ explicit ScopedGetDC(HWND hwnd)
+ : hwnd_(hwnd),
+ hdc_(GetDC(hwnd)) {
+ if (hwnd_) {
+ DCHECK(IsWindow(hwnd_));
+ DCHECK(hdc_);
+ } else {
+ // If GetDC(NULL) returns NULL, something really bad has happened, like
+ // GDI handle exhaustion. In this case Chrome is going to behave badly no
+ // matter what, so we may as well just force a crash now.
+ CHECK(hdc_);
+ }
+ }
+
+ ~ScopedGetDC() {
+ if (hdc_)
+ ReleaseDC(hwnd_, hdc_);
+ }
+
+ operator HDC() { return hdc_; }
+
+ private:
+ HWND hwnd_;
+ HDC hdc_;
+
+ DISALLOW_COPY_AND_ASSIGN(ScopedGetDC);
+};
+
+// Like ScopedHandle but for HDC. Only use this on HDCs returned from
+// CreateCompatibleDC, CreateDC and CreateIC.
+class CreateDCTraits {
+ public:
+ typedef HDC Handle;
+
+ static bool CloseHandle(HDC handle) {
+ return ::DeleteDC(handle) != FALSE;
+ }
+
+ static bool IsHandleValid(HDC handle) {
+ return handle != NULL;
+ }
+
+ static HDC NullHandle() {
+ return NULL;
+ }
+
+ private:
+ DISALLOW_IMPLICIT_CONSTRUCTORS(CreateDCTraits);
+};
+
+typedef GenericScopedHandle<CreateDCTraits, VerifierTraits> ScopedCreateDC;
+
+} // namespace win
+} // namespace base
+
+#endif // BASE_WIN_SCOPED_HDC_H_
diff --git a/base/win/scoped_hglobal.h b/base/win/scoped_hglobal.h
new file mode 100644
index 0000000..891e6cd
--- /dev/null
+++ b/base/win/scoped_hglobal.h
@@ -0,0 +1,52 @@
+// Copyright (c) 2010 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef BASE_WIN_SCOPED_HGLOBAL_H_
+#define BASE_WIN_SCOPED_HGLOBAL_H_
+
+#include <windows.h>
+
+#include "base/basictypes.h"
+
+namespace base {
+namespace win {
+
+// Like ScopedHandle except for HGLOBAL.
+template<class T>
+class ScopedHGlobal {
+ public:
+ explicit ScopedHGlobal(HGLOBAL glob) : glob_(glob) {
+ data_ = static_cast<T*>(GlobalLock(glob_));
+ }
+ ~ScopedHGlobal() {
+ GlobalUnlock(glob_);
+ }
+
+ T* get() { return data_; }
+
+ size_t Size() const { return GlobalSize(glob_); }
+
+ T* operator->() const {
+ assert(data_ != 0);
+ return data_;
+ }
+
+ T* release() {
+ T* data = data_;
+ data_ = NULL;
+ return data;
+ }
+
+ private:
+ HGLOBAL glob_;
+
+ T* data_;
+
+ DISALLOW_COPY_AND_ASSIGN(ScopedHGlobal);
+};
+
+} // namespace win
+} // namespace base
+
+#endif // BASE_WIN_SCOPED_HGLOBAL_H_
diff --git a/base/win/scoped_process_information.cc b/base/win/scoped_process_information.cc
new file mode 100644
index 0000000..4adb8d4
--- /dev/null
+++ b/base/win/scoped_process_information.cc
@@ -0,0 +1,126 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "base/win/scoped_process_information.h"
+
+#include "base/logging.h"
+#include "base/win/scoped_handle.h"
+
+namespace base {
+namespace win {
+
+namespace {
+
+// Closes the provided handle if it is not NULL.
+void CheckAndCloseHandle(HANDLE handle) {
+ if (!handle)
+ return;
+ if (::CloseHandle(handle))
+ return;
+ CHECK(false);
+}
+
+// Duplicates source into target, returning true upon success. |target| is
+// guaranteed to be untouched in case of failure. Succeeds with no side-effects
+// if source is NULL.
+bool CheckAndDuplicateHandle(HANDLE source, HANDLE* target) {
+ if (!source)
+ return true;
+
+ HANDLE temp = NULL;
+ if (!::DuplicateHandle(::GetCurrentProcess(), source,
+ ::GetCurrentProcess(), &temp, 0, FALSE,
+ DUPLICATE_SAME_ACCESS)) {
+ DPLOG(ERROR) << "Failed to duplicate a handle.";
+ return false;
+ }
+ *target = temp;
+ return true;
+}
+
+} // namespace
+
+ScopedProcessInformation::ScopedProcessInformation()
+ : process_information_() {
+}
+
+ScopedProcessInformation::~ScopedProcessInformation() {
+ Close();
+}
+
+PROCESS_INFORMATION* ScopedProcessInformation::Receive() {
+ DCHECK(!IsValid()) << "process_information_ must be NULL";
+ return &process_information_;
+}
+
+bool ScopedProcessInformation::IsValid() const {
+ return process_information_.hThread || process_information_.hProcess ||
+ process_information_.dwProcessId || process_information_.dwThreadId;
+}
+
+
+void ScopedProcessInformation::Close() {
+ CheckAndCloseHandle(process_information_.hThread);
+ CheckAndCloseHandle(process_information_.hProcess);
+ Reset();
+}
+
+void ScopedProcessInformation::Swap(ScopedProcessInformation* other) {
+ DCHECK(other);
+ PROCESS_INFORMATION temp = other->process_information_;
+ other->process_information_ = process_information_;
+ process_information_ = temp;
+}
+
+bool ScopedProcessInformation::DuplicateFrom(
+ const ScopedProcessInformation& other) {
+ DCHECK(!IsValid()) << "target ScopedProcessInformation must be NULL";
+ DCHECK(other.IsValid()) << "source ScopedProcessInformation must be valid";
+
+ ScopedHandle duplicate_process;
+ ScopedHandle duplicate_thread;
+
+ if (CheckAndDuplicateHandle(other.process_handle(),
+ duplicate_process.Receive()) &&
+ CheckAndDuplicateHandle(other.thread_handle(),
+ duplicate_thread.Receive())) {
+ process_information_.dwProcessId = other.process_id();
+ process_information_.dwThreadId = other.thread_id();
+ process_information_.hProcess = duplicate_process.Take();
+ process_information_.hThread = duplicate_thread.Take();
+ return true;
+ }
+
+ return false;
+}
+
+PROCESS_INFORMATION ScopedProcessInformation::Take() {
+ PROCESS_INFORMATION process_information = process_information_;
+ Reset();
+ return process_information;
+}
+
+HANDLE ScopedProcessInformation::TakeProcessHandle() {
+ HANDLE process = process_information_.hProcess;
+ process_information_.hProcess = NULL;
+ process_information_.dwProcessId = 0;
+ return process;
+}
+
+HANDLE ScopedProcessInformation::TakeThreadHandle() {
+ HANDLE thread = process_information_.hThread;
+ process_information_.hThread = NULL;
+ process_information_.dwThreadId = 0;
+ return thread;
+}
+
+void ScopedProcessInformation::Reset() {
+ process_information_.hThread = NULL;
+ process_information_.hProcess = NULL;
+ process_information_.dwProcessId = 0;
+ process_information_.dwThreadId = 0;
+}
+
+} // namespace win
+} // namespace base
diff --git a/base/win/scoped_process_information.h b/base/win/scoped_process_information.h
new file mode 100644
index 0000000..cfa7dc9
--- /dev/null
+++ b/base/win/scoped_process_information.h
@@ -0,0 +1,93 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef BASE_WIN_SCOPED_PROCESS_INFORMATION_H_
+#define BASE_WIN_SCOPED_PROCESS_INFORMATION_H_
+
+#include <windows.h>
+
+#include "base/basictypes.h"
+#include "base/base_export.h"
+
+namespace base {
+namespace win {
+
+// Manages the closing of process and thread handles from PROCESS_INFORMATION
+// structures. Allows clients to take ownership of either handle independently.
+class BASE_EXPORT ScopedProcessInformation {
+ public:
+ // Creates an instance holding a null PROCESS_INFORMATION.
+ ScopedProcessInformation();
+
+ // Closes the held thread and process handles, if any.
+ ~ScopedProcessInformation();
+
+ // Returns a pointer that may be passed to API calls such as CreateProcess.
+ // DCHECKs that the object is not currently holding any handles.
+ // HANDLEs stored in the returned PROCESS_INFORMATION will be owned by this
+ // instance.
+ PROCESS_INFORMATION* Receive();
+
+ // Returns true iff this instance is holding a thread and/or process handle.
+ bool IsValid() const;
+
+ // Closes the held thread and process handles, if any, and resets the held
+ // PROCESS_INFORMATION to null.
+ void Close();
+
+ // Swaps contents with the other ScopedProcessInformation.
+ void Swap(ScopedProcessInformation* other);
+
+ // Populates this instance with duplicate handles and the thread/process IDs
+ // from |other|. Returns false in case of failure, in which case this instance
+ // will be completely unpopulated.
+ bool DuplicateFrom(const ScopedProcessInformation& other);
+
+ // Transfers ownership of the held PROCESS_INFORMATION, if any, away from this
+ // instance. Resets the held PROCESS_INFORMATION to null.
+ PROCESS_INFORMATION Take();
+
+ // Transfers ownership of the held process handle, if any, away from this
+ // instance. The hProcess and dwProcessId members of the held
+ // PROCESS_INFORMATION will be reset.
+ HANDLE TakeProcessHandle();
+
+ // Transfers ownership of the held thread handle, if any, away from this
+ // instance. The hThread and dwThreadId members of the held
+ // PROCESS_INFORMATION will be reset.
+ HANDLE TakeThreadHandle();
+
+ // Returns the held process handle, if any, while retaining ownership.
+ HANDLE process_handle() const {
+ return process_information_.hProcess;
+ }
+
+ // Returns the held thread handle, if any, while retaining ownership.
+ HANDLE thread_handle() const {
+ return process_information_.hThread;
+ }
+
+ // Returns the held process id, if any.
+ DWORD process_id() const {
+ return process_information_.dwProcessId;
+ }
+
+ // Returns the held thread id, if any.
+ DWORD thread_id() const {
+ return process_information_.dwThreadId;
+ }
+
+ private:
+ // Resets the held PROCESS_INFORMATION to null.
+ void Reset();
+
+ PROCESS_INFORMATION process_information_;
+
+ DISALLOW_COPY_AND_ASSIGN(ScopedProcessInformation);
+};
+
+} // namespace win
+} // namespace base
+
+#endif // BASE_WIN_SCOPED_PROCESS_INFORMATION_H_
diff --git a/base/win/scoped_process_information_unittest.cc b/base/win/scoped_process_information_unittest.cc
new file mode 100644
index 0000000..906c156
--- /dev/null
+++ b/base/win/scoped_process_information_unittest.cc
@@ -0,0 +1,181 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include <windows.h>
+
+#include <string>
+
+#include "base/command_line.h"
+#include "base/process_util.h"
+#include "base/test/multiprocess_test.h"
+#include "base/win/scoped_process_information.h"
+#include "testing/multiprocess_func_list.h"
+
+class ScopedProcessInformationTest : public base::MultiProcessTest {
+ protected:
+ void DoCreateProcess(const std::string& main_id,
+ PROCESS_INFORMATION* process_handle);
+};
+
+MULTIPROCESS_TEST_MAIN(ReturnSeven) {
+ return 7;
+}
+
+MULTIPROCESS_TEST_MAIN(ReturnNine) {
+ return 9;
+}
+
+void ScopedProcessInformationTest::DoCreateProcess(
+ const std::string& main_id, PROCESS_INFORMATION* process_handle) {
+ std::wstring cmd_line =
+ this->MakeCmdLine(main_id, false).GetCommandLineString();
+ STARTUPINFO startup_info = {};
+ startup_info.cb = sizeof(startup_info);
+
+ EXPECT_TRUE(::CreateProcess(NULL,
+ const_cast<wchar_t*>(cmd_line.c_str()),
+ NULL, NULL, false, 0, NULL, NULL,
+ &startup_info, process_handle));
+}
+
+TEST_F(ScopedProcessInformationTest, TakeProcess) {
+ base::win::ScopedProcessInformation process_info;
+ DoCreateProcess("ReturnSeven", process_info.Receive());
+ int exit_code = 0;
+ ASSERT_TRUE(base::WaitForExitCode(process_info.TakeProcessHandle(),
+ &exit_code));
+ ASSERT_EQ(7, exit_code);
+ ASSERT_TRUE(process_info.IsValid());
+ ASSERT_EQ(0u, process_info.process_id());
+ ASSERT_TRUE(process_info.process_handle() == NULL);
+ ASSERT_NE(0u, process_info.thread_id());
+ ASSERT_FALSE(process_info.thread_handle() == NULL);
+}
+
+TEST_F(ScopedProcessInformationTest, TakeThread) {
+ base::win::ScopedProcessInformation process_info;
+ DoCreateProcess("ReturnSeven", process_info.Receive());
+ ASSERT_TRUE(::CloseHandle(process_info.TakeThreadHandle()));
+ ASSERT_TRUE(process_info.IsValid());
+ ASSERT_NE(0u, process_info.process_id());
+ ASSERT_FALSE(process_info.process_handle() == NULL);
+ ASSERT_EQ(0u, process_info.thread_id());
+ ASSERT_TRUE(process_info.thread_handle() == NULL);
+}
+
+TEST_F(ScopedProcessInformationTest, TakeBoth) {
+ base::win::ScopedProcessInformation process_info;
+ DoCreateProcess("ReturnSeven", process_info.Receive());
+ int exit_code = 0;
+ ASSERT_TRUE(base::WaitForExitCode(process_info.TakeProcessHandle(),
+ &exit_code));
+ ASSERT_EQ(7, exit_code);
+ ASSERT_TRUE(::CloseHandle(process_info.TakeThreadHandle()));
+ ASSERT_FALSE(process_info.IsValid());
+ ASSERT_EQ(0u, process_info.process_id());
+ ASSERT_TRUE(process_info.process_handle() == NULL);
+ ASSERT_EQ(0u, process_info.thread_id());
+ ASSERT_TRUE(process_info.thread_handle() == NULL);
+}
+
+TEST_F(ScopedProcessInformationTest, TakeNothing) {
+ base::win::ScopedProcessInformation process_info;
+ DoCreateProcess("ReturnSeven", process_info.Receive());
+ ASSERT_TRUE(process_info.IsValid());
+ ASSERT_NE(0u, process_info.thread_id());
+ ASSERT_FALSE(process_info.thread_handle() == NULL);
+ ASSERT_NE(0u, process_info.process_id());
+ ASSERT_FALSE(process_info.process_handle() == NULL);
+}
+
+TEST_F(ScopedProcessInformationTest, TakeWholeStruct) {
+ base::win::ScopedProcessInformation process_info;
+ DoCreateProcess("ReturnSeven", process_info.Receive());
+ base::win::ScopedProcessInformation other;
+ *other.Receive() = process_info.Take();
+
+ ASSERT_FALSE(process_info.IsValid());
+ ASSERT_EQ(0u, process_info.process_id());
+ ASSERT_TRUE(process_info.process_handle() == NULL);
+ ASSERT_EQ(0u, process_info.thread_id());
+ ASSERT_TRUE(process_info.thread_handle() == NULL);
+
+ // Validate that what was taken is good.
+ ASSERT_NE(0u, other.thread_id());
+ ASSERT_NE(0u, other.process_id());
+ int exit_code = 0;
+ ASSERT_TRUE(base::WaitForExitCode(other.TakeProcessHandle(),
+ &exit_code));
+ ASSERT_EQ(7, exit_code);
+ ASSERT_TRUE(::CloseHandle(other.TakeThreadHandle()));
+}
+
+TEST_F(ScopedProcessInformationTest, Duplicate) {
+ base::win::ScopedProcessInformation process_info;
+ DoCreateProcess("ReturnSeven", process_info.Receive());
+ base::win::ScopedProcessInformation duplicate;
+ duplicate.DuplicateFrom(process_info);
+
+ ASSERT_TRUE(process_info.IsValid());
+ ASSERT_NE(0u, process_info.process_id());
+ ASSERT_EQ(duplicate.process_id(), process_info.process_id());
+ ASSERT_NE(0u, process_info.thread_id());
+ ASSERT_EQ(duplicate.thread_id(), process_info.thread_id());
+
+ // Validate that we have separate handles that are good.
+ int exit_code = 0;
+ ASSERT_TRUE(base::WaitForExitCode(process_info.TakeProcessHandle(),
+ &exit_code));
+ ASSERT_EQ(7, exit_code);
+
+ exit_code = 0;
+ ASSERT_TRUE(base::WaitForExitCode(duplicate.TakeProcessHandle(),
+ &exit_code));
+ ASSERT_EQ(7, exit_code);
+
+ ASSERT_TRUE(::CloseHandle(process_info.TakeThreadHandle()));
+ ASSERT_TRUE(::CloseHandle(duplicate.TakeThreadHandle()));
+}
+
+TEST_F(ScopedProcessInformationTest, Swap) {
+ base::win::ScopedProcessInformation seven_to_nine_info;
+ DoCreateProcess("ReturnSeven", seven_to_nine_info.Receive());
+ base::win::ScopedProcessInformation nine_to_seven_info;
+ DoCreateProcess("ReturnNine", nine_to_seven_info.Receive());
+
+ HANDLE seven_process = seven_to_nine_info.process_handle();
+ DWORD seven_process_id = seven_to_nine_info.process_id();
+ HANDLE seven_thread = seven_to_nine_info.thread_handle();
+ DWORD seven_thread_id = seven_to_nine_info.thread_id();
+ HANDLE nine_process = nine_to_seven_info.process_handle();
+ DWORD nine_process_id = nine_to_seven_info.process_id();
+ HANDLE nine_thread = nine_to_seven_info.thread_handle();
+ DWORD nine_thread_id = nine_to_seven_info.thread_id();
+
+ seven_to_nine_info.Swap(&nine_to_seven_info);
+
+ ASSERT_EQ(seven_process, nine_to_seven_info.process_handle());
+ ASSERT_EQ(seven_process_id, nine_to_seven_info.process_id());
+ ASSERT_EQ(seven_thread, nine_to_seven_info.thread_handle());
+ ASSERT_EQ(seven_thread_id, nine_to_seven_info.thread_id());
+ ASSERT_EQ(nine_process, seven_to_nine_info.process_handle());
+ ASSERT_EQ(nine_process_id, seven_to_nine_info.process_id());
+ ASSERT_EQ(nine_thread, seven_to_nine_info.thread_handle());
+ ASSERT_EQ(nine_thread_id, seven_to_nine_info.thread_id());
+
+ int exit_code = 0;
+ ASSERT_TRUE(base::WaitForExitCode(seven_to_nine_info.TakeProcessHandle(),
+ &exit_code));
+ ASSERT_EQ(9, exit_code);
+
+ ASSERT_TRUE(base::WaitForExitCode(nine_to_seven_info.TakeProcessHandle(),
+ &exit_code));
+ ASSERT_EQ(7, exit_code);
+
+}
+
+TEST_F(ScopedProcessInformationTest, InitiallyInvalid) {
+ base::win::ScopedProcessInformation process_info;
+ ASSERT_FALSE(process_info.IsValid());
+}
diff --git a/base/win/scoped_select_object.h b/base/win/scoped_select_object.h
new file mode 100644
index 0000000..347de79
--- /dev/null
+++ b/base/win/scoped_select_object.h
@@ -0,0 +1,43 @@
+// Copyright (c) 2011 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef BASE_WIN_SCOPED_SELECT_OBJECT_H_
+#define BASE_WIN_SCOPED_SELECT_OBJECT_H_
+
+#include <windows.h>
+
+#include "base/basictypes.h"
+#include "base/logging.h"
+
+namespace base {
+namespace win {
+
+// Helper class for deselecting object from DC.
+class ScopedSelectObject {
+ public:
+ ScopedSelectObject(HDC hdc, HGDIOBJ object)
+ : hdc_(hdc),
+ oldobj_(SelectObject(hdc, object)) {
+ DCHECK(hdc_);
+ DCHECK(object);
+ DCHECK(oldobj_ != NULL && oldobj_ != HGDI_ERROR);
+ }
+
+ ~ScopedSelectObject() {
+ HGDIOBJ object = SelectObject(hdc_, oldobj_);
+ DCHECK((GetObjectType(oldobj_) != OBJ_REGION && object != NULL) ||
+ (GetObjectType(oldobj_) == OBJ_REGION && object != HGDI_ERROR));
+ }
+
+ private:
+ HDC hdc_;
+ HGDIOBJ oldobj_;
+
+ DISALLOW_COPY_AND_ASSIGN(ScopedSelectObject);
+};
+
+} // namespace win
+} // namespace base
+
+#endif // BASE_WIN_SCOPED_SELECT_OBJECT_H_
diff --git a/base/win/scoped_variant.cc b/base/win/scoped_variant.cc
new file mode 100644
index 0000000..f57ab93
--- /dev/null
+++ b/base/win/scoped_variant.cc
@@ -0,0 +1,276 @@
+// Copyright (c) 2010 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "base/win/scoped_variant.h"
+#include "base/logging.h"
+
+namespace base {
+namespace win {
+
+// Global, const instance of an empty variant.
+const VARIANT ScopedVariant::kEmptyVariant = { VT_EMPTY };
+
+ScopedVariant::~ScopedVariant() {
+ COMPILE_ASSERT(sizeof(ScopedVariant) == sizeof(VARIANT), ScopedVariantSize);
+ ::VariantClear(&var_);
+}
+
+ScopedVariant::ScopedVariant(const wchar_t* str) {
+ var_.vt = VT_EMPTY;
+ Set(str);
+}
+
+ScopedVariant::ScopedVariant(const wchar_t* str, UINT length) {
+ var_.vt = VT_BSTR;
+ var_.bstrVal = ::SysAllocStringLen(str, length);
+}
+
+ScopedVariant::ScopedVariant(int value, VARTYPE vt) {
+ var_.vt = vt;
+ var_.lVal = value;
+}
+
+ScopedVariant::ScopedVariant(double value, VARTYPE vt) {
+ DCHECK(vt == VT_R8 || vt == VT_DATE);
+ var_.vt = vt;
+ var_.dblVal = value;
+}
+
+ScopedVariant::ScopedVariant(IDispatch* dispatch) {
+ var_.vt = VT_EMPTY;
+ Set(dispatch);
+}
+
+ScopedVariant::ScopedVariant(IUnknown* unknown) {
+ var_.vt = VT_EMPTY;
+ Set(unknown);
+}
+
+ScopedVariant::ScopedVariant(SAFEARRAY* safearray) {
+ var_.vt = VT_EMPTY;
+ Set(safearray);
+}
+
+ScopedVariant::ScopedVariant(const VARIANT& var) {
+ var_.vt = VT_EMPTY;
+ Set(var);
+}
+
+void ScopedVariant::Reset(const VARIANT& var) {
+ if (&var != &var_) {
+ ::VariantClear(&var_);
+ var_ = var;
+ }
+}
+
+VARIANT ScopedVariant::Release() {
+ VARIANT var = var_;
+ var_.vt = VT_EMPTY;
+ return var;
+}
+
+void ScopedVariant::Swap(ScopedVariant& var) {
+ VARIANT tmp = var_;
+ var_ = var.var_;
+ var.var_ = tmp;
+}
+
+VARIANT* ScopedVariant::Receive() {
+ DCHECK(!IsLeakableVarType(var_.vt)) << "variant leak. type: " << var_.vt;
+ return &var_;
+}
+
+VARIANT ScopedVariant::Copy() const {
+ VARIANT ret = { VT_EMPTY };
+ ::VariantCopy(&ret, &var_);
+ return ret;
+}
+
+int ScopedVariant::Compare(const VARIANT& var, bool ignore_case) const {
+ ULONG flags = ignore_case ? NORM_IGNORECASE : 0;
+ HRESULT hr = ::VarCmp(const_cast<VARIANT*>(&var_), const_cast<VARIANT*>(&var),
+ LOCALE_USER_DEFAULT, flags);
+ int ret = 0;
+
+ switch (hr) {
+ case VARCMP_LT:
+ ret = -1;
+ break;
+
+ case VARCMP_GT:
+ case VARCMP_NULL:
+ ret = 1;
+ break;
+
+ default:
+ // Equal.
+ break;
+ }
+
+ return ret;
+}
+
+void ScopedVariant::Set(const wchar_t* str) {
+ DCHECK(!IsLeakableVarType(var_.vt)) << "leaking variant: " << var_.vt;
+ var_.vt = VT_BSTR;
+ var_.bstrVal = ::SysAllocString(str);
+}
+
+void ScopedVariant::Set(int8 i8) {
+ DCHECK(!IsLeakableVarType(var_.vt)) << "leaking variant: " << var_.vt;
+ var_.vt = VT_I1;
+ var_.cVal = i8;
+}
+
+void ScopedVariant::Set(uint8 ui8) {
+ DCHECK(!IsLeakableVarType(var_.vt)) << "leaking variant: " << var_.vt;
+ var_.vt = VT_UI1;
+ var_.bVal = ui8;
+}
+
+void ScopedVariant::Set(int16 i16) {
+ DCHECK(!IsLeakableVarType(var_.vt)) << "leaking variant: " << var_.vt;
+ var_.vt = VT_I2;
+ var_.iVal = i16;
+}
+
+void ScopedVariant::Set(uint16 ui16) {
+ DCHECK(!IsLeakableVarType(var_.vt)) << "leaking variant: " << var_.vt;
+ var_.vt = VT_UI2;
+ var_.uiVal = ui16;
+}
+
+void ScopedVariant::Set(int32 i32) {
+ DCHECK(!IsLeakableVarType(var_.vt)) << "leaking variant: " << var_.vt;
+ var_.vt = VT_I4;
+ var_.lVal = i32;
+}
+
+void ScopedVariant::Set(uint32 ui32) {
+ DCHECK(!IsLeakableVarType(var_.vt)) << "leaking variant: " << var_.vt;
+ var_.vt = VT_UI4;
+ var_.ulVal = ui32;
+}
+
+void ScopedVariant::Set(int64 i64) {
+ DCHECK(!IsLeakableVarType(var_.vt)) << "leaking variant: " << var_.vt;
+ var_.vt = VT_I8;
+ var_.llVal = i64;
+}
+
+void ScopedVariant::Set(uint64 ui64) {
+ DCHECK(!IsLeakableVarType(var_.vt)) << "leaking variant: " << var_.vt;
+ var_.vt = VT_UI8;
+ var_.ullVal = ui64;
+}
+
+void ScopedVariant::Set(float r32) {
+ DCHECK(!IsLeakableVarType(var_.vt)) << "leaking variant: " << var_.vt;
+ var_.vt = VT_R4;
+ var_.fltVal = r32;
+}
+
+void ScopedVariant::Set(double r64) {
+ DCHECK(!IsLeakableVarType(var_.vt)) << "leaking variant: " << var_.vt;
+ var_.vt = VT_R8;
+ var_.dblVal = r64;
+}
+
+void ScopedVariant::SetDate(DATE date) {
+ DCHECK(!IsLeakableVarType(var_.vt)) << "leaking variant: " << var_.vt;
+ var_.vt = VT_DATE;
+ var_.date = date;
+}
+
+void ScopedVariant::Set(IDispatch* disp) {
+ DCHECK(!IsLeakableVarType(var_.vt)) << "leaking variant: " << var_.vt;
+ var_.vt = VT_DISPATCH;
+ var_.pdispVal = disp;
+ if (disp)
+ disp->AddRef();
+}
+
+void ScopedVariant::Set(bool b) {
+ DCHECK(!IsLeakableVarType(var_.vt)) << "leaking variant: " << var_.vt;
+ var_.vt = VT_BOOL;
+ var_.boolVal = b ? VARIANT_TRUE : VARIANT_FALSE;
+}
+
+void ScopedVariant::Set(IUnknown* unk) {
+ DCHECK(!IsLeakableVarType(var_.vt)) << "leaking variant: " << var_.vt;
+ var_.vt = VT_UNKNOWN;
+ var_.punkVal = unk;
+ if (unk)
+ unk->AddRef();
+}
+
+void ScopedVariant::Set(SAFEARRAY* array) {
+ DCHECK(!IsLeakableVarType(var_.vt)) << "leaking variant: " << var_.vt;
+ if (SUCCEEDED(::SafeArrayGetVartype(array, &var_.vt))) {
+ var_.vt |= VT_ARRAY;
+ var_.parray = array;
+ } else {
+ DCHECK(!array) << "Unable to determine safearray vartype";
+ var_.vt = VT_EMPTY;
+ }
+}
+
+void ScopedVariant::Set(const VARIANT& var) {
+ DCHECK(!IsLeakableVarType(var_.vt)) << "leaking variant: " << var_.vt;
+ if (FAILED(::VariantCopy(&var_, &var))) {
+ DLOG(ERROR) << "VariantCopy failed";
+ var_.vt = VT_EMPTY;
+ }
+}
+
+ScopedVariant& ScopedVariant::operator=(const VARIANT& var) {
+ if (&var != &var_) {
+ VariantClear(&var_);
+ Set(var);
+ }
+ return *this;
+}
+
+bool ScopedVariant::IsLeakableVarType(VARTYPE vt) {
+ bool leakable = false;
+ switch (vt & VT_TYPEMASK) {
+ case VT_BSTR:
+ case VT_DISPATCH:
+ // we treat VT_VARIANT as leakable to err on the safe side.
+ case VT_VARIANT:
+ case VT_UNKNOWN:
+ case VT_SAFEARRAY:
+
+ // very rarely used stuff (if ever):
+ case VT_VOID:
+ case VT_PTR:
+ case VT_CARRAY:
+ case VT_USERDEFINED:
+ case VT_LPSTR:
+ case VT_LPWSTR:
+ case VT_RECORD:
+ case VT_INT_PTR:
+ case VT_UINT_PTR:
+ case VT_FILETIME:
+ case VT_BLOB:
+ case VT_STREAM:
+ case VT_STORAGE:
+ case VT_STREAMED_OBJECT:
+ case VT_STORED_OBJECT:
+ case VT_BLOB_OBJECT:
+ case VT_VERSIONED_STREAM:
+ case VT_BSTR_BLOB:
+ leakable = true;
+ break;
+ }
+
+ if (!leakable && (vt & VT_ARRAY) != 0) {
+ leakable = true;
+ }
+
+ return leakable;
+}
+
+} // namespace win
+} // namespace base
diff --git a/base/win/scoped_variant.h b/base/win/scoped_variant.h
new file mode 100644
index 0000000..b6e6579
--- /dev/null
+++ b/base/win/scoped_variant.h
@@ -0,0 +1,166 @@
+// Copyright (c) 2011 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef BASE_WIN_SCOPED_VARIANT_H_
+#define BASE_WIN_SCOPED_VARIANT_H_
+
+#include <windows.h>
+#include <oleauto.h>
+
+#include "base/base_export.h"
+#include "base/basictypes.h"
+
+namespace base {
+namespace win {
+
+// Scoped VARIANT class for automatically freeing a COM VARIANT at the
+// end of a scope. Additionally provides a few functions to make the
+// encapsulated VARIANT easier to use.
+// Instead of inheriting from VARIANT, we take the containment approach
+// in order to have more control over the usage of the variant and guard
+// against memory leaks.
+class BASE_EXPORT ScopedVariant {
+ public:
+ // Declaration of a global variant variable that's always VT_EMPTY
+ static const VARIANT kEmptyVariant;
+
+ // Default constructor.
+ ScopedVariant() {
+ // This is equivalent to what VariantInit does, but less code.
+ var_.vt = VT_EMPTY;
+ }
+
+ // Constructor to create a new VT_BSTR VARIANT.
+ // NOTE: Do not pass a BSTR to this constructor expecting ownership to
+ // be transferred
+ explicit ScopedVariant(const wchar_t* str);
+
+ // Creates a new VT_BSTR variant of a specified length.
+ ScopedVariant(const wchar_t* str, UINT length);
+
+ // Creates a new integral type variant and assigns the value to
+ // VARIANT.lVal (32 bit sized field).
+ explicit ScopedVariant(int value, VARTYPE vt = VT_I4);
+
+ // Creates a new double-precision type variant. |vt| must be either VT_R8
+ // or VT_DATE.
+ explicit ScopedVariant(double value, VARTYPE vt = VT_R8);
+
+ // VT_DISPATCH
+ explicit ScopedVariant(IDispatch* dispatch);
+
+ // VT_UNKNOWN
+ explicit ScopedVariant(IUnknown* unknown);
+
+ // SAFEARRAY
+ explicit ScopedVariant(SAFEARRAY* safearray);
+
+ // Copies the variant.
+ explicit ScopedVariant(const VARIANT& var);
+
+ ~ScopedVariant();
+
+ inline VARTYPE type() const {
+ return var_.vt;
+ }
+
+ // Give ScopedVariant ownership over an already allocated VARIANT.
+ void Reset(const VARIANT& var = kEmptyVariant);
+
+ // Releases ownership of the VARIANT to the caller.
+ VARIANT Release();
+
+ // Swap two ScopedVariant's.
+ void Swap(ScopedVariant& var);
+
+ // Returns a copy of the variant.
+ VARIANT Copy() const;
+
+ // The return value is 0 if the variants are equal, 1 if this object is
+ // greater than |var|, -1 if it is smaller.
+ int Compare(const VARIANT& var, bool ignore_case = false) const;
+
+ // Retrieves the pointer address.
+ // Used to receive a VARIANT as an out argument (and take ownership).
+ // The function DCHECKs on the current value being empty/null.
+ // Usage: GetVariant(var.receive());
+ VARIANT* Receive();
+
+ void Set(const wchar_t* str);
+
+ // Setters for simple types.
+ void Set(int8 i8);
+ void Set(uint8 ui8);
+ void Set(int16 i16);
+ void Set(uint16 ui16);
+ void Set(int32 i32);
+ void Set(uint32 ui32);
+ void Set(int64 i64);
+ void Set(uint64 ui64);
+ void Set(float r32);
+ void Set(double r64);
+ void Set(bool b);
+
+ // Creates a copy of |var| and assigns as this instance's value.
+ // Note that this is different from the Reset() method that's used to
+ // free the current value and assume ownership.
+ void Set(const VARIANT& var);
+
+ // COM object setters
+ void Set(IDispatch* disp);
+ void Set(IUnknown* unk);
+
+ // SAFEARRAY support
+ void Set(SAFEARRAY* array);
+
+ // Special setter for DATE since DATE is a double and we already have
+ // a setter for double.
+ void SetDate(DATE date);
+
+ // Allows const access to the contained variant without DCHECKs etc.
+ // This support is necessary for the V_XYZ (e.g. V_BSTR) set of macros to
+ // work properly but still doesn't allow modifications since we want control
+ // over that.
+ const VARIANT* operator&() const {
+ return &var_;
+ }
+
+ // Like other scoped classes (e.g scoped_refptr, ScopedComPtr, ScopedBstr)
+ // we support the assignment operator for the type we wrap.
+ ScopedVariant& operator=(const VARIANT& var);
+
+ // A hack to pass a pointer to the variant where the accepting
+ // function treats the variant as an input-only, read-only value
+ // but the function prototype requires a non const variant pointer.
+ // There's no DCHECK or anything here. Callers must know what they're doing.
+ VARIANT* AsInput() const {
+ // The nature of this function is const, so we declare
+ // it as such and cast away the constness here.
+ return const_cast<VARIANT*>(&var_);
+ }
+
+ // Allows the ScopedVariant instance to be passed to functions either by value
+ // or by const reference.
+ operator const VARIANT&() const {
+ return var_;
+ }
+
+ // Used as a debug check to see if we're leaking anything.
+ static bool IsLeakableVarType(VARTYPE vt);
+
+ protected:
+ VARIANT var_;
+
+ private:
+ // Comparison operators for ScopedVariant are not supported at this point.
+ // Use the Compare method instead.
+ bool operator==(const ScopedVariant& var) const;
+ bool operator!=(const ScopedVariant& var) const;
+ DISALLOW_COPY_AND_ASSIGN(ScopedVariant);
+};
+
+} // namespace win
+} // namesoace base
+
+#endif // BASE_WIN_SCOPED_VARIANT_H_
diff --git a/base/win/scoped_variant_unittest.cc b/base/win/scoped_variant_unittest.cc
new file mode 100644
index 0000000..1f017cf
--- /dev/null
+++ b/base/win/scoped_variant_unittest.cc
@@ -0,0 +1,261 @@
+// Copyright (c) 2011 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "base/win/scoped_variant.h"
+#include "testing/gtest/include/gtest/gtest.h"
+
+namespace base {
+namespace win {
+
+namespace {
+
+static const wchar_t kTestString1[] = L"Used to create BSTRs";
+static const wchar_t kTestString2[] = L"Also used to create BSTRs";
+
+void GiveMeAVariant(VARIANT* ret) {
+ EXPECT_TRUE(ret != NULL);
+ ret->vt = VT_BSTR;
+ V_BSTR(ret) = ::SysAllocString(kTestString1);
+}
+
+// A dummy IDispatch implementation (if you can call it that).
+// The class does nothing intelligent really. Only increments a counter
+// when AddRef is called and decrements it when Release is called.
+class FakeComObject : public IDispatch {
+ public:
+ FakeComObject() : ref_(0) {
+ }
+
+ STDMETHOD_(DWORD, AddRef)() {
+ ref_++;
+ return ref_;
+ }
+
+ STDMETHOD_(DWORD, Release)() {
+ ref_--;
+ return ref_;
+ }
+
+ STDMETHOD(QueryInterface)(REFIID, void**) {
+ return E_NOTIMPL;
+ }
+
+ STDMETHOD(GetTypeInfoCount)(UINT*) {
+ return E_NOTIMPL;
+ }
+
+ STDMETHOD(GetTypeInfo)(UINT, LCID, ITypeInfo**) {
+ return E_NOTIMPL;
+ }
+
+ STDMETHOD(GetIDsOfNames)(REFIID, LPOLESTR*, UINT, LCID, DISPID*) {
+ return E_NOTIMPL;
+ }
+
+ STDMETHOD(Invoke)(DISPID, REFIID, LCID, WORD, DISPPARAMS*, VARIANT*,
+ EXCEPINFO*, UINT*) {
+ return E_NOTIMPL;
+ }
+
+ // A way to check the internal reference count of the class.
+ int ref_count() const {
+ return ref_;
+ }
+
+ protected:
+ int ref_;
+};
+
+} // namespace
+
+TEST(ScopedVariantTest, ScopedVariant) {
+ ScopedVariant var;
+ EXPECT_TRUE(var.type() == VT_EMPTY);
+ // V_BSTR(&var) = NULL; <- NOTE: Assignment like that is not supported
+
+ ScopedVariant var_bstr(L"VT_BSTR");
+ EXPECT_EQ(VT_BSTR, V_VT(&var_bstr));
+ EXPECT_TRUE(V_BSTR(&var_bstr) != NULL); // can't use EXPECT_NE for BSTR
+ var_bstr.Reset();
+ EXPECT_NE(VT_BSTR, V_VT(&var_bstr));
+ var_bstr.Set(kTestString2);
+ EXPECT_EQ(VT_BSTR, V_VT(&var_bstr));
+
+ VARIANT tmp = var_bstr.Release();
+ EXPECT_EQ(VT_EMPTY, V_VT(&var_bstr));
+ EXPECT_EQ(VT_BSTR, V_VT(&tmp));
+ EXPECT_EQ(0, lstrcmp(V_BSTR(&tmp), kTestString2));
+
+ var.Reset(tmp);
+ EXPECT_EQ(VT_BSTR, V_VT(&var));
+ EXPECT_EQ(0, lstrcmpW(V_BSTR(&var), kTestString2));
+
+ var_bstr.Swap(var);
+ EXPECT_EQ(VT_EMPTY, V_VT(&var));
+ EXPECT_EQ(VT_BSTR, V_VT(&var_bstr));
+ EXPECT_EQ(0, lstrcmpW(V_BSTR(&var_bstr), kTestString2));
+ var_bstr.Reset();
+
+ // Test the Compare and Copy routines.
+ GiveMeAVariant(var_bstr.Receive());
+ ScopedVariant var_bstr2(V_BSTR(&var_bstr));
+ EXPECT_EQ(0, var_bstr.Compare(var_bstr2));
+ var_bstr2.Reset();
+ EXPECT_NE(0, var_bstr.Compare(var_bstr2));
+ var_bstr2.Reset(var_bstr.Copy());
+ EXPECT_EQ(0, var_bstr.Compare(var_bstr2));
+ var_bstr2.Reset();
+ var_bstr2.Set(V_BSTR(&var_bstr));
+ EXPECT_EQ(0, var_bstr.Compare(var_bstr2));
+ var_bstr2.Reset();
+ var_bstr.Reset();
+
+ // Test for the SetDate setter.
+ SYSTEMTIME sys_time;
+ ::GetSystemTime(&sys_time);
+ DATE date;
+ ::SystemTimeToVariantTime(&sys_time, &date);
+ var.Reset();
+ var.SetDate(date);
+ EXPECT_EQ(VT_DATE, var.type());
+ EXPECT_EQ(date, V_DATE(&var));
+
+ // Simple setter tests. These do not require resetting the variant
+ // after each test since the variant type is not "leakable" (i.e. doesn't
+ // need to be freed explicitly).
+
+ // We need static cast here since char defaults to int (!?).
+ var.Set(static_cast<int8>('v'));
+ EXPECT_EQ(VT_I1, var.type());
+ EXPECT_EQ('v', V_I1(&var));
+
+ var.Set(static_cast<short>(123));
+ EXPECT_EQ(VT_I2, var.type());
+ EXPECT_EQ(123, V_I2(&var));
+
+ var.Set(static_cast<int32>(123));
+ EXPECT_EQ(VT_I4, var.type());
+ EXPECT_EQ(123, V_I4(&var));
+
+ var.Set(static_cast<int64>(123));
+ EXPECT_EQ(VT_I8, var.type());
+ EXPECT_EQ(123, V_I8(&var));
+
+ var.Set(static_cast<uint8>(123));
+ EXPECT_EQ(VT_UI1, var.type());
+ EXPECT_EQ(123, V_UI1(&var));
+
+ var.Set(static_cast<unsigned short>(123));
+ EXPECT_EQ(VT_UI2, var.type());
+ EXPECT_EQ(123, V_UI2(&var));
+
+ var.Set(static_cast<uint32>(123));
+ EXPECT_EQ(VT_UI4, var.type());
+ EXPECT_EQ(123, V_UI4(&var));
+
+ var.Set(static_cast<uint64>(123));
+ EXPECT_EQ(VT_UI8, var.type());
+ EXPECT_EQ(123, V_UI8(&var));
+
+ var.Set(123.123f);
+ EXPECT_EQ(VT_R4, var.type());
+ EXPECT_EQ(123.123f, V_R4(&var));
+
+ var.Set(static_cast<double>(123.123));
+ EXPECT_EQ(VT_R8, var.type());
+ EXPECT_EQ(123.123, V_R8(&var));
+
+ var.Set(true);
+ EXPECT_EQ(VT_BOOL, var.type());
+ EXPECT_EQ(VARIANT_TRUE, V_BOOL(&var));
+ var.Set(false);
+ EXPECT_EQ(VT_BOOL, var.type());
+ EXPECT_EQ(VARIANT_FALSE, V_BOOL(&var));
+
+ // Com interface tests
+
+ var.Set(static_cast<IDispatch*>(NULL));
+ EXPECT_EQ(VT_DISPATCH, var.type());
+ EXPECT_EQ(NULL, V_DISPATCH(&var));
+ var.Reset();
+
+ var.Set(static_cast<IUnknown*>(NULL));
+ EXPECT_EQ(VT_UNKNOWN, var.type());
+ EXPECT_EQ(NULL, V_UNKNOWN(&var));
+ var.Reset();
+
+ FakeComObject faker;
+ EXPECT_EQ(0, faker.ref_count());
+ var.Set(static_cast<IDispatch*>(&faker));
+ EXPECT_EQ(VT_DISPATCH, var.type());
+ EXPECT_EQ(&faker, V_DISPATCH(&var));
+ EXPECT_EQ(1, faker.ref_count());
+ var.Reset();
+ EXPECT_EQ(0, faker.ref_count());
+
+ var.Set(static_cast<IUnknown*>(&faker));
+ EXPECT_EQ(VT_UNKNOWN, var.type());
+ EXPECT_EQ(&faker, V_UNKNOWN(&var));
+ EXPECT_EQ(1, faker.ref_count());
+ var.Reset();
+ EXPECT_EQ(0, faker.ref_count());
+
+ {
+ ScopedVariant disp_var(&faker);
+ EXPECT_EQ(VT_DISPATCH, disp_var.type());
+ EXPECT_EQ(&faker, V_DISPATCH(&disp_var));
+ EXPECT_EQ(1, faker.ref_count());
+ }
+ EXPECT_EQ(0, faker.ref_count());
+
+ {
+ ScopedVariant ref1(&faker);
+ EXPECT_EQ(1, faker.ref_count());
+ ScopedVariant ref2(static_cast<const VARIANT&>(ref1));
+ EXPECT_EQ(2, faker.ref_count());
+ ScopedVariant ref3;
+ ref3 = static_cast<const VARIANT&>(ref2);
+ EXPECT_EQ(3, faker.ref_count());
+ }
+ EXPECT_EQ(0, faker.ref_count());
+
+ {
+ ScopedVariant unk_var(static_cast<IUnknown*>(&faker));
+ EXPECT_EQ(VT_UNKNOWN, unk_var.type());
+ EXPECT_EQ(&faker, V_UNKNOWN(&unk_var));
+ EXPECT_EQ(1, faker.ref_count());
+ }
+ EXPECT_EQ(0, faker.ref_count());
+
+ VARIANT raw;
+ raw.vt = VT_UNKNOWN;
+ raw.punkVal = &faker;
+ EXPECT_EQ(0, faker.ref_count());
+ var.Set(raw);
+ EXPECT_EQ(1, faker.ref_count());
+ var.Reset();
+ EXPECT_EQ(0, faker.ref_count());
+
+ {
+ ScopedVariant number(123);
+ EXPECT_EQ(VT_I4, number.type());
+ EXPECT_EQ(123, V_I4(&number));
+ }
+
+ // SAFEARRAY tests
+ var.Set(static_cast<SAFEARRAY*>(NULL));
+ EXPECT_EQ(VT_EMPTY, var.type());
+
+ SAFEARRAY* sa = ::SafeArrayCreateVector(VT_UI1, 0, 100);
+ ASSERT_TRUE(sa != NULL);
+
+ var.Set(sa);
+ EXPECT_TRUE(ScopedVariant::IsLeakableVarType(var.type()));
+ EXPECT_EQ(VT_ARRAY | VT_UI1, var.type());
+ EXPECT_EQ(sa, V_ARRAY(&var));
+ // The array is destroyed in the destructor of var.
+}
+
+} // namespace win
+} // namespace base
diff --git a/base/win/shortcut.cc b/base/win/shortcut.cc
new file mode 100644
index 0000000..8afd55d
--- /dev/null
+++ b/base/win/shortcut.cc
@@ -0,0 +1,233 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "base/win/shortcut.h"
+
+#include <shellapi.h>
+#include <shlobj.h>
+#include <propkey.h>
+
+#include "base/file_util.h"
+#include "base/threading/thread_restrictions.h"
+#include "base/win/scoped_comptr.h"
+#include "base/win/win_util.h"
+#include "base/win/windows_version.h"
+
+namespace base {
+namespace win {
+
+namespace {
+
+// Initializes |i_shell_link| and |i_persist_file| (releasing them first if they
+// are already initialized).
+// If |shortcut| is not NULL, loads |shortcut| into |i_persist_file|.
+// If any of the above steps fail, both |i_shell_link| and |i_persist_file| will
+// be released.
+void InitializeShortcutInterfaces(
+ const wchar_t* shortcut,
+ ScopedComPtr<IShellLink>* i_shell_link,
+ ScopedComPtr<IPersistFile>* i_persist_file) {
+ i_shell_link->Release();
+ i_persist_file->Release();
+ if (FAILED(i_shell_link->CreateInstance(CLSID_ShellLink, NULL,
+ CLSCTX_INPROC_SERVER)) ||
+ FAILED(i_persist_file->QueryFrom(*i_shell_link)) ||
+ (shortcut && FAILED((*i_persist_file)->Load(shortcut, STGM_READWRITE)))) {
+ i_shell_link->Release();
+ i_persist_file->Release();
+ }
+}
+
+} // namespace
+
+bool CreateOrUpdateShortcutLink(const FilePath& shortcut_path,
+ const ShortcutProperties& properties,
+ ShortcutOperation operation) {
+ base::ThreadRestrictions::AssertIOAllowed();
+
+ // A target is required unless |operation| is SHORTCUT_UPDATE_EXISTING.
+ if (operation != SHORTCUT_UPDATE_EXISTING &&
+ !(properties.options & ShortcutProperties::PROPERTIES_TARGET)) {
+ NOTREACHED();
+ return false;
+ }
+
+ bool shortcut_existed = file_util::PathExists(shortcut_path);
+
+ ScopedComPtr<IShellLink> i_shell_link;
+ ScopedComPtr<IPersistFile> i_persist_file;
+ switch (operation) {
+ case SHORTCUT_CREATE_ALWAYS:
+ InitializeShortcutInterfaces(NULL, &i_shell_link, &i_persist_file);
+ break;
+ case SHORTCUT_UPDATE_EXISTING:
+ InitializeShortcutInterfaces(shortcut_path.value().c_str(), &i_shell_link,
+ &i_persist_file);
+ break;
+ case SHORTCUT_REPLACE_EXISTING:
+ InitializeShortcutInterfaces(shortcut_path.value().c_str(), &i_shell_link,
+ &i_persist_file);
+ // Confirm |shortcut_path| exists and is a shortcut by verifying
+ // |i_persist_file| was successfully initialized in the call above. If so,
+ // re-initialize the interfaces to begin writing a new shortcut (to
+ // overwrite the current one if successful).
+ if (i_persist_file.get())
+ InitializeShortcutInterfaces(NULL, &i_shell_link, &i_persist_file);
+ break;
+ default:
+ NOTREACHED();
+ }
+
+ // Return false immediately upon failure to initialize shortcut interfaces.
+ if (!i_persist_file.get())
+ return false;
+
+ if ((properties.options & ShortcutProperties::PROPERTIES_TARGET) &&
+ FAILED(i_shell_link->SetPath(properties.target.value().c_str()))) {
+ return false;
+ }
+
+ if ((properties.options & ShortcutProperties::PROPERTIES_WORKING_DIR) &&
+ FAILED(i_shell_link->SetWorkingDirectory(
+ properties.working_dir.value().c_str()))) {
+ return false;
+ }
+
+ if ((properties.options & ShortcutProperties::PROPERTIES_ARGUMENTS) &&
+ FAILED(i_shell_link->SetArguments(properties.arguments.c_str()))) {
+ return false;
+ }
+
+ if ((properties.options & ShortcutProperties::PROPERTIES_DESCRIPTION) &&
+ FAILED(i_shell_link->SetDescription(properties.description.c_str()))) {
+ return false;
+ }
+
+ if ((properties.options & ShortcutProperties::PROPERTIES_ICON) &&
+ FAILED(i_shell_link->SetIconLocation(properties.icon.value().c_str(),
+ properties.icon_index))) {
+ return false;
+ }
+
+ bool has_app_id =
+ (properties.options & ShortcutProperties::PROPERTIES_APP_ID) != 0;
+ bool has_dual_mode =
+ (properties.options & ShortcutProperties::PROPERTIES_DUAL_MODE) != 0;
+ if ((has_app_id || has_dual_mode) &&
+ GetVersion() >= VERSION_WIN7) {
+ ScopedComPtr<IPropertyStore> property_store;
+ if (FAILED(property_store.QueryFrom(i_shell_link)) || !property_store.get())
+ return false;
+
+ if (has_app_id &&
+ !SetAppIdForPropertyStore(property_store, properties.app_id.c_str())) {
+ return false;
+ }
+ if (has_dual_mode &&
+ !SetBooleanValueForPropertyStore(property_store,
+ PKEY_AppUserModel_IsDualMode,
+ properties.dual_mode)) {
+ return false;
+ }
+ }
+
+ HRESULT result = i_persist_file->Save(shortcut_path.value().c_str(), TRUE);
+
+ // Release the interfaces in case the SHChangeNotify call below depends on
+ // the operations above being fully completed.
+ i_persist_file.Release();
+ i_shell_link.Release();
+
+ // If we successfully created/updated the icon, notify the shell that we have
+ // done so.
+ const bool succeeded = SUCCEEDED(result);
+ if (succeeded) {
+ if (shortcut_existed) {
+ // TODO(gab): SHCNE_UPDATEITEM might be sufficient here; further testing
+ // required.
+ SHChangeNotify(SHCNE_ASSOCCHANGED, SHCNF_IDLIST, NULL, NULL);
+ } else {
+ SHChangeNotify(SHCNE_CREATE, SHCNF_PATH, shortcut_path.value().c_str(),
+ NULL);
+ }
+ }
+
+ return succeeded;
+}
+
+bool ResolveShortcut(const FilePath& shortcut_path,
+ FilePath* target_path,
+ string16* args) {
+ base::ThreadRestrictions::AssertIOAllowed();
+
+ HRESULT result;
+ ScopedComPtr<IShellLink> i_shell_link;
+
+ // Get pointer to the IShellLink interface.
+ result = i_shell_link.CreateInstance(CLSID_ShellLink, NULL,
+ CLSCTX_INPROC_SERVER);
+ if (FAILED(result))
+ return false;
+
+ ScopedComPtr<IPersistFile> persist;
+ // Query IShellLink for the IPersistFile interface.
+ result = persist.QueryFrom(i_shell_link);
+ if (FAILED(result))
+ return false;
+
+ // Load the shell link.
+ result = persist->Load(shortcut_path.value().c_str(), STGM_READ);
+ if (FAILED(result))
+ return false;
+
+ WCHAR temp[MAX_PATH];
+ if (target_path) {
+ // Try to find the target of a shortcut.
+ result = i_shell_link->Resolve(0, SLR_NO_UI);
+ if (FAILED(result))
+ return false;
+
+ result = i_shell_link->GetPath(temp, MAX_PATH, NULL, SLGP_UNCPRIORITY);
+ if (FAILED(result))
+ return false;
+
+ *target_path = FilePath(temp);
+ }
+
+ if (args) {
+ result = i_shell_link->GetArguments(temp, MAX_PATH);
+ if (FAILED(result))
+ return false;
+
+ *args = string16(temp);
+ }
+ return true;
+}
+
+bool TaskbarPinShortcutLink(const wchar_t* shortcut) {
+ base::ThreadRestrictions::AssertIOAllowed();
+
+ // "Pin to taskbar" is only supported after Win7.
+ if (GetVersion() < VERSION_WIN7)
+ return false;
+
+ int result = reinterpret_cast<int>(ShellExecute(NULL, L"taskbarpin", shortcut,
+ NULL, NULL, 0));
+ return result > 32;
+}
+
+bool TaskbarUnpinShortcutLink(const wchar_t* shortcut) {
+ base::ThreadRestrictions::AssertIOAllowed();
+
+ // "Unpin from taskbar" is only supported after Win7.
+ if (base::win::GetVersion() < base::win::VERSION_WIN7)
+ return false;
+
+ int result = reinterpret_cast<int>(ShellExecute(NULL, L"taskbarunpin",
+ shortcut, NULL, NULL, 0));
+ return result > 32;
+}
+
+} // namespace win
+} // namespace base
diff --git a/base/win/shortcut.h b/base/win/shortcut.h
new file mode 100644
index 0000000..c1e7d5c
--- /dev/null
+++ b/base/win/shortcut.h
@@ -0,0 +1,141 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef BASE_WIN_SHORTCUT_H_
+#define BASE_WIN_SHORTCUT_H_
+
+#include <windows.h>
+
+#include "base/logging.h"
+#include "base/file_path.h"
+#include "base/string16.h"
+
+namespace base {
+namespace win {
+
+enum ShortcutOperation {
+ // Create a new shortcut (overwriting if necessary).
+ SHORTCUT_CREATE_ALWAYS = 0,
+ // Overwrite an existing shortcut (fails if the shortcut doesn't exist).
+ SHORTCUT_REPLACE_EXISTING,
+ // Update specified properties only on an existing shortcut.
+ SHORTCUT_UPDATE_EXISTING,
+};
+
+// Properties for shortcuts. Properties set will be applied to the shortcut on
+// creation/update, others will be ignored.
+// Callers are encouraged to use the setters provided which take care of
+// setting |options| as desired.
+struct ShortcutProperties {
+ enum IndividualProperties {
+ PROPERTIES_TARGET = 1 << 0,
+ PROPERTIES_WORKING_DIR = 1 << 1,
+ PROPERTIES_ARGUMENTS = 1 << 2,
+ PROPERTIES_DESCRIPTION = 1 << 3,
+ PROPERTIES_ICON = 1 << 4,
+ PROPERTIES_APP_ID = 1 << 5,
+ PROPERTIES_DUAL_MODE = 1 << 6,
+ };
+
+ ShortcutProperties() : icon_index(-1), dual_mode(false), options(0U) {}
+
+ void set_target(const FilePath& target_in) {
+ target = target_in;
+ options |= PROPERTIES_TARGET;
+ }
+
+ void set_working_dir(const FilePath& working_dir_in) {
+ working_dir = working_dir_in;
+ options |= PROPERTIES_WORKING_DIR;
+ }
+
+ void set_arguments(const string16& arguments_in) {
+ // Size restriction as per MSDN at http://goo.gl/TJ7q5.
+ DCHECK(arguments_in.size() < MAX_PATH);
+ arguments = arguments_in;
+ options |= PROPERTIES_ARGUMENTS;
+ }
+
+ void set_description(const string16& description_in) {
+ // Size restriction as per MSDN at http://goo.gl/OdNQq.
+ DCHECK(description_in.size() < MAX_PATH);
+ description = description_in;
+ options |= PROPERTIES_DESCRIPTION;
+ }
+
+ void set_icon(const FilePath& icon_in, int icon_index_in) {
+ icon = icon_in;
+ icon_index = icon_index_in;
+ options |= PROPERTIES_ICON;
+ }
+
+ void set_app_id(const string16& app_id_in) {
+ app_id = app_id_in;
+ options |= PROPERTIES_APP_ID;
+ }
+
+ void set_dual_mode(bool dual_mode_in) {
+ dual_mode = dual_mode_in;
+ options |= PROPERTIES_DUAL_MODE;
+ }
+
+ // The target to launch from this shortcut. This is mandatory when creating
+ // a shortcut.
+ FilePath target;
+ // The name of the working directory when launching the shortcut.
+ FilePath working_dir;
+ // The arguments to be applied to |target| when launching from this shortcut.
+ // The length of this string must be less than MAX_PATH.
+ string16 arguments;
+ // The localized description of the shortcut.
+ // The length of this string must be less than MAX_PATH.
+ string16 description;
+ // The path to the icon (can be a dll or exe, in which case |icon_index| is
+ // the resource id).
+ FilePath icon;
+ int icon_index;
+ // The app model id for the shortcut (Win7+).
+ string16 app_id;
+ // Whether this is a dual mode shortcut (Win8+).
+ bool dual_mode;
+ // Bitfield made of IndividualProperties. Properties set in |options| will be
+ // set on the shortcut, others will be ignored.
+ uint32 options;
+};
+
+// This method creates (or updates) a shortcut link at |shortcut_path| using the
+// information given through |properties|.
+// Ensure you have initialized COM before calling into this function.
+// |operation|: a choice from the ShortcutOperation enum.
+// If |operation| is SHORTCUT_REPLACE_EXISTING or SHORTCUT_UPDATE_EXISTING and
+// |shortcut_path| does not exist, this method is a no-op and returns false.
+BASE_EXPORT bool CreateOrUpdateShortcutLink(
+ const FilePath& shortcut_path,
+ const ShortcutProperties& properties,
+ ShortcutOperation operation);
+
+// Resolve Windows shortcut (.LNK file)
+// This methods tries to resolve a shortcut .LNK file. The path of the shortcut
+// to resolve is in |shortcut_path|. If |target_path| is not NULL, the target
+// will be resolved and placed in |target_path|. If |args| is not NULL, the
+// arguments will be retrieved and placed in |args|. The function returns true
+// if all requested fields are found successfully.
+// Callers can safely use the same variable for both |shortcut_path| and
+// |target_path|.
+BASE_EXPORT bool ResolveShortcut(const FilePath& shortcut_path,
+ FilePath* target_path,
+ string16* args);
+
+// Pins a shortcut to the Windows 7 taskbar. The shortcut file must already
+// exist and be a shortcut that points to an executable.
+BASE_EXPORT bool TaskbarPinShortcutLink(const wchar_t* shortcut);
+
+// Unpins a shortcut from the Windows 7 taskbar. The shortcut must exist and
+// already be pinned to the taskbar.
+BASE_EXPORT bool TaskbarUnpinShortcutLink(const wchar_t* shortcut);
+
+} // namespace win
+} // namespace base
+
+#endif // BASE_WIN_SHORTCUT_H_
diff --git a/base/win/shortcut_unittest.cc b/base/win/shortcut_unittest.cc
new file mode 100644
index 0000000..e13bc28
--- /dev/null
+++ b/base/win/shortcut_unittest.cc
@@ -0,0 +1,250 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "base/win/shortcut.h"
+
+#include <string>
+
+#include "base/file_path.h"
+#include "base/file_util.h"
+#include "base/scoped_temp_dir.h"
+#include "base/test/test_file_util.h"
+#include "base/test/test_shortcut_win.h"
+#include "base/win/scoped_com_initializer.h"
+#include "testing/gtest/include/gtest/gtest.h"
+
+namespace {
+
+static const char kFileContents[] = "This is a target.";
+static const char kFileContents2[] = "This is another target.";
+
+class ShortcutTest : public testing::Test {
+ protected:
+ virtual void SetUp() OVERRIDE {
+ ASSERT_TRUE(temp_dir_.CreateUniqueTempDir());
+ ASSERT_TRUE(temp_dir_2_.CreateUniqueTempDir());
+
+ link_file_ = temp_dir_.path().Append(L"My Link.lnk");
+
+ // Shortcut 1's properties
+ {
+ const FilePath target_file(temp_dir_.path().Append(L"Target 1.txt"));
+ file_util::WriteFile(target_file, kFileContents,
+ arraysize(kFileContents));
+
+ link_properties_.set_target(target_file);
+ link_properties_.set_working_dir(temp_dir_.path());
+ link_properties_.set_arguments(L"--magic --awesome");
+ link_properties_.set_description(L"Chrome is awesome.");
+ link_properties_.set_icon(link_properties_.target, 4);
+ link_properties_.set_app_id(L"Chrome");
+ link_properties_.set_dual_mode(false);
+ }
+
+ // Shortcut 2's properties (all different from properties of shortcut 1).
+ {
+ const FilePath target_file_2(temp_dir_.path().Append(L"Target 2.txt"));
+ file_util::WriteFile(target_file_2, kFileContents2,
+ arraysize(kFileContents2));
+
+ FilePath icon_path_2;
+ file_util::CreateTemporaryFileInDir(temp_dir_.path(), &icon_path_2);
+
+ link_properties_2_.set_target(target_file_2);
+ link_properties_2_.set_working_dir(temp_dir_2_.path());
+ link_properties_2_.set_arguments(L"--super --crazy");
+ link_properties_2_.set_description(L"The best in the west.");
+ link_properties_2_.set_icon(icon_path_2, 0);
+ link_properties_2_.set_app_id(L"Chrome.UserLevelCrazySuffix");
+ link_properties_2_.set_dual_mode(true);
+ }
+ }
+
+ base::win::ScopedCOMInitializer com_initializer_;
+ ScopedTempDir temp_dir_;
+ ScopedTempDir temp_dir_2_;
+
+ // The link file to be created/updated in the shortcut tests below.
+ FilePath link_file_;
+
+ // Properties for the created shortcut.
+ base::win::ShortcutProperties link_properties_;
+
+ // Properties for the updated shortcut.
+ base::win::ShortcutProperties link_properties_2_;
+};
+
+} // namespace
+
+TEST_F(ShortcutTest, CreateAndResolveShortcut) {
+ base::win::ShortcutProperties only_target_properties;
+ only_target_properties.set_target(link_properties_.target);
+
+ ASSERT_TRUE(base::win::CreateOrUpdateShortcutLink(
+ link_file_, only_target_properties, base::win::SHORTCUT_CREATE_ALWAYS));
+
+ FilePath resolved_name;
+ EXPECT_TRUE(base::win::ResolveShortcut(link_file_, &resolved_name, NULL));
+
+ char read_contents[arraysize(kFileContents)];
+ file_util::ReadFile(resolved_name, read_contents, arraysize(read_contents));
+ EXPECT_STREQ(kFileContents, read_contents);
+}
+
+TEST_F(ShortcutTest, ResolveShortcutWithArgs) {
+ ASSERT_TRUE(base::win::CreateOrUpdateShortcutLink(
+ link_file_, link_properties_, base::win::SHORTCUT_CREATE_ALWAYS));
+
+ FilePath resolved_name;
+ string16 args;
+ EXPECT_TRUE(base::win::ResolveShortcut(link_file_, &resolved_name, &args));
+
+ char read_contents[arraysize(kFileContents)];
+ file_util::ReadFile(resolved_name, read_contents, arraysize(read_contents));
+ EXPECT_STREQ(kFileContents, read_contents);
+ EXPECT_EQ(link_properties_.arguments, args);
+}
+
+TEST_F(ShortcutTest, CreateShortcutWithOnlySomeProperties) {
+ base::win::ShortcutProperties target_and_args_properties;
+ target_and_args_properties.set_target(link_properties_.target);
+ target_and_args_properties.set_arguments(link_properties_.arguments);
+
+ ASSERT_TRUE(base::win::CreateOrUpdateShortcutLink(
+ link_file_, target_and_args_properties,
+ base::win::SHORTCUT_CREATE_ALWAYS));
+
+ base::win::ValidateShortcut(link_file_, target_and_args_properties);
+}
+
+TEST_F(ShortcutTest, CreateShortcutVerifyProperties) {
+ ASSERT_TRUE(base::win::CreateOrUpdateShortcutLink(
+ link_file_, link_properties_, base::win::SHORTCUT_CREATE_ALWAYS));
+
+ base::win::ValidateShortcut(link_file_, link_properties_);
+}
+
+TEST_F(ShortcutTest, UpdateShortcutVerifyProperties) {
+ ASSERT_TRUE(base::win::CreateOrUpdateShortcutLink(
+ link_file_, link_properties_, base::win::SHORTCUT_CREATE_ALWAYS));
+
+ ASSERT_TRUE(base::win::CreateOrUpdateShortcutLink(
+ link_file_, link_properties_2_, base::win::SHORTCUT_UPDATE_EXISTING));
+
+ base::win::ValidateShortcut(link_file_, link_properties_2_);
+}
+
+TEST_F(ShortcutTest, UpdateShortcutUpdateOnlyTargetAndResolve) {
+ ASSERT_TRUE(base::win::CreateOrUpdateShortcutLink(
+ link_file_, link_properties_, base::win::SHORTCUT_CREATE_ALWAYS));
+
+ base::win::ShortcutProperties update_only_target_properties;
+ update_only_target_properties.set_target(link_properties_2_.target);
+
+ ASSERT_TRUE(base::win::CreateOrUpdateShortcutLink(
+ link_file_, update_only_target_properties,
+ base::win::SHORTCUT_UPDATE_EXISTING));
+
+ base::win::ShortcutProperties expected_properties = link_properties_;
+ expected_properties.set_target(link_properties_2_.target);
+ base::win::ValidateShortcut(link_file_, expected_properties);
+
+ FilePath resolved_name;
+ EXPECT_TRUE(base::win::ResolveShortcut(link_file_, &resolved_name, NULL));
+
+ char read_contents[arraysize(kFileContents2)];
+ file_util::ReadFile(resolved_name, read_contents, arraysize(read_contents));
+ EXPECT_STREQ(kFileContents2, read_contents);
+}
+
+TEST_F(ShortcutTest, UpdateShortcutMakeDualMode) {
+ ASSERT_TRUE(base::win::CreateOrUpdateShortcutLink(
+ link_file_, link_properties_, base::win::SHORTCUT_CREATE_ALWAYS));
+
+ base::win::ShortcutProperties make_dual_mode_properties;
+ make_dual_mode_properties.set_dual_mode(true);
+
+ ASSERT_TRUE(base::win::CreateOrUpdateShortcutLink(
+ link_file_, make_dual_mode_properties,
+ base::win::SHORTCUT_UPDATE_EXISTING));
+
+ base::win::ShortcutProperties expected_properties = link_properties_;
+ expected_properties.set_dual_mode(true);
+ base::win::ValidateShortcut(link_file_, expected_properties);
+}
+
+TEST_F(ShortcutTest, UpdateShortcutRemoveDualMode) {
+ ASSERT_TRUE(base::win::CreateOrUpdateShortcutLink(
+ link_file_, link_properties_2_, base::win::SHORTCUT_CREATE_ALWAYS));
+
+ base::win::ShortcutProperties remove_dual_mode_properties;
+ remove_dual_mode_properties.set_dual_mode(false);
+
+ ASSERT_TRUE(base::win::CreateOrUpdateShortcutLink(
+ link_file_, remove_dual_mode_properties,
+ base::win::SHORTCUT_UPDATE_EXISTING));
+
+ base::win::ShortcutProperties expected_properties = link_properties_2_;
+ expected_properties.set_dual_mode(false);
+ base::win::ValidateShortcut(link_file_, expected_properties);
+}
+
+TEST_F(ShortcutTest, UpdateShortcutClearArguments) {
+ ASSERT_TRUE(base::win::CreateOrUpdateShortcutLink(
+ link_file_, link_properties_, base::win::SHORTCUT_CREATE_ALWAYS));
+
+ base::win::ShortcutProperties clear_arguments_properties;
+ clear_arguments_properties.set_arguments(string16());
+
+ ASSERT_TRUE(base::win::CreateOrUpdateShortcutLink(
+ link_file_, clear_arguments_properties,
+ base::win::SHORTCUT_UPDATE_EXISTING));
+
+ base::win::ShortcutProperties expected_properties = link_properties_;
+ expected_properties.set_arguments(string16());
+ base::win::ValidateShortcut(link_file_, expected_properties);
+}
+
+TEST_F(ShortcutTest, FailUpdateShortcutThatDoesNotExist) {
+ ASSERT_FALSE(base::win::CreateOrUpdateShortcutLink(
+ link_file_, link_properties_, base::win::SHORTCUT_UPDATE_EXISTING));
+ ASSERT_FALSE(file_util::PathExists(link_file_));
+}
+
+TEST_F(ShortcutTest, TruncateShortcutAllProperties) {
+ ASSERT_TRUE(base::win::CreateOrUpdateShortcutLink(
+ link_file_, link_properties_, base::win::SHORTCUT_CREATE_ALWAYS));
+
+ ASSERT_TRUE(base::win::CreateOrUpdateShortcutLink(
+ link_file_, link_properties_2_, base::win::SHORTCUT_REPLACE_EXISTING));
+
+ base::win::ValidateShortcut(link_file_, link_properties_2_);
+}
+
+TEST_F(ShortcutTest, TruncateShortcutSomeProperties) {
+ ASSERT_TRUE(base::win::CreateOrUpdateShortcutLink(
+ link_file_, link_properties_, base::win::SHORTCUT_CREATE_ALWAYS));
+
+ base::win::ShortcutProperties new_properties;
+ new_properties.set_target(link_properties_2_.target);
+ new_properties.set_description(link_properties_2_.description);
+ ASSERT_TRUE(base::win::CreateOrUpdateShortcutLink(
+ link_file_, new_properties, base::win::SHORTCUT_REPLACE_EXISTING));
+
+ // Expect only properties in |new_properties| to be set, all other properties
+ // should have been overwritten.
+ base::win::ShortcutProperties expected_properties = new_properties;
+ expected_properties.set_working_dir(FilePath());
+ expected_properties.set_arguments(string16());
+ expected_properties.set_icon(FilePath(), 0);
+ expected_properties.set_app_id(string16());
+ expected_properties.set_dual_mode(false);
+ base::win::ValidateShortcut(link_file_, expected_properties);
+}
+
+TEST_F(ShortcutTest, FailTruncateShortcutThatDoesNotExist) {
+ ASSERT_FALSE(base::win::CreateOrUpdateShortcutLink(
+ link_file_, link_properties_, base::win::SHORTCUT_REPLACE_EXISTING));
+ ASSERT_FALSE(file_util::PathExists(link_file_));
+}
diff --git a/base/win/startup_information.cc b/base/win/startup_information.cc
new file mode 100644
index 0000000..aff52eb
--- /dev/null
+++ b/base/win/startup_information.cc
@@ -0,0 +1,109 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "base/win/startup_information.h"
+
+#include "base/logging.h"
+#include "base/win/windows_version.h"
+
+namespace {
+
+typedef BOOL (WINAPI *InitializeProcThreadAttributeListFunction)(
+ LPPROC_THREAD_ATTRIBUTE_LIST attribute_list,
+ DWORD attribute_count,
+ DWORD flags,
+ PSIZE_T size);
+static InitializeProcThreadAttributeListFunction
+ initialize_proc_thread_attribute_list;
+
+typedef BOOL (WINAPI *UpdateProcThreadAttributeFunction)(
+ LPPROC_THREAD_ATTRIBUTE_LIST attribute_list,
+ DWORD flags,
+ DWORD_PTR attribute,
+ PVOID value,
+ SIZE_T size,
+ PVOID previous_value,
+ PSIZE_T return_size);
+static UpdateProcThreadAttributeFunction update_proc_thread_attribute_list;
+
+typedef VOID (WINAPI *DeleteProcThreadAttributeListFunction)(
+ LPPROC_THREAD_ATTRIBUTE_LIST lpAttributeList);
+static DeleteProcThreadAttributeListFunction delete_proc_thread_attribute_list;
+
+} // namespace
+
+namespace base {
+namespace win {
+
+StartupInformation::StartupInformation() {
+ memset(&startup_info_, 0, sizeof(startup_info_));
+
+ // Pre Windows Vista doesn't support STARTUPINFOEX.
+ if (base::win::GetVersion() < base::win::VERSION_VISTA) {
+ startup_info_.StartupInfo.cb = sizeof(STARTUPINFO);
+ return;
+ }
+
+ startup_info_.StartupInfo.cb = sizeof(startup_info_);
+
+ // Load the attribute API functions.
+ if (!initialize_proc_thread_attribute_list ||
+ !update_proc_thread_attribute_list ||
+ !delete_proc_thread_attribute_list) {
+ HMODULE module = ::GetModuleHandleW(L"kernel32.dll");
+ initialize_proc_thread_attribute_list =
+ reinterpret_cast<InitializeProcThreadAttributeListFunction>(
+ ::GetProcAddress(module, "InitializeProcThreadAttributeList"));
+ update_proc_thread_attribute_list =
+ reinterpret_cast<UpdateProcThreadAttributeFunction>(
+ ::GetProcAddress(module, "UpdateProcThreadAttribute"));
+ delete_proc_thread_attribute_list =
+ reinterpret_cast<DeleteProcThreadAttributeListFunction>(
+ ::GetProcAddress(module, "DeleteProcThreadAttributeList"));
+ }
+}
+
+StartupInformation::~StartupInformation() {
+ if (startup_info_.lpAttributeList) {
+ delete_proc_thread_attribute_list(startup_info_.lpAttributeList);
+ delete [] reinterpret_cast<BYTE*>(startup_info_.lpAttributeList);
+ }
+}
+
+bool StartupInformation::InitializeProcThreadAttributeList(
+ DWORD attribute_count) {
+ if (startup_info_.StartupInfo.cb != sizeof(startup_info_) ||
+ startup_info_.lpAttributeList)
+ return false;
+
+ SIZE_T size = 0;
+ initialize_proc_thread_attribute_list(NULL, attribute_count, 0, &size);
+ if (size == 0)
+ return false;
+
+ startup_info_.lpAttributeList =
+ reinterpret_cast<LPPROC_THREAD_ATTRIBUTE_LIST>(new BYTE[size]);
+ if (!initialize_proc_thread_attribute_list(startup_info_.lpAttributeList,
+ attribute_count, 0, &size)) {
+ delete [] reinterpret_cast<BYTE*>(startup_info_.lpAttributeList);
+ startup_info_.lpAttributeList = NULL;
+ return false;
+ }
+
+ return true;
+}
+
+bool StartupInformation::UpdateProcThreadAttribute(
+ DWORD_PTR attribute,
+ void* value,
+ size_t size) {
+ if (!startup_info_.lpAttributeList)
+ return false;
+ return !!update_proc_thread_attribute_list(startup_info_.lpAttributeList, 0,
+ attribute, value, size, NULL, NULL);
+}
+
+} // namespace win
+} // namespace base
+
diff --git a/base/win/startup_information.h b/base/win/startup_information.h
new file mode 100644
index 0000000..7cef81f
--- /dev/null
+++ b/base/win/startup_information.h
@@ -0,0 +1,49 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef BASE_WIN_STARTUP_INFORMATION_H_
+#define BASE_WIN_STARTUP_INFORMATION_H_
+
+#include <windows.h>
+
+#include "base/base_export.h"
+#include "base/basictypes.h"
+
+namespace base {
+namespace win {
+
+// Manages the lifetime of additional attributes in STARTUPINFOEX.
+class BASE_EXPORT StartupInformation {
+ public:
+ StartupInformation();
+
+ ~StartupInformation();
+
+ // Initialize the attribute list for the specified number of entries.
+ bool InitializeProcThreadAttributeList(DWORD attribute_count);
+
+ // Sets one entry in the initialized attribute list.
+ bool UpdateProcThreadAttribute(DWORD_PTR attribute,
+ void* value,
+ size_t size);
+
+ LPSTARTUPINFOW startup_info() { return &startup_info_.StartupInfo; }
+ const LPSTARTUPINFOW startup_info() const {
+ return const_cast<const LPSTARTUPINFOW>(&startup_info_.StartupInfo);
+ }
+
+ bool has_extended_startup_info() const {
+ return !!startup_info_.lpAttributeList;
+ }
+
+ private:
+ STARTUPINFOEXW startup_info_;
+ DISALLOW_COPY_AND_ASSIGN(StartupInformation);
+};
+
+} // namespace win
+} // namespace base
+
+#endif // BASE_WIN_SCOPED_STARTUP_INFO_EX_H_
+
diff --git a/base/win/startup_information_unittest.cc b/base/win/startup_information_unittest.cc
new file mode 100644
index 0000000..1903564
--- /dev/null
+++ b/base/win/startup_information_unittest.cc
@@ -0,0 +1,76 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include <windows.h>
+
+#include <string>
+
+#include "base/command_line.h"
+#include "base/test/multiprocess_test.h"
+#include "base/win/scoped_handle.h"
+#include "base/win/scoped_process_information.h"
+#include "base/win/startup_information.h"
+#include "base/win/windows_version.h"
+#include "testing/multiprocess_func_list.h"
+
+const wchar_t kSectionName[] = L"EventTestSection";
+const size_t kSectionSize = 4096;
+
+MULTIPROCESS_TEST_MAIN(FireInheritedEvents) {
+ HANDLE section = ::OpenFileMappingW(PAGE_READWRITE, false, kSectionName);
+ HANDLE* events = reinterpret_cast<HANDLE*>(::MapViewOfFile(section,
+ PAGE_READWRITE, 0, 0, kSectionSize));
+ // This event should not be valid because it wasn't explicitly inherited.
+ if (::SetEvent(events[1]))
+ return -1;
+ // This event should be valid because it was explicitly inherited.
+ if (!::SetEvent(events[0]))
+ return -1;
+
+ return 0;
+}
+
+class StartupInformationTest : public base::MultiProcessTest {};
+
+// Verify that only the explicitly specified event is inherited.
+TEST_F(StartupInformationTest, InheritStdOut) {
+ if (base::win::GetVersion() < base::win::VERSION_VISTA)
+ return;
+
+ base::win::ScopedProcessInformation process_info;
+ base::win::StartupInformation startup_info;
+
+ HANDLE section = ::CreateFileMappingW(INVALID_HANDLE_VALUE, NULL,
+ PAGE_READWRITE, 0, kSectionSize,
+ kSectionName);
+ ASSERT_TRUE(section);
+
+ HANDLE* events = reinterpret_cast<HANDLE*>(::MapViewOfFile(section,
+ FILE_MAP_READ | FILE_MAP_WRITE, 0, 0, kSectionSize));
+
+ // Make two inheritable events.
+ SECURITY_ATTRIBUTES security_attributes = { sizeof(security_attributes),
+ NULL, true };
+ events[0] = ::CreateEvent(&security_attributes, false, false, NULL);
+ ASSERT_TRUE(events[0]);
+ events[1] = ::CreateEvent(&security_attributes, false, false, NULL);
+ ASSERT_TRUE(events[1]);
+
+ ASSERT_TRUE(startup_info.InitializeProcThreadAttributeList(1));
+ ASSERT_TRUE(startup_info.UpdateProcThreadAttribute(
+ PROC_THREAD_ATTRIBUTE_HANDLE_LIST, &events[0],
+ sizeof(events[0])));
+
+ std::wstring cmd_line =
+ this->MakeCmdLine("FireInheritedEvents", false).GetCommandLineString();
+
+ ASSERT_TRUE(::CreateProcess(NULL, const_cast<wchar_t*>(cmd_line.c_str()),
+ NULL, NULL, true, EXTENDED_STARTUPINFO_PRESENT,
+ NULL, NULL, startup_info.startup_info(),
+ process_info.Receive())) << ::GetLastError();
+ // Only the first event should be signalled
+ EXPECT_EQ(WAIT_OBJECT_0, ::WaitForMultipleObjects(2, events, false,
+ 4000));
+}
+
diff --git a/base/win/text_services_message_filter.cc b/base/win/text_services_message_filter.cc
new file mode 100644
index 0000000..7ce233d
--- /dev/null
+++ b/base/win/text_services_message_filter.cc
@@ -0,0 +1,82 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "base/win/text_services_message_filter.h"
+
+namespace base {
+namespace win {
+
+TextServicesMessageFilter::TextServicesMessageFilter()
+ : client_id_(TF_CLIENTID_NULL),
+ is_initialized_(false) {
+}
+
+TextServicesMessageFilter::~TextServicesMessageFilter() {
+ if (is_initialized_)
+ thread_mgr_->Deactivate();
+}
+
+bool TextServicesMessageFilter::Init() {
+ if (FAILED(thread_mgr_.CreateInstance(CLSID_TF_ThreadMgr)))
+ return false;
+
+ if (FAILED(message_pump_.QueryFrom(thread_mgr_)))
+ return false;
+
+ if (FAILED(keystroke_mgr_.QueryFrom(thread_mgr_)))
+ return false;
+
+ if (FAILED(thread_mgr_->Activate(&client_id_)))
+ return false;
+
+ is_initialized_ = true;
+ return is_initialized_;
+}
+
+// Wraps for ITfMessagePump::PeekMessage with win32 PeekMessage signature.
+// Obtains messages from application message queue.
+BOOL TextServicesMessageFilter::DoPeekMessage(MSG* msg,
+ HWND window_handle,
+ UINT msg_filter_min,
+ UINT msg_filter_max,
+ UINT remove_msg) {
+ BOOL result = FALSE;
+ if (FAILED(message_pump_->PeekMessage(msg, window_handle,
+ msg_filter_min, msg_filter_max,
+ remove_msg, &result))) {
+ result = FALSE;
+ }
+
+ return result;
+}
+
+// Sends message to Text Service Manager.
+// The message will be used to input composition text.
+// Returns true if |message| was consumed by text service manager.
+bool TextServicesMessageFilter::ProcessMessage(const MSG& msg) {
+ if (msg.message == WM_KEYDOWN) {
+ BOOL eaten = FALSE;
+ HRESULT hr = keystroke_mgr_->TestKeyDown(msg.wParam, msg.lParam, &eaten);
+ if (FAILED(hr) && !eaten)
+ return false;
+ eaten = FALSE;
+ hr = keystroke_mgr_->KeyDown(msg.wParam, msg.lParam, &eaten);
+ return (SUCCEEDED(hr) && !!eaten);
+ }
+
+ if (msg.message == WM_KEYUP) {
+ BOOL eaten = FALSE;
+ HRESULT hr = keystroke_mgr_->TestKeyUp(msg.wParam, msg.lParam, &eaten);
+ if (FAILED(hr) && !eaten)
+ return false;
+ eaten = FALSE;
+ hr = keystroke_mgr_->KeyUp(msg.wParam, msg.lParam, &eaten);
+ return (SUCCEEDED(hr) && !!eaten);
+ }
+
+ return false;
+}
+
+} // namespace win
+} // namespace base
diff --git a/base/win/text_services_message_filter.h b/base/win/text_services_message_filter.h
new file mode 100644
index 0000000..facd613
--- /dev/null
+++ b/base/win/text_services_message_filter.h
@@ -0,0 +1,48 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef BASE_WIN_TEXT_SERVICES_MESSAGE_FILTER_H_
+#define BASE_WIN_TEXT_SERVICES_MESSAGE_FILTER_H_
+
+#include <msctf.h>
+#include <Windows.h>
+
+#include "base/memory/scoped_ptr.h"
+#include "base/message_pump_win.h"
+#include "base/win/metro.h"
+#include "base/win/scoped_comptr.h"
+
+namespace base {
+namespace win {
+
+// TextServicesMessageFilter extends MessageFilter with methods that are using
+// Text Services Framework COM component.
+class BASE_EXPORT TextServicesMessageFilter
+ : public base::MessagePumpForUI::MessageFilter {
+ public:
+ TextServicesMessageFilter();
+ virtual ~TextServicesMessageFilter();
+ virtual BOOL DoPeekMessage(MSG* msg,
+ HWND window_handle,
+ UINT msg_filter_min,
+ UINT msg_filter_max,
+ UINT remove_msg) OVERRIDE;
+ virtual bool ProcessMessage(const MSG& msg) OVERRIDE;
+
+ bool Init();
+
+ private:
+ TfClientId client_id_;
+ bool is_initialized_;
+ base::win::ScopedComPtr<ITfThreadMgr> thread_mgr_;
+ base::win::ScopedComPtr<ITfMessagePump> message_pump_;
+ base::win::ScopedComPtr<ITfKeystrokeMgr> keystroke_mgr_;
+
+ DISALLOW_COPY_AND_ASSIGN(TextServicesMessageFilter);
+};
+
+} // namespace win
+} // namespace base
+
+#endif // BASE_WIN_TEXT_SERVICES_MESSAGE_FILTER_H_
diff --git a/base/win/win_util.cc b/base/win/win_util.cc
new file mode 100644
index 0000000..7c106a2
--- /dev/null
+++ b/base/win/win_util.cc
@@ -0,0 +1,227 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "base/win/win_util.h"
+
+#include <aclapi.h>
+#include <shobjidl.h> // Must be before propkey.
+#include <initguid.h>
+#include <propkey.h>
+#include <propvarutil.h>
+#include <sddl.h>
+#include <shlobj.h>
+
+#include "base/logging.h"
+#include "base/memory/scoped_ptr.h"
+#include "base/win/registry.h"
+#include "base/string_util.h"
+#include "base/stringprintf.h"
+#include "base/threading/thread_restrictions.h"
+#include "base/win/scoped_handle.h"
+#include "base/win/windows_version.h"
+
+namespace {
+
+// Sets the value of |property_key| to |property_value| in |property_store|.
+// Clears the PropVariant contained in |property_value|.
+bool SetPropVariantValueForPropertyStore(
+ IPropertyStore* property_store,
+ const PROPERTYKEY& property_key,
+ PROPVARIANT* property_value) {
+ DCHECK(property_store);
+
+ HRESULT result = property_store->SetValue(property_key, *property_value);
+ if (result == S_OK)
+ result = property_store->Commit();
+
+ PropVariantClear(property_value);
+ return SUCCEEDED(result);
+}
+
+} // namespace
+
+namespace base {
+namespace win {
+
+static bool g_crash_on_process_detach = false;
+
+#define NONCLIENTMETRICS_SIZE_PRE_VISTA \
+ SIZEOF_STRUCT_WITH_SPECIFIED_LAST_MEMBER(NONCLIENTMETRICS, lfMessageFont)
+
+void GetNonClientMetrics(NONCLIENTMETRICS* metrics) {
+ DCHECK(metrics);
+
+ static const UINT SIZEOF_NONCLIENTMETRICS =
+ (base::win::GetVersion() >= base::win::VERSION_VISTA) ?
+ sizeof(NONCLIENTMETRICS) : NONCLIENTMETRICS_SIZE_PRE_VISTA;
+ metrics->cbSize = SIZEOF_NONCLIENTMETRICS;
+ const bool success = !!SystemParametersInfo(SPI_GETNONCLIENTMETRICS,
+ SIZEOF_NONCLIENTMETRICS, metrics,
+ 0);
+ DCHECK(success);
+}
+
+bool GetUserSidString(std::wstring* user_sid) {
+ // Get the current token.
+ HANDLE token = NULL;
+ if (!::OpenProcessToken(::GetCurrentProcess(), TOKEN_QUERY, &token))
+ return false;
+ base::win::ScopedHandle token_scoped(token);
+
+ DWORD size = sizeof(TOKEN_USER) + SECURITY_MAX_SID_SIZE;
+ scoped_array<BYTE> user_bytes(new BYTE[size]);
+ TOKEN_USER* user = reinterpret_cast<TOKEN_USER*>(user_bytes.get());
+
+ if (!::GetTokenInformation(token, TokenUser, user, size, &size))
+ return false;
+
+ if (!user->User.Sid)
+ return false;
+
+ // Convert the data to a string.
+ wchar_t* sid_string;
+ if (!::ConvertSidToStringSid(user->User.Sid, &sid_string))
+ return false;
+
+ *user_sid = sid_string;
+
+ ::LocalFree(sid_string);
+
+ return true;
+}
+
+bool IsShiftPressed() {
+ return (::GetKeyState(VK_SHIFT) & 0x8000) == 0x8000;
+}
+
+bool IsCtrlPressed() {
+ return (::GetKeyState(VK_CONTROL) & 0x8000) == 0x8000;
+}
+
+bool IsAltPressed() {
+ return (::GetKeyState(VK_MENU) & 0x8000) == 0x8000;
+}
+
+bool UserAccountControlIsEnabled() {
+ // This can be slow if Windows ends up going to disk. Should watch this key
+ // for changes and only read it once, preferably on the file thread.
+ // http://code.google.com/p/chromium/issues/detail?id=61644
+ base::ThreadRestrictions::ScopedAllowIO allow_io;
+
+ base::win::RegKey key(HKEY_LOCAL_MACHINE,
+ L"SOFTWARE\\Microsoft\\Windows\\CurrentVersion\\Policies\\System",
+ KEY_READ);
+ DWORD uac_enabled;
+ if (key.ReadValueDW(L"EnableLUA", &uac_enabled) != ERROR_SUCCESS)
+ return true;
+ // Users can set the EnableLUA value to something arbitrary, like 2, which
+ // Vista will treat as UAC enabled, so we make sure it is not set to 0.
+ return (uac_enabled != 0);
+}
+
+bool SetBooleanValueForPropertyStore(IPropertyStore* property_store,
+ const PROPERTYKEY& property_key,
+ bool property_bool_value) {
+ PROPVARIANT property_value;
+ if (FAILED(InitPropVariantFromBoolean(property_bool_value, &property_value)))
+ return false;
+
+ return SetPropVariantValueForPropertyStore(property_store,
+ property_key,
+ &property_value);
+}
+
+bool SetStringValueForPropertyStore(IPropertyStore* property_store,
+ const PROPERTYKEY& property_key,
+ const wchar_t* property_string_value) {
+ PROPVARIANT property_value;
+ if (FAILED(InitPropVariantFromString(property_string_value, &property_value)))
+ return false;
+
+ return SetPropVariantValueForPropertyStore(property_store,
+ property_key,
+ &property_value);
+}
+
+bool SetAppIdForPropertyStore(IPropertyStore* property_store,
+ const wchar_t* app_id) {
+ // App id should be less than 64 chars and contain no space. And recommended
+ // format is CompanyName.ProductName[.SubProduct.ProductNumber].
+ // See http://msdn.microsoft.com/en-us/library/dd378459%28VS.85%29.aspx
+ DCHECK(lstrlen(app_id) < 64 && wcschr(app_id, L' ') == NULL);
+
+ return SetStringValueForPropertyStore(property_store,
+ PKEY_AppUserModel_ID,
+ app_id);
+}
+
+static const char16 kAutoRunKeyPath[] =
+ L"Software\\Microsoft\\Windows\\CurrentVersion\\Run";
+
+bool AddCommandToAutoRun(HKEY root_key, const string16& name,
+ const string16& command) {
+ base::win::RegKey autorun_key(root_key, kAutoRunKeyPath, KEY_SET_VALUE);
+ return (autorun_key.WriteValue(name.c_str(), command.c_str()) ==
+ ERROR_SUCCESS);
+}
+
+bool RemoveCommandFromAutoRun(HKEY root_key, const string16& name) {
+ base::win::RegKey autorun_key(root_key, kAutoRunKeyPath, KEY_SET_VALUE);
+ return (autorun_key.DeleteValue(name.c_str()) == ERROR_SUCCESS);
+}
+
+bool ReadCommandFromAutoRun(HKEY root_key,
+ const string16& name,
+ string16* command) {
+ base::win::RegKey autorun_key(root_key, kAutoRunKeyPath, KEY_QUERY_VALUE);
+ return (autorun_key.ReadValue(name.c_str(), command) == ERROR_SUCCESS);
+}
+
+void SetShouldCrashOnProcessDetach(bool crash) {
+ g_crash_on_process_detach = crash;
+}
+
+bool ShouldCrashOnProcessDetach() {
+ return g_crash_on_process_detach;
+}
+
+bool IsMachineATablet() {
+ if (base::win::GetVersion() < base::win::VERSION_WIN7)
+ return false;
+ const int kMultiTouch = NID_INTEGRATED_TOUCH | NID_MULTI_INPUT | NID_READY;
+ const int kMaxTabletScreenWidth = 1366;
+ const int kMaxTabletScreenHeight = 768;
+ int sm = GetSystemMetrics(SM_DIGITIZER);
+ if ((sm & kMultiTouch) == kMultiTouch) {
+ int cx = GetSystemMetrics(SM_CXSCREEN);
+ int cy = GetSystemMetrics(SM_CYSCREEN);
+ return cx <= kMaxTabletScreenWidth && cy <= kMaxTabletScreenHeight;
+ }
+ return false;
+}
+
+} // namespace win
+} // namespace base
+
+#ifdef _MSC_VER
+//
+// If the ASSERT below fails, please install Visual Studio 2005 Service Pack 1.
+//
+extern char VisualStudio2005ServicePack1Detection[10];
+COMPILE_ASSERT(sizeof(&VisualStudio2005ServicePack1Detection) == sizeof(void*),
+ VS2005SP1Detect);
+//
+// Chrome requires at least Service Pack 1 for Visual Studio 2005.
+//
+#endif // _MSC_VER
+
+#ifndef COPY_FILE_COPY_SYMLINK
+#error You must install the Windows 2008 or Vista Software Development Kit and \
+set it as your default include path to build this library. You can grab it by \
+searching for "download windows sdk 2008" in your favorite web search engine. \
+Also make sure you register the SDK with Visual Studio, by selecting \
+"Integrate Windows SDK with Visual Studio 2005" from the Windows SDK \
+menu (see Start - All Programs - Microsoft Windows SDK - \
+Visual Studio Registration).
+#endif
diff --git a/base/win/win_util.h b/base/win/win_util.h
new file mode 100644
index 0000000..7db98e9
--- /dev/null
+++ b/base/win/win_util.h
@@ -0,0 +1,123 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+// =============================================================================
+// PLEASE READ
+//
+// In general, you should not be adding stuff to this file.
+//
+// - If your thing is only used in one place, just put it in a reasonable
+// location in or near that one place. It's nice you want people to be able
+// to re-use your function, but realistically, if it hasn't been necessary
+// before after so many years of development, it's probably not going to be
+// used in other places in the future unless you know of them now.
+//
+// - If your thing is used by multiple callers and is UI-related, it should
+// probably be in app/win/ instead. Try to put it in the most specific file
+// possible (avoiding the *_util files when practical).
+//
+// =============================================================================
+
+#ifndef BASE_WIN_WIN_UTIL_H_
+#define BASE_WIN_WIN_UTIL_H_
+
+#include <windows.h>
+
+#include <string>
+
+#include "base/base_export.h"
+#include "base/string16.h"
+
+struct IPropertyStore;
+struct _tagpropertykey;
+typedef _tagpropertykey PROPERTYKEY;
+
+namespace base {
+namespace win {
+
+// A Windows message reflected from other windows. This message is sent
+// with the following arguments:
+// hWnd - Target window
+// uMsg - kReflectedMessage
+// wParam - Should be 0
+// lParam - Pointer to MSG struct containing the original message.
+const int kReflectedMessage = WM_APP + 3;
+
+BASE_EXPORT void GetNonClientMetrics(NONCLIENTMETRICS* metrics);
+
+// Returns the string representing the current user sid.
+BASE_EXPORT bool GetUserSidString(std::wstring* user_sid);
+
+// Returns true if the shift key is currently pressed.
+BASE_EXPORT bool IsShiftPressed();
+
+// Returns true if the ctrl key is currently pressed.
+BASE_EXPORT bool IsCtrlPressed();
+
+// Returns true if the alt key is currently pressed.
+BASE_EXPORT bool IsAltPressed();
+
+// Returns false if user account control (UAC) has been disabled with the
+// EnableLUA registry flag. Returns true if user account control is enabled.
+// NOTE: The EnableLUA registry flag, which is ignored on Windows XP
+// machines, might still exist and be set to 0 (UAC disabled), in which case
+// this function will return false. You should therefore check this flag only
+// if the OS is Vista or later.
+BASE_EXPORT bool UserAccountControlIsEnabled();
+
+// Sets the boolean value for a given key in given IPropertyStore.
+BASE_EXPORT bool SetBooleanValueForPropertyStore(
+ IPropertyStore* property_store,
+ const PROPERTYKEY& property_key,
+ bool property_bool_value);
+
+// Sets the string value for a given key in given IPropertyStore.
+BASE_EXPORT bool SetStringValueForPropertyStore(
+ IPropertyStore* property_store,
+ const PROPERTYKEY& property_key,
+ const wchar_t* property_string_value);
+
+// Sets the application id in given IPropertyStore. The function is intended
+// for tagging application/chromium shortcut, browser window and jump list for
+// Win7.
+BASE_EXPORT bool SetAppIdForPropertyStore(IPropertyStore* property_store,
+ const wchar_t* app_id);
+
+// Adds the specified |command| using the specified |name| to the AutoRun key.
+// |root_key| could be HKCU or HKLM or the root of any user hive.
+BASE_EXPORT bool AddCommandToAutoRun(HKEY root_key, const string16& name,
+ const string16& command);
+// Removes the command specified by |name| from the AutoRun key. |root_key|
+// could be HKCU or HKLM or the root of any user hive.
+BASE_EXPORT bool RemoveCommandFromAutoRun(HKEY root_key, const string16& name);
+
+// Reads the command specified by |name| from the AutoRun key. |root_key|
+// could be HKCU or HKLM or the root of any user hive. Used for unit-tests.
+BASE_EXPORT bool ReadCommandFromAutoRun(HKEY root_key,
+ const string16& name,
+ string16* command);
+
+// Sets whether to crash the process during exit. This is inspected by DLLMain
+// and used to intercept unexpected terminations of the process (via calls to
+// exit(), abort(), _exit(), ExitProcess()) and convert them into crashes.
+// Note that not all mechanisms for terminating the process are covered by
+// this. In particular, TerminateProcess() is not caught.
+BASE_EXPORT void SetShouldCrashOnProcessDetach(bool crash);
+BASE_EXPORT bool ShouldCrashOnProcessDetach();
+
+// A tablet by this definition is something that has integrated multi-touch
+// ready to use and also has screen resolution not greater than 1366x768.
+BASE_EXPORT bool IsMachineATablet();
+
+// Get the size of a struct up to and including the specified member.
+// This is necessary to set compatible struct sizes for different versions
+// of certain Windows APIs (e.g. SystemParametersInfo).
+#define SIZEOF_STRUCT_WITH_SPECIFIED_LAST_MEMBER(struct_name, member) \
+ offsetof(struct_name, member) + \
+ (sizeof static_cast<struct_name*>(NULL)->member)
+
+} // namespace win
+} // namespace base
+
+#endif // BASE_WIN_WIN_UTIL_H_
diff --git a/base/win/win_util_unittest.cc b/base/win/win_util_unittest.cc
new file mode 100644
index 0000000..b79ed56
--- /dev/null
+++ b/base/win/win_util_unittest.cc
@@ -0,0 +1,58 @@
+// Copyright (c) 2010 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include <windows.h>
+
+#include "base/basictypes.h"
+#include "base/string_util.h"
+#include "base/win/win_util.h"
+#include "base/win/windows_version.h"
+#include "testing/gtest/include/gtest/gtest.h"
+
+namespace base {
+namespace win {
+
+namespace {
+
+// Saves the current thread's locale ID when initialized, and restores it when
+// the instance is going out of scope.
+class ThreadLocaleSaver {
+ public:
+ ThreadLocaleSaver() : original_locale_id_(GetThreadLocale()) {}
+ ~ThreadLocaleSaver() { SetThreadLocale(original_locale_id_); }
+
+ private:
+ LCID original_locale_id_;
+
+ DISALLOW_COPY_AND_ASSIGN(ThreadLocaleSaver);
+};
+
+} // namespace
+
+// The test is somewhat silly, because the Vista bots some have UAC enabled
+// and some have it disabled. At least we check that it does not crash.
+TEST(BaseWinUtilTest, TestIsUACEnabled) {
+ if (GetVersion() >= base::win::VERSION_VISTA) {
+ UserAccountControlIsEnabled();
+ } else {
+ EXPECT_TRUE(UserAccountControlIsEnabled());
+ }
+}
+
+TEST(BaseWinUtilTest, TestGetUserSidString) {
+ std::wstring user_sid;
+ EXPECT_TRUE(GetUserSidString(&user_sid));
+ EXPECT_TRUE(!user_sid.empty());
+}
+
+TEST(BaseWinUtilTest, TestGetNonClientMetrics) {
+ NONCLIENTMETRICS metrics = {0};
+ GetNonClientMetrics(&metrics);
+ EXPECT_TRUE(metrics.cbSize > 0);
+ EXPECT_TRUE(metrics.iScrollWidth > 0);
+ EXPECT_TRUE(metrics.iScrollHeight > 0);
+}
+
+} // namespace win
+} // namespace base
diff --git a/base/win/windows_version.cc b/base/win/windows_version.cc
new file mode 100644
index 0000000..c130e0e
--- /dev/null
+++ b/base/win/windows_version.cc
@@ -0,0 +1,111 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "base/win/windows_version.h"
+
+#include <windows.h>
+
+#include "base/logging.h"
+#include "base/utf_string_conversions.h"
+#include "base/win/registry.h"
+
+namespace base {
+namespace win {
+
+// static
+OSInfo* OSInfo::GetInstance() {
+ // Note: we don't use the Singleton class because it depends on AtExitManager,
+ // and it's convenient for other modules to use this classs without it. This
+ // pattern is copied from gurl.cc.
+ static OSInfo* info;
+ if (!info) {
+ OSInfo* new_info = new OSInfo();
+ if (InterlockedCompareExchangePointer(
+ reinterpret_cast<PVOID*>(&info), new_info, NULL)) {
+ delete new_info;
+ }
+ }
+ return info;
+}
+
+OSInfo::OSInfo()
+ : version_(VERSION_PRE_XP),
+ architecture_(OTHER_ARCHITECTURE),
+ wow64_status_(GetWOW64StatusForProcess(GetCurrentProcess())) {
+ OSVERSIONINFOEX version_info = { sizeof version_info };
+ GetVersionEx(reinterpret_cast<OSVERSIONINFO*>(&version_info));
+ version_number_.major = version_info.dwMajorVersion;
+ version_number_.minor = version_info.dwMinorVersion;
+ version_number_.build = version_info.dwBuildNumber;
+ if ((version_number_.major == 5) && (version_number_.minor > 0)) {
+ // Treat XP Pro x64, Home Server, and Server 2003 R2 as Server 2003.
+ version_ = (version_number_.minor == 1) ? VERSION_XP : VERSION_SERVER_2003;
+ } else if (version_number_.major == 6) {
+ switch (version_number_.minor) {
+ case 0:
+ // Treat Windows Server 2008 the same as Windows Vista.
+ version_ = VERSION_VISTA;
+ break;
+ case 1:
+ // Treat Windows Server 2008 R2 the same as Windows 7.
+ version_ = VERSION_WIN7;
+ break;
+ default:
+ DCHECK_EQ(version_number_.minor, 2);
+ // Treat Windows Server 2012 the same as Windows 8.
+ version_ = VERSION_WIN8;
+ break;
+ }
+ } else if (version_number_.major > 6) {
+ NOTREACHED();
+ version_ = VERSION_WIN_LAST;
+ }
+ service_pack_.major = version_info.wServicePackMajor;
+ service_pack_.minor = version_info.wServicePackMinor;
+
+ SYSTEM_INFO system_info = { 0 };
+ GetNativeSystemInfo(&system_info);
+ switch (system_info.wProcessorArchitecture) {
+ case PROCESSOR_ARCHITECTURE_INTEL: architecture_ = X86_ARCHITECTURE; break;
+ case PROCESSOR_ARCHITECTURE_AMD64: architecture_ = X64_ARCHITECTURE; break;
+ case PROCESSOR_ARCHITECTURE_IA64: architecture_ = IA64_ARCHITECTURE; break;
+ }
+ processors_ = system_info.dwNumberOfProcessors;
+ allocation_granularity_ = system_info.dwAllocationGranularity;
+}
+
+OSInfo::~OSInfo() {
+}
+
+std::string OSInfo::processor_model_name() {
+ if (processor_model_name_.empty()) {
+ const wchar_t kProcessorNameString[] =
+ L"HARDWARE\\DESCRIPTION\\System\\CentralProcessor\\0";
+ base::win::RegKey key(HKEY_LOCAL_MACHINE, kProcessorNameString, KEY_READ);
+ string16 value;
+ key.ReadValue(L"ProcessorNameString", &value);
+ processor_model_name_ = UTF16ToUTF8(value);
+ }
+ return processor_model_name_;
+}
+
+// static
+OSInfo::WOW64Status OSInfo::GetWOW64StatusForProcess(HANDLE process_handle) {
+ typedef BOOL (WINAPI* IsWow64ProcessFunc)(HANDLE, PBOOL);
+ IsWow64ProcessFunc is_wow64_process = reinterpret_cast<IsWow64ProcessFunc>(
+ GetProcAddress(GetModuleHandle(L"kernel32.dll"), "IsWow64Process"));
+ if (!is_wow64_process)
+ return WOW64_DISABLED;
+ BOOL is_wow64 = FALSE;
+ if (!(*is_wow64_process)(process_handle, &is_wow64))
+ return WOW64_UNKNOWN;
+ return is_wow64 ? WOW64_ENABLED : WOW64_DISABLED;
+}
+
+Version GetVersion() {
+ return OSInfo::GetInstance()->version();
+}
+
+} // namespace win
+} // namespace base
diff --git a/base/win/windows_version.h b/base/win/windows_version.h
new file mode 100644
index 0000000..d466dad
--- /dev/null
+++ b/base/win/windows_version.h
@@ -0,0 +1,110 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#ifndef BASE_WIN_WINDOWS_VERSION_H_
+#define BASE_WIN_WINDOWS_VERSION_H_
+
+#include <string>
+
+#include "base/base_export.h"
+#include "base/basictypes.h"
+
+typedef void* HANDLE;
+
+namespace base {
+namespace win {
+
+// The running version of Windows. This is declared outside OSInfo for
+// syntactic sugar reasons; see the declaration of GetVersion() below.
+// NOTE: Keep these in order so callers can do things like
+// "if (base::win::GetVersion() >= base::win::VERSION_VISTA) ...".
+enum Version {
+ VERSION_PRE_XP = 0, // Not supported.
+ VERSION_XP,
+ VERSION_SERVER_2003, // Also includes XP Pro x64 and Server 2003 R2.
+ VERSION_VISTA, // Also includes Windows Server 2008.
+ VERSION_WIN7, // Also includes Windows Server 2008 R2.
+ VERSION_WIN8, // Also includes Windows Server 2012.
+ VERSION_WIN_LAST, // Indicates error condition.
+};
+
+// A singleton that can be used to query various pieces of information about the
+// OS and process state. Note that this doesn't use the base Singleton class, so
+// it can be used without an AtExitManager.
+class BASE_EXPORT OSInfo {
+ public:
+ struct VersionNumber {
+ int major;
+ int minor;
+ int build;
+ };
+
+ struct ServicePack {
+ int major;
+ int minor;
+ };
+
+ // The processor architecture this copy of Windows natively uses. For
+ // example, given an x64-capable processor, we have three possibilities:
+ // 32-bit Chrome running on 32-bit Windows: X86_ARCHITECTURE
+ // 32-bit Chrome running on 64-bit Windows via WOW64: X64_ARCHITECTURE
+ // 64-bit Chrome running on 64-bit Windows: X64_ARCHITECTURE
+ enum WindowsArchitecture {
+ X86_ARCHITECTURE,
+ X64_ARCHITECTURE,
+ IA64_ARCHITECTURE,
+ OTHER_ARCHITECTURE,
+ };
+
+ // Whether a process is running under WOW64 (the wrapper that allows 32-bit
+ // processes to run on 64-bit versions of Windows). This will return
+ // WOW64_DISABLED for both "32-bit Chrome on 32-bit Windows" and "64-bit
+ // Chrome on 64-bit Windows". WOW64_UNKNOWN means "an error occurred", e.g.
+ // the process does not have sufficient access rights to determine this.
+ enum WOW64Status {
+ WOW64_DISABLED,
+ WOW64_ENABLED,
+ WOW64_UNKNOWN,
+ };
+
+ static OSInfo* GetInstance();
+
+ Version version() const { return version_; }
+ // The next two functions return arrays of values, [major, minor(, build)].
+ VersionNumber version_number() const { return version_number_; }
+ ServicePack service_pack() const { return service_pack_; }
+ WindowsArchitecture architecture() const { return architecture_; }
+ int processors() const { return processors_; }
+ size_t allocation_granularity() const { return allocation_granularity_; }
+ WOW64Status wow64_status() const { return wow64_status_; }
+ std::string processor_model_name();
+
+ // Like wow64_status(), but for the supplied handle instead of the current
+ // process. This doesn't touch member state, so you can bypass the singleton.
+ static WOW64Status GetWOW64StatusForProcess(HANDLE process_handle);
+
+ private:
+ OSInfo();
+ ~OSInfo();
+
+ Version version_;
+ VersionNumber version_number_;
+ ServicePack service_pack_;
+ WindowsArchitecture architecture_;
+ int processors_;
+ size_t allocation_granularity_;
+ WOW64Status wow64_status_;
+ std::string processor_model_name_;
+
+ DISALLOW_COPY_AND_ASSIGN(OSInfo);
+};
+
+// Because this is by far the most commonly-requested value from the above
+// singleton, we add a global-scope accessor here as syntactic sugar.
+BASE_EXPORT Version GetVersion();
+
+} // namespace win
+} // namespace base
+
+#endif // BASE_WIN_WINDOWS_VERSION_H_
diff --git a/base/win/wrapped_window_proc.cc b/base/win/wrapped_window_proc.cc
new file mode 100644
index 0000000..04e59c5
--- /dev/null
+++ b/base/win/wrapped_window_proc.cc
@@ -0,0 +1,63 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "base/win/wrapped_window_proc.h"
+
+#include "base/atomicops.h"
+#include "base/logging.h"
+#include "base/process_util.h"
+
+namespace {
+
+base::win::WinProcExceptionFilter s_exception_filter = NULL;
+
+} // namespace.
+
+namespace base {
+namespace win {
+
+WinProcExceptionFilter SetWinProcExceptionFilter(
+ WinProcExceptionFilter filter) {
+ subtle::AtomicWord rv = subtle::NoBarrier_AtomicExchange(
+ reinterpret_cast<subtle::AtomicWord*>(&s_exception_filter),
+ reinterpret_cast<subtle::AtomicWord>(filter));
+ return reinterpret_cast<WinProcExceptionFilter>(rv);
+}
+
+int CallExceptionFilter(EXCEPTION_POINTERS* info) {
+ return s_exception_filter ? s_exception_filter(info) :
+ EXCEPTION_CONTINUE_SEARCH;
+}
+
+BASE_EXPORT void InitializeWindowClass(
+ const char16* class_name,
+ WNDPROC window_proc,
+ UINT style,
+ int class_extra,
+ int window_extra,
+ HCURSOR cursor,
+ HBRUSH background,
+ const char16* menu_name,
+ HICON large_icon,
+ HICON small_icon,
+ WNDCLASSEX* class_out) {
+ class_out->cbSize = sizeof(WNDCLASSEX);
+ class_out->style = style;
+ class_out->lpfnWndProc = window_proc;
+ class_out->cbClsExtra = class_extra;
+ class_out->cbWndExtra = window_extra;
+ class_out->hInstance = base::GetModuleFromAddress(window_proc);
+ class_out->hIcon = large_icon;
+ class_out->hCursor = cursor;
+ class_out->hbrBackground = background;
+ class_out->lpszMenuName = menu_name;
+ class_out->lpszClassName = class_name;
+ class_out->hIconSm = small_icon;
+
+ // Check if |window_proc| is valid.
+ DCHECK(class_out->hInstance != NULL);
+}
+
+} // namespace win
+} // namespace base
diff --git a/base/win/wrapped_window_proc.h b/base/win/wrapped_window_proc.h
new file mode 100644
index 0000000..b5793f2
--- /dev/null
+++ b/base/win/wrapped_window_proc.h
@@ -0,0 +1,85 @@
+// Copyright (c) 2012 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+// Provides a way to handle exceptions that happen while a WindowProc is
+// running. The behavior of exceptions generated inside a WindowProc is OS
+// dependent, but it is possible that the OS just ignores the exception and
+// continues execution, which leads to unpredictable behavior for Chrome.
+
+#ifndef BASE_WIN_WRAPPED_WINDOW_PROC_H_
+#define BASE_WIN_WRAPPED_WINDOW_PROC_H_
+
+#include <windows.h>
+
+#include "base/base_export.h"
+#include "base/string16.h"
+
+namespace base {
+namespace win {
+
+// An exception filter for a WindowProc. The return value determines how the
+// exception should be handled, following standard SEH rules. However, the
+// expected behavior for this function is to not return, instead of returning
+// EXCEPTION_EXECUTE_HANDLER or similar, given that in general we are not
+// prepared to handle exceptions.
+typedef int (__cdecl *WinProcExceptionFilter)(EXCEPTION_POINTERS* info);
+
+// Sets the filter to deal with exceptions inside a WindowProc. Returns the old
+// exception filter, if any.
+// This function should be called before any window is created.
+BASE_EXPORT WinProcExceptionFilter SetWinProcExceptionFilter(
+ WinProcExceptionFilter filter);
+
+// Calls the registered exception filter.
+BASE_EXPORT int CallExceptionFilter(EXCEPTION_POINTERS* info);
+
+// Initializes the WNDCLASSEX structure |*class_out| to be passed to
+// RegisterClassEx() making sure that it is associated with the module
+// containing the window procedure.
+BASE_EXPORT void InitializeWindowClass(
+ const char16* class_name,
+ WNDPROC window_proc,
+ UINT style,
+ int class_extra,
+ int window_extra,
+ HCURSOR cursor,
+ HBRUSH background,
+ const char16* menu_name,
+ HICON large_icon,
+ HICON small_icon,
+ WNDCLASSEX* class_out);
+
+// Wrapper that supplies a standard exception frame for the provided WindowProc.
+// The normal usage is something like this:
+//
+// LRESULT CALLBACK MyWinProc(HWND hwnd, UINT message,
+// WPARAM wparam, LPARAM lparam) {
+// // Do Something.
+// }
+//
+// ...
+//
+// WNDCLASSEX wc = {0};
+// wc.lpfnWndProc = WrappedWindowProc<MyWinProc>;
+// wc.lpszClassName = class_name;
+// ...
+// RegisterClassEx(&wc);
+//
+// CreateWindowW(class_name, window_name, ...
+//
+template <WNDPROC proc>
+LRESULT CALLBACK WrappedWindowProc(HWND hwnd, UINT message,
+ WPARAM wparam, LPARAM lparam) {
+ LRESULT rv = 0;
+ __try {
+ rv = proc(hwnd, message, wparam, lparam);
+ } __except(CallExceptionFilter(GetExceptionInformation())) {
+ }
+ return rv;
+}
+
+} // namespace win
+} // namespace base
+
+#endif // BASE_WIN_WRAPPED_WINDOW_PROC_H_
diff --git a/base/win/wrapped_window_proc_unittest.cc b/base/win/wrapped_window_proc_unittest.cc
new file mode 100644
index 0000000..ccf3c85
--- /dev/null
+++ b/base/win/wrapped_window_proc_unittest.cc
@@ -0,0 +1,79 @@
+// Copyright (c) 2011 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "base/win/wrapped_window_proc.h"
+#include "base/message_loop.h"
+#include "testing/gtest/include/gtest/gtest.h"
+
+namespace {
+
+DWORD kExceptionCode = 12345;
+WPARAM kCrashMsg = 98765;
+
+// A trivial WindowProc that generates an exception.
+LRESULT CALLBACK TestWindowProc(HWND hwnd, UINT message,
+ WPARAM wparam, LPARAM lparam) {
+ if (message == kCrashMsg)
+ RaiseException(kExceptionCode, 0, 0, NULL);
+ return DefWindowProc(hwnd, message, wparam, lparam);
+}
+
+// This class implements an exception filter that can be queried about a past
+// exception.
+class TestWrappedExceptionFiter {
+ public:
+ TestWrappedExceptionFiter() : called_(false) {
+ EXPECT_FALSE(s_filter_);
+ s_filter_ = this;
+ }
+
+ ~TestWrappedExceptionFiter() {
+ EXPECT_EQ(s_filter_, this);
+ s_filter_ = NULL;
+ }
+
+ bool called() {
+ return called_;
+ }
+
+ // The actual exception filter just records the exception.
+ static int Filter(EXCEPTION_POINTERS* info) {
+ EXPECT_FALSE(s_filter_->called_);
+ if (info->ExceptionRecord->ExceptionCode == kExceptionCode)
+ s_filter_->called_ = true;
+ return EXCEPTION_EXECUTE_HANDLER;
+ }
+
+ private:
+ bool called_;
+ static TestWrappedExceptionFiter* s_filter_;
+};
+TestWrappedExceptionFiter* TestWrappedExceptionFiter::s_filter_ = NULL;
+
+} // namespace.
+
+TEST(WrappedWindowProc, CatchesExceptions) {
+ HINSTANCE hinst = GetModuleHandle(NULL);
+ std::wstring class_name(L"TestClass");
+
+ WNDCLASS wc = {0};
+ wc.lpfnWndProc = base::win::WrappedWindowProc<TestWindowProc>;
+ wc.hInstance = hinst;
+ wc.lpszClassName = class_name.c_str();
+ RegisterClass(&wc);
+
+ HWND window = CreateWindow(class_name.c_str(), 0, 0, 0, 0, 0, 0, HWND_MESSAGE,
+ 0, hinst, 0);
+ ASSERT_TRUE(window);
+
+ // Before generating the exception we make sure that the filter will see it.
+ TestWrappedExceptionFiter wrapper;
+ base::win::WinProcExceptionFilter old_filter =
+ base::win::SetWinProcExceptionFilter(TestWrappedExceptionFiter::Filter);
+
+ SendMessage(window, kCrashMsg, 0, 0);
+ EXPECT_TRUE(wrapper.called());
+
+ base::win::SetWinProcExceptionFilter(old_filter);
+}