| // Copyright 2014 The Chromium Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| #include "mojo/public/cpp/bindings/lib/router.h" |
| |
| #include <stdint.h> |
| |
| #include <utility> |
| |
| #include "base/bind.h" |
| #include "base/location.h" |
| #include "base/logging.h" |
| #include "base/memory/ptr_util.h" |
| #include "base/stl_util.h" |
| #include "mojo/public/cpp/bindings/sync_call_restrictions.h" |
| |
| namespace mojo { |
| namespace internal { |
| |
| // ---------------------------------------------------------------------------- |
| |
| namespace { |
| |
| void DCheckIfInvalid(const base::WeakPtr<Router>& router, |
| const std::string& message) { |
| bool is_valid = router && !router->encountered_error() && router->is_valid(); |
| DCHECK(!is_valid) << message; |
| } |
| |
| class ResponderThunk : public MessageReceiverWithStatus { |
| public: |
| explicit ResponderThunk(const base::WeakPtr<Router>& router, |
| scoped_refptr<base::SingleThreadTaskRunner> runner) |
| : router_(router), |
| accept_was_invoked_(false), |
| task_runner_(std::move(runner)) {} |
| ~ResponderThunk() override { |
| if (!accept_was_invoked_) { |
| // The Mojo application handled a message that was expecting a response |
| // but did not send a response. |
| // We raise an error to signal the calling application that an error |
| // condition occurred. Without this the calling application would have no |
| // way of knowing it should stop waiting for a response. |
| if (task_runner_->RunsTasksOnCurrentThread()) { |
| // Please note that even if this code is run from a different task |
| // runner on the same thread as |task_runner_|, it is okay to directly |
| // call Router::RaiseError(), because it will raise error from the |
| // correct task runner asynchronously. |
| if (router_) |
| router_->RaiseError(); |
| } else { |
| task_runner_->PostTask(FROM_HERE, |
| base::Bind(&Router::RaiseError, router_)); |
| } |
| } |
| } |
| |
| // MessageReceiver implementation: |
| bool Accept(Message* message) override { |
| DCHECK(task_runner_->RunsTasksOnCurrentThread()); |
| accept_was_invoked_ = true; |
| DCHECK(message->has_flag(Message::kFlagIsResponse)); |
| |
| bool result = false; |
| |
| if (router_) |
| result = router_->Accept(message); |
| |
| return result; |
| } |
| |
| // MessageReceiverWithStatus implementation: |
| bool IsValid() override { |
| DCHECK(task_runner_->RunsTasksOnCurrentThread()); |
| return router_ && !router_->encountered_error() && router_->is_valid(); |
| } |
| |
| void DCheckInvalid(const std::string& message) override { |
| if (task_runner_->RunsTasksOnCurrentThread()) { |
| DCheckIfInvalid(router_, message); |
| } else { |
| task_runner_->PostTask(FROM_HERE, |
| base::Bind(&DCheckIfInvalid, router_, message)); |
| } |
| } |
| |
| private: |
| base::WeakPtr<Router> router_; |
| bool accept_was_invoked_; |
| scoped_refptr<base::SingleThreadTaskRunner> task_runner_; |
| }; |
| |
| } // namespace |
| |
| // ---------------------------------------------------------------------------- |
| |
| Router::SyncResponseInfo::SyncResponseInfo(bool* in_response_received) |
| : response_received(in_response_received) {} |
| |
| Router::SyncResponseInfo::~SyncResponseInfo() {} |
| |
| // ---------------------------------------------------------------------------- |
| |
| Router::HandleIncomingMessageThunk::HandleIncomingMessageThunk(Router* router) |
| : router_(router) { |
| } |
| |
| Router::HandleIncomingMessageThunk::~HandleIncomingMessageThunk() { |
| } |
| |
| bool Router::HandleIncomingMessageThunk::Accept(Message* message) { |
| return router_->HandleIncomingMessage(message); |
| } |
| |
| // ---------------------------------------------------------------------------- |
| |
| Router::Router(ScopedMessagePipeHandle message_pipe, |
| FilterChain filters, |
| bool expects_sync_requests, |
| scoped_refptr<base::SingleThreadTaskRunner> runner) |
| : thunk_(this), |
| filters_(std::move(filters)), |
| connector_(std::move(message_pipe), |
| Connector::SINGLE_THREADED_SEND, |
| std::move(runner)), |
| incoming_receiver_(nullptr), |
| next_request_id_(0), |
| testing_mode_(false), |
| pending_task_for_messages_(false), |
| encountered_error_(false), |
| weak_factory_(this) { |
| filters_.SetSink(&thunk_); |
| if (expects_sync_requests) |
| connector_.AllowWokenUpBySyncWatchOnSameThread(); |
| connector_.set_incoming_receiver(filters_.GetHead()); |
| connector_.set_connection_error_handler( |
| base::Bind(&Router::OnConnectionError, base::Unretained(this))); |
| } |
| |
| Router::~Router() {} |
| |
| bool Router::Accept(Message* message) { |
| DCHECK(thread_checker_.CalledOnValidThread()); |
| DCHECK(!message->has_flag(Message::kFlagExpectsResponse)); |
| return connector_.Accept(message); |
| } |
| |
| bool Router::AcceptWithResponder(Message* message, MessageReceiver* responder) { |
| DCHECK(thread_checker_.CalledOnValidThread()); |
| DCHECK(message->has_flag(Message::kFlagExpectsResponse)); |
| |
| // Reserve 0 in case we want it to convey special meaning in the future. |
| uint64_t request_id = next_request_id_++; |
| if (request_id == 0) |
| request_id = next_request_id_++; |
| |
| bool is_sync = message->has_flag(Message::kFlagIsSync); |
| message->set_request_id(request_id); |
| if (!connector_.Accept(message)) |
| return false; |
| |
| if (!is_sync) { |
| // We assume ownership of |responder|. |
| async_responders_[request_id] = base::WrapUnique(responder); |
| return true; |
| } |
| |
| SyncCallRestrictions::AssertSyncCallAllowed(); |
| |
| bool response_received = false; |
| std::unique_ptr<MessageReceiver> sync_responder(responder); |
| sync_responses_.insert(std::make_pair( |
| request_id, base::WrapUnique(new SyncResponseInfo(&response_received)))); |
| |
| base::WeakPtr<Router> weak_self = weak_factory_.GetWeakPtr(); |
| connector_.SyncWatch(&response_received); |
| // Make sure that this instance hasn't been destroyed. |
| if (weak_self) { |
| DCHECK(ContainsKey(sync_responses_, request_id)); |
| auto iter = sync_responses_.find(request_id); |
| DCHECK_EQ(&response_received, iter->second->response_received); |
| if (response_received) { |
| std::unique_ptr<Message> response = std::move(iter->second->response); |
| ignore_result(sync_responder->Accept(response.get())); |
| } |
| sync_responses_.erase(iter); |
| } |
| |
| // Return true means that we take ownership of |responder|. |
| return true; |
| } |
| |
| void Router::EnableTestingMode() { |
| DCHECK(thread_checker_.CalledOnValidThread()); |
| testing_mode_ = true; |
| connector_.set_enforce_errors_from_incoming_receiver(false); |
| } |
| |
| bool Router::HandleIncomingMessage(Message* message) { |
| DCHECK(thread_checker_.CalledOnValidThread()); |
| |
| const bool during_sync_call = |
| connector_.during_sync_handle_watcher_callback(); |
| if (!message->has_flag(Message::kFlagIsSync) && |
| (during_sync_call || !pending_messages_.empty())) { |
| std::unique_ptr<Message> pending_message(new Message); |
| message->MoveTo(pending_message.get()); |
| pending_messages_.push(std::move(pending_message)); |
| |
| if (!pending_task_for_messages_) { |
| pending_task_for_messages_ = true; |
| connector_.task_runner()->PostTask( |
| FROM_HERE, base::Bind(&Router::HandleQueuedMessages, |
| weak_factory_.GetWeakPtr())); |
| } |
| |
| return true; |
| } |
| |
| return HandleMessageInternal(message); |
| } |
| |
| void Router::HandleQueuedMessages() { |
| DCHECK(thread_checker_.CalledOnValidThread()); |
| DCHECK(pending_task_for_messages_); |
| |
| base::WeakPtr<Router> weak_self = weak_factory_.GetWeakPtr(); |
| while (!pending_messages_.empty()) { |
| std::unique_ptr<Message> message(std::move(pending_messages_.front())); |
| pending_messages_.pop(); |
| |
| bool result = HandleMessageInternal(message.get()); |
| if (!weak_self) |
| return; |
| |
| if (!result && !testing_mode_) { |
| connector_.RaiseError(); |
| break; |
| } |
| } |
| |
| pending_task_for_messages_ = false; |
| |
| // We may have already seen a connection error from the connector, but |
| // haven't notified the user because we want to process all the queued |
| // messages first. We should do it now. |
| if (connector_.encountered_error() && !encountered_error_) |
| OnConnectionError(); |
| } |
| |
| bool Router::HandleMessageInternal(Message* message) { |
| if (message->has_flag(Message::kFlagExpectsResponse)) { |
| if (!incoming_receiver_) |
| return false; |
| |
| MessageReceiverWithStatus* responder = new ResponderThunk( |
| weak_factory_.GetWeakPtr(), connector_.task_runner()); |
| bool ok = incoming_receiver_->AcceptWithResponder(message, responder); |
| if (!ok) |
| delete responder; |
| return ok; |
| |
| } else if (message->has_flag(Message::kFlagIsResponse)) { |
| uint64_t request_id = message->request_id(); |
| |
| if (message->has_flag(Message::kFlagIsSync)) { |
| auto it = sync_responses_.find(request_id); |
| if (it == sync_responses_.end()) { |
| DCHECK(testing_mode_); |
| return false; |
| } |
| it->second->response.reset(new Message()); |
| message->MoveTo(it->second->response.get()); |
| *it->second->response_received = true; |
| return true; |
| } |
| |
| auto it = async_responders_.find(request_id); |
| if (it == async_responders_.end()) { |
| DCHECK(testing_mode_); |
| return false; |
| } |
| std::unique_ptr<MessageReceiver> responder = std::move(it->second); |
| async_responders_.erase(it); |
| return responder->Accept(message); |
| } else { |
| if (!incoming_receiver_) |
| return false; |
| |
| return incoming_receiver_->Accept(message); |
| } |
| } |
| |
| void Router::OnConnectionError() { |
| if (encountered_error_) |
| return; |
| |
| if (!pending_messages_.empty()) { |
| // After all the pending messages are processed, we will check whether an |
| // error has been encountered and run the user's connection error handler |
| // if necessary. |
| DCHECK(pending_task_for_messages_); |
| return; |
| } |
| |
| if (connector_.during_sync_handle_watcher_callback()) { |
| // We don't want the error handler to reenter an ongoing sync call. |
| connector_.task_runner()->PostTask( |
| FROM_HERE, |
| base::Bind(&Router::OnConnectionError, weak_factory_.GetWeakPtr())); |
| return; |
| } |
| |
| encountered_error_ = true; |
| if (!error_handler_.is_null()) |
| error_handler_.Run(); |
| } |
| |
| // ---------------------------------------------------------------------------- |
| |
| } // namespace internal |
| } // namespace mojo |