[Executor] [NFC] Restructure `ExecutorState::Process()` to avoid unneeded work.
This change makes several changes to improve performance and readability in the `ExecutorState::Process()` method:
* (Readability) Split the dispatch logic into four helper methods (`ProcessSync()`, `ProcessAsync()`, `ProcessNoop()`, and `ProcessConstTensor()`), to reduce the length of the `Process()` method.
* (Performance) Avoid preparing inputs, initializing per-op params, or creating an `OpKernelContext` for the `ProcessNoop()` and `ProcessConstTensor()` cases.
PiperOrigin-RevId: 300364786
Change-Id: Id1ab900aa3a5b4a1300c3ed188d7620cea75db98
diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc
index c985927..995c88c 100644
--- a/tensorflow/core/common_runtime/executor.cc
+++ b/tensorflow/core/common_runtime/executor.cc
@@ -1442,6 +1442,18 @@
// Process a ready node in current thread.
void Process(TaggedNode node, int64 scheduled_nsec);
+ Status ProcessSync(const NodeItem& item, OpKernelContext::Params* params,
+ EntryVector* outputs,
+ TensorReferenceVector* accessed_tensors,
+ DeviceContext** device_context,
+ NodeExecStatsInterface* stats);
+ void ProcessAsync(const NodeItem& item, const OpKernelContext::Params& params,
+ const TaggedNode& tagged_node, Entry* first_input,
+ NodeExecStatsInterface* stats);
+ void ProcessNoop(NodeExecStatsInterface* stats);
+ void ProcessConstTensor(const NodeItem& item, EntryVector* outputs,
+ NodeExecStatsInterface* stats);
+
// Before invoking item->kernel, fills in its "inputs".
Status PrepareInputs(const NodeItem& item, Entry* first_input,
TensorValueVec* inputs,
@@ -1736,6 +1748,134 @@
profiler::GetTFTraceMeLevel(item.kernel->IsExpensive()));
}
+Status ExecutorState::ProcessSync(const NodeItem& item,
+ OpKernelContext::Params* params,
+ EntryVector* outputs,
+ TensorReferenceVector* accessed_tensors,
+ DeviceContext** device_context,
+ NodeExecStatsInterface* stats) {
+ Status s;
+ OpKernelContext ctx(params, item.num_outputs);
+ nodestats::SetOpStart(stats);
+
+ OpKernel* op_kernel = item.kernel;
+ Device* device = impl_->params_.device;
+
+ if (TF_PREDICT_FALSE(MightTrace(item, event_collector_))) {
+ tracing::ScopedRegion region(tracing::EventCategory::kCompute,
+ op_kernel->name_view());
+ profiler::AnnotatedTraceMe activity(
+ [&] {
+ return op_kernel->TraceString(
+ &ctx, /*verbose=*/profiler::TfOpDetailsEnabled());
+ },
+ profiler::GetTFTraceMeLevel(op_kernel->IsExpensive()));
+ device->Compute(op_kernel, &ctx);
+ nodestats::SetOpEnd(stats);
+ s = ProcessOutputs(item, &ctx, outputs, stats);
+ } else {
+ // In the common case, avoid creating any tracing objects.
+ if (op_kernel->IsExpensive()) {
+ KernelTimer timer;
+ device->Compute(op_kernel, &ctx);
+ op_kernel->UpdateCostEstimate(timer.ElapsedCycles());
+ } else {
+ device->Compute(op_kernel, &ctx);
+ }
+ nodestats::SetOpEnd(stats);
+ s = ProcessOutputs(item, &ctx, outputs, stats);
+ }
+ if (TF_PREDICT_FALSE(impl_->device_record_tensor_accesses_) && s.ok()) {
+ // Get the list of all tensors accessed during the execution
+ ctx.retrieve_accessed_tensors(accessed_tensors);
+ *device_context = ctx.op_device_context();
+ }
+ nodestats::SetMemory(stats, &ctx);
+ return s;
+}
+
+void ExecutorState::ProcessAsync(const NodeItem& item,
+ const OpKernelContext::Params& params,
+ const TaggedNode& tagged_node,
+ Entry* first_input,
+ NodeExecStatsInterface* stats) {
+ AsyncOpKernel* async_kernel = item.kernel->AsAsync();
+ DCHECK(async_kernel != nullptr);
+ AsyncState* state =
+ new AsyncState(params, tagged_node, &item, first_input, stats);
+
+ auto done = [this, state]() {
+ Device* device = impl_->params_.device;
+ NodeExecStatsInterface* stats = state->stats; // Shorthand
+ Entry* first_input = state->first_input; // Shorthand
+
+ nodestats::SetOpEnd(stats);
+ EntryVector outputs;
+ Status s = ProcessOutputs(*state->item, &state->ctx, &outputs, stats);
+ nodestats::SetMemory(stats, &state->ctx);
+ if (vlog_) {
+ VLOG(2) << "Async kernel done: " << state->item->node_id << " step "
+ << step_id_ << " " << SummarizeNodeDef(state->item->kernel->def())
+ << (state->tagged_node.is_dead ? " is dead" : "")
+ << " device: " << device->name();
+ }
+
+ // Clears inputs.
+ const int num_inputs = state->item->num_inputs;
+ for (int i = 0; i < num_inputs; ++i) {
+ (first_input + i)->ClearVal();
+ }
+ FrameState* input_frame = state->tagged_node.input_frame;
+ const int64 input_iter = state->tagged_node.input_iter;
+ MaybeMarkCompleted(input_frame, input_iter, *state->item);
+ TaggedNodeSeq ready;
+ if (s.ok()) {
+ PropagateOutputs(state->tagged_node, state->item, &outputs, &ready);
+ }
+ outputs.clear();
+ if (TF_PREDICT_FALSE(impl_->device_record_tensor_accesses_) && s.ok()) {
+ // Get the list of all tensors accessed during the execution
+ TensorReferenceVector accessed;
+ state->ctx.retrieve_accessed_tensors(&accessed);
+ nodestats::SetReferencedTensors(stats, accessed);
+ // callee takes ownership of the vector
+ device->ConsumeListOfAccessedTensors(state->ctx.op_device_context(),
+ accessed);
+ }
+ const bool completed = NodeDone(s, &ready, stats, nullptr);
+ delete state;
+ if (completed) ScheduleFinish();
+ };
+ nodestats::SetOpStart(stats);
+ {
+ profiler::AnnotatedTraceMe activity(
+ [&] {
+ return async_kernel->TraceString(
+ &state->ctx, /*verbose=*/profiler::TfOpDetailsEnabled());
+ },
+ profiler::GetTFTraceMeLevel(async_kernel->IsExpensive()));
+ impl_->params_.device->ComputeAsync(async_kernel, &state->ctx,
+ std::move(done));
+ }
+}
+
+void ExecutorState::ProcessNoop(NodeExecStatsInterface* stats) {
+ nodestats::SetOpStart(stats);
+ nodestats::SetOpEnd(stats);
+}
+
+void ExecutorState::ProcessConstTensor(const NodeItem& item,
+ EntryVector* outputs,
+ NodeExecStatsInterface* stats) {
+ nodestats::SetOpStart(stats);
+ nodestats::SetOpEnd(stats);
+ outputs->resize(1);
+ Entry& output = (*outputs)[0];
+ output.state = Entry::State::HAS_CONST_TENSOR;
+ output.const_tensor = item.const_tensor;
+ output.alloc_attr = item.output_attrs()[0];
+}
+
void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) {
profiler::TraceMe activity(
[&] {
@@ -1849,13 +1989,16 @@
outputs.clear();
TensorReferenceVector accessed_tensors;
- DeviceContext* device_context = nullptr;
// Only execute this node if it is not dead or it is a send/recv
// transfer node. For transfer nodes, we need to propagate the "dead"
// bit even when the node is dead.
bool launched_asynchronously = false;
if (tagged_node.is_dead && !item.is_transfer_node) {
outputs.resize(item.num_outputs);
+ } else if (TF_PREDICT_FALSE(item.is_noop)) {
+ ProcessNoop(stats);
+ } else if (item.const_tensor != nullptr && !params.track_allocations) {
+ ProcessConstTensor(item, &outputs, stats);
} else {
// Prepares inputs.
bool is_input_dead = false;
@@ -1863,7 +2006,7 @@
&is_input_dead);
if (!s.ok()) {
// Clear inputs.
- int num_inputs = item.num_inputs;
+ const int num_inputs = item.num_inputs;
for (int i = 0; i < num_inputs; ++i) {
(first_input + i)->ClearVal();
}
@@ -1874,8 +2017,7 @@
}
// Set up compute params.
- OpKernel* op_kernel = item.kernel;
- params.op_kernel = op_kernel;
+ params.op_kernel = item.kernel;
params.frame_iter = FrameAndIter(input_frame->frame_id, input_iter);
params.is_input_dead = is_input_dead;
params.output_attr_array = item.output_attrs();
@@ -1883,110 +2025,18 @@
params.outputs_required_array = item.outputs_required.get();
if (item.kernel_is_async) {
- // Asynchronous computes.
- AsyncOpKernel* async = item.kernel->AsAsync();
- DCHECK(async != nullptr);
+ ProcessAsync(item, params, tagged_node, first_input, stats);
launched_asynchronously = true;
- AsyncState* state =
- new AsyncState(params, tagged_node, &item, first_input, stats);
-
- auto done = [this, state]() {
- Device* device = impl_->params_.device;
- NodeExecStatsInterface* stats = state->stats; // Shorthand
- Entry* first_input = state->first_input; // Shorthand
-
- nodestats::SetOpEnd(stats);
- EntryVector outputs;
- Status s = ProcessOutputs(*state->item, &state->ctx, &outputs, stats);
- nodestats::SetMemory(stats, &state->ctx);
- if (vlog_) {
- VLOG(2) << "Async kernel done: " << state->item->node_id << " step "
- << step_id_ << " "
- << SummarizeNodeDef(state->item->kernel->def())
- << (state->tagged_node.is_dead ? " is dead" : "")
- << " device: " << device->name();
- }
-
- // Clears inputs.
- const int num_inputs = state->item->num_inputs;
- for (int i = 0; i < num_inputs; ++i) {
- (first_input + i)->ClearVal();
- }
- FrameState* input_frame = state->tagged_node.input_frame;
- const int64 input_iter = state->tagged_node.input_iter;
- MaybeMarkCompleted(input_frame, input_iter, *state->item);
- TaggedNodeSeq ready;
- if (s.ok()) {
- PropagateOutputs(state->tagged_node, state->item, &outputs, &ready);
- }
- outputs.clear();
- if (s.ok() && impl_->device_record_tensor_accesses_) {
- // Get the list of all tensors accessed during the execution
- TensorReferenceVector accessed;
- state->ctx.retrieve_accessed_tensors(&accessed);
- nodestats::SetReferencedTensors(stats, accessed);
- // callee takes ownership of the vector
- device->ConsumeListOfAccessedTensors(state->ctx.op_device_context(),
- accessed);
- }
- const bool completed = NodeDone(s, &ready, stats, nullptr);
- delete state;
- if (completed) ScheduleFinish();
- };
- nodestats::SetOpStart(stats);
- {
- profiler::AnnotatedTraceMe activity(
- [&] {
- return op_kernel->TraceString(
- &state->ctx, /*verbose=*/profiler::TfOpDetailsEnabled());
- },
- profiler::GetTFTraceMeLevel(op_kernel->IsExpensive()));
- device->ComputeAsync(async, &state->ctx, done);
- }
} else {
- // Synchronous computes.
- OpKernelContext ctx(¶ms, item.num_outputs);
- nodestats::SetOpStart(stats);
-
- if (TF_PREDICT_FALSE(item.is_noop)) {
- nodestats::SetOpEnd(stats);
- } else if (TF_PREDICT_FALSE(MightTrace(item, event_collector_))) {
- tracing::ScopedRegion region(tracing::EventCategory::kCompute,
- op_kernel->name_view());
- profiler::AnnotatedTraceMe activity(
- [&] {
- return op_kernel->TraceString(
- &ctx, /*verbose=*/profiler::TfOpDetailsEnabled());
- },
- profiler::GetTFTraceMeLevel(op_kernel->IsExpensive()));
- device->Compute(op_kernel, &ctx);
- nodestats::SetOpEnd(stats);
- s = ProcessOutputs(item, &ctx, &outputs, stats);
- } else if (item.const_tensor != nullptr && !ctx.track_allocations()) {
- // Special case for ConstantOp, which is very common.
- nodestats::SetOpEnd(stats);
- outputs.resize(1);
- outputs[0].state = Entry::State::HAS_CONST_TENSOR;
- outputs[0].const_tensor = item.const_tensor;
- outputs[0].alloc_attr = ctx.output_alloc_attr(0);
- } else {
- // In the common case, avoid creating any tracing objects.
- if (op_kernel->IsExpensive()) {
- KernelTimer timer;
- device->Compute(op_kernel, &ctx);
- op_kernel->UpdateCostEstimate(timer.ElapsedCycles());
- } else {
- device->Compute(op_kernel, &ctx);
- }
- nodestats::SetOpEnd(stats);
- s = ProcessOutputs(item, &ctx, &outputs, stats);
+ DeviceContext* device_context = nullptr;
+ s = ProcessSync(item, ¶ms, &outputs, &accessed_tensors,
+ &device_context, stats);
+ if (!accessed_tensors.empty()) {
+ nodestats::SetReferencedTensors(stats, accessed_tensors);
+ // device_context is set above in `ProcessSync()`.
+ device->ConsumeListOfAccessedTensors(device_context,
+ accessed_tensors);
}
- if (s.ok() && impl_->device_record_tensor_accesses_) {
- // Get the list of all tensors accessed during the execution
- ctx.retrieve_accessed_tensors(&accessed_tensors);
- device_context = ctx.op_device_context();
- }
- nodestats::SetMemory(stats, &ctx);
}
}
@@ -2009,11 +2059,6 @@
PropagateOutputs(tagged_node, &item, &outputs, &ready);
}
outputs.clear();
- if (!accessed_tensors.empty()) {
- nodestats::SetReferencedTensors(stats, accessed_tensors);
- // device_context is set above in synchronous computes
- device->ConsumeListOfAccessedTensors(device_context, accessed_tensors);
- }
if (stats) {
scheduled_nsec = nodestats::NowInNsec();
}