[Executor] [NFC] Restructure `ExecutorState::Process()` to avoid unneeded work.

This change makes several changes to improve performance and readability in the `ExecutorState::Process()` method:

* (Readability) Split the dispatch logic into four helper methods (`ProcessSync()`, `ProcessAsync()`, `ProcessNoop()`, and `ProcessConstTensor()`), to reduce the length of the `Process()` method.
* (Performance) Avoid preparing inputs, initializing per-op params, or creating an `OpKernelContext` for the `ProcessNoop()` and `ProcessConstTensor()` cases.

PiperOrigin-RevId: 300364786
Change-Id: Id1ab900aa3a5b4a1300c3ed188d7620cea75db98
diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc
index c985927..995c88c 100644
--- a/tensorflow/core/common_runtime/executor.cc
+++ b/tensorflow/core/common_runtime/executor.cc
@@ -1442,6 +1442,18 @@
   // Process a ready node in current thread.
   void Process(TaggedNode node, int64 scheduled_nsec);
 
+  Status ProcessSync(const NodeItem& item, OpKernelContext::Params* params,
+                     EntryVector* outputs,
+                     TensorReferenceVector* accessed_tensors,
+                     DeviceContext** device_context,
+                     NodeExecStatsInterface* stats);
+  void ProcessAsync(const NodeItem& item, const OpKernelContext::Params& params,
+                    const TaggedNode& tagged_node, Entry* first_input,
+                    NodeExecStatsInterface* stats);
+  void ProcessNoop(NodeExecStatsInterface* stats);
+  void ProcessConstTensor(const NodeItem& item, EntryVector* outputs,
+                          NodeExecStatsInterface* stats);
+
   // Before invoking item->kernel, fills in its "inputs".
   Status PrepareInputs(const NodeItem& item, Entry* first_input,
                        TensorValueVec* inputs,
@@ -1736,6 +1748,134 @@
       profiler::GetTFTraceMeLevel(item.kernel->IsExpensive()));
 }
 
+Status ExecutorState::ProcessSync(const NodeItem& item,
+                                  OpKernelContext::Params* params,
+                                  EntryVector* outputs,
+                                  TensorReferenceVector* accessed_tensors,
+                                  DeviceContext** device_context,
+                                  NodeExecStatsInterface* stats) {
+  Status s;
+  OpKernelContext ctx(params, item.num_outputs);
+  nodestats::SetOpStart(stats);
+
+  OpKernel* op_kernel = item.kernel;
+  Device* device = impl_->params_.device;
+
+  if (TF_PREDICT_FALSE(MightTrace(item, event_collector_))) {
+    tracing::ScopedRegion region(tracing::EventCategory::kCompute,
+                                 op_kernel->name_view());
+    profiler::AnnotatedTraceMe activity(
+        [&] {
+          return op_kernel->TraceString(
+              &ctx, /*verbose=*/profiler::TfOpDetailsEnabled());
+        },
+        profiler::GetTFTraceMeLevel(op_kernel->IsExpensive()));
+    device->Compute(op_kernel, &ctx);
+    nodestats::SetOpEnd(stats);
+    s = ProcessOutputs(item, &ctx, outputs, stats);
+  } else {
+    // In the common case, avoid creating any tracing objects.
+    if (op_kernel->IsExpensive()) {
+      KernelTimer timer;
+      device->Compute(op_kernel, &ctx);
+      op_kernel->UpdateCostEstimate(timer.ElapsedCycles());
+    } else {
+      device->Compute(op_kernel, &ctx);
+    }
+    nodestats::SetOpEnd(stats);
+    s = ProcessOutputs(item, &ctx, outputs, stats);
+  }
+  if (TF_PREDICT_FALSE(impl_->device_record_tensor_accesses_) && s.ok()) {
+    // Get the list of all tensors accessed during the execution
+    ctx.retrieve_accessed_tensors(accessed_tensors);
+    *device_context = ctx.op_device_context();
+  }
+  nodestats::SetMemory(stats, &ctx);
+  return s;
+}
+
+void ExecutorState::ProcessAsync(const NodeItem& item,
+                                 const OpKernelContext::Params& params,
+                                 const TaggedNode& tagged_node,
+                                 Entry* first_input,
+                                 NodeExecStatsInterface* stats) {
+  AsyncOpKernel* async_kernel = item.kernel->AsAsync();
+  DCHECK(async_kernel != nullptr);
+  AsyncState* state =
+      new AsyncState(params, tagged_node, &item, first_input, stats);
+
+  auto done = [this, state]() {
+    Device* device = impl_->params_.device;
+    NodeExecStatsInterface* stats = state->stats;  // Shorthand
+    Entry* first_input = state->first_input;       // Shorthand
+
+    nodestats::SetOpEnd(stats);
+    EntryVector outputs;
+    Status s = ProcessOutputs(*state->item, &state->ctx, &outputs, stats);
+    nodestats::SetMemory(stats, &state->ctx);
+    if (vlog_) {
+      VLOG(2) << "Async kernel done: " << state->item->node_id << " step "
+              << step_id_ << " " << SummarizeNodeDef(state->item->kernel->def())
+              << (state->tagged_node.is_dead ? " is dead" : "")
+              << " device: " << device->name();
+    }
+
+    // Clears inputs.
+    const int num_inputs = state->item->num_inputs;
+    for (int i = 0; i < num_inputs; ++i) {
+      (first_input + i)->ClearVal();
+    }
+    FrameState* input_frame = state->tagged_node.input_frame;
+    const int64 input_iter = state->tagged_node.input_iter;
+    MaybeMarkCompleted(input_frame, input_iter, *state->item);
+    TaggedNodeSeq ready;
+    if (s.ok()) {
+      PropagateOutputs(state->tagged_node, state->item, &outputs, &ready);
+    }
+    outputs.clear();
+    if (TF_PREDICT_FALSE(impl_->device_record_tensor_accesses_) && s.ok()) {
+      // Get the list of all tensors accessed during the execution
+      TensorReferenceVector accessed;
+      state->ctx.retrieve_accessed_tensors(&accessed);
+      nodestats::SetReferencedTensors(stats, accessed);
+      // callee takes ownership of the vector
+      device->ConsumeListOfAccessedTensors(state->ctx.op_device_context(),
+                                           accessed);
+    }
+    const bool completed = NodeDone(s, &ready, stats, nullptr);
+    delete state;
+    if (completed) ScheduleFinish();
+  };
+  nodestats::SetOpStart(stats);
+  {
+    profiler::AnnotatedTraceMe activity(
+        [&] {
+          return async_kernel->TraceString(
+              &state->ctx, /*verbose=*/profiler::TfOpDetailsEnabled());
+        },
+        profiler::GetTFTraceMeLevel(async_kernel->IsExpensive()));
+    impl_->params_.device->ComputeAsync(async_kernel, &state->ctx,
+                                        std::move(done));
+  }
+}
+
+void ExecutorState::ProcessNoop(NodeExecStatsInterface* stats) {
+  nodestats::SetOpStart(stats);
+  nodestats::SetOpEnd(stats);
+}
+
+void ExecutorState::ProcessConstTensor(const NodeItem& item,
+                                       EntryVector* outputs,
+                                       NodeExecStatsInterface* stats) {
+  nodestats::SetOpStart(stats);
+  nodestats::SetOpEnd(stats);
+  outputs->resize(1);
+  Entry& output = (*outputs)[0];
+  output.state = Entry::State::HAS_CONST_TENSOR;
+  output.const_tensor = item.const_tensor;
+  output.alloc_attr = item.output_attrs()[0];
+}
+
 void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) {
   profiler::TraceMe activity(
       [&] {
@@ -1849,13 +1989,16 @@
     outputs.clear();
 
     TensorReferenceVector accessed_tensors;
-    DeviceContext* device_context = nullptr;
     // Only execute this node if it is not dead or it is a send/recv
     // transfer node. For transfer nodes, we need to propagate the "dead"
     // bit even when the node is dead.
     bool launched_asynchronously = false;
     if (tagged_node.is_dead && !item.is_transfer_node) {
       outputs.resize(item.num_outputs);
+    } else if (TF_PREDICT_FALSE(item.is_noop)) {
+      ProcessNoop(stats);
+    } else if (item.const_tensor != nullptr && !params.track_allocations) {
+      ProcessConstTensor(item, &outputs, stats);
     } else {
       // Prepares inputs.
       bool is_input_dead = false;
@@ -1863,7 +2006,7 @@
                         &is_input_dead);
       if (!s.ok()) {
         // Clear inputs.
-        int num_inputs = item.num_inputs;
+        const int num_inputs = item.num_inputs;
         for (int i = 0; i < num_inputs; ++i) {
           (first_input + i)->ClearVal();
         }
@@ -1874,8 +2017,7 @@
       }
 
       // Set up compute params.
-      OpKernel* op_kernel = item.kernel;
-      params.op_kernel = op_kernel;
+      params.op_kernel = item.kernel;
       params.frame_iter = FrameAndIter(input_frame->frame_id, input_iter);
       params.is_input_dead = is_input_dead;
       params.output_attr_array = item.output_attrs();
@@ -1883,110 +2025,18 @@
       params.outputs_required_array = item.outputs_required.get();
 
       if (item.kernel_is_async) {
-        // Asynchronous computes.
-        AsyncOpKernel* async = item.kernel->AsAsync();
-        DCHECK(async != nullptr);
+        ProcessAsync(item, params, tagged_node, first_input, stats);
         launched_asynchronously = true;
-        AsyncState* state =
-            new AsyncState(params, tagged_node, &item, first_input, stats);
-
-        auto done = [this, state]() {
-          Device* device = impl_->params_.device;
-          NodeExecStatsInterface* stats = state->stats;  // Shorthand
-          Entry* first_input = state->first_input;       // Shorthand
-
-          nodestats::SetOpEnd(stats);
-          EntryVector outputs;
-          Status s = ProcessOutputs(*state->item, &state->ctx, &outputs, stats);
-          nodestats::SetMemory(stats, &state->ctx);
-          if (vlog_) {
-            VLOG(2) << "Async kernel done: " << state->item->node_id << " step "
-                    << step_id_ << " "
-                    << SummarizeNodeDef(state->item->kernel->def())
-                    << (state->tagged_node.is_dead ? " is dead" : "")
-                    << " device: " << device->name();
-          }
-
-          // Clears inputs.
-          const int num_inputs = state->item->num_inputs;
-          for (int i = 0; i < num_inputs; ++i) {
-            (first_input + i)->ClearVal();
-          }
-          FrameState* input_frame = state->tagged_node.input_frame;
-          const int64 input_iter = state->tagged_node.input_iter;
-          MaybeMarkCompleted(input_frame, input_iter, *state->item);
-          TaggedNodeSeq ready;
-          if (s.ok()) {
-            PropagateOutputs(state->tagged_node, state->item, &outputs, &ready);
-          }
-          outputs.clear();
-          if (s.ok() && impl_->device_record_tensor_accesses_) {
-            // Get the list of all tensors accessed during the execution
-            TensorReferenceVector accessed;
-            state->ctx.retrieve_accessed_tensors(&accessed);
-            nodestats::SetReferencedTensors(stats, accessed);
-            // callee takes ownership of the vector
-            device->ConsumeListOfAccessedTensors(state->ctx.op_device_context(),
-                                                 accessed);
-          }
-          const bool completed = NodeDone(s, &ready, stats, nullptr);
-          delete state;
-          if (completed) ScheduleFinish();
-        };
-        nodestats::SetOpStart(stats);
-        {
-          profiler::AnnotatedTraceMe activity(
-              [&] {
-                return op_kernel->TraceString(
-                    &state->ctx, /*verbose=*/profiler::TfOpDetailsEnabled());
-              },
-              profiler::GetTFTraceMeLevel(op_kernel->IsExpensive()));
-          device->ComputeAsync(async, &state->ctx, done);
-        }
       } else {
-        // Synchronous computes.
-        OpKernelContext ctx(&params, item.num_outputs);
-        nodestats::SetOpStart(stats);
-
-        if (TF_PREDICT_FALSE(item.is_noop)) {
-          nodestats::SetOpEnd(stats);
-        } else if (TF_PREDICT_FALSE(MightTrace(item, event_collector_))) {
-          tracing::ScopedRegion region(tracing::EventCategory::kCompute,
-                                       op_kernel->name_view());
-          profiler::AnnotatedTraceMe activity(
-              [&] {
-                return op_kernel->TraceString(
-                    &ctx, /*verbose=*/profiler::TfOpDetailsEnabled());
-              },
-              profiler::GetTFTraceMeLevel(op_kernel->IsExpensive()));
-          device->Compute(op_kernel, &ctx);
-          nodestats::SetOpEnd(stats);
-          s = ProcessOutputs(item, &ctx, &outputs, stats);
-        } else if (item.const_tensor != nullptr && !ctx.track_allocations()) {
-          // Special case for ConstantOp, which is very common.
-          nodestats::SetOpEnd(stats);
-          outputs.resize(1);
-          outputs[0].state = Entry::State::HAS_CONST_TENSOR;
-          outputs[0].const_tensor = item.const_tensor;
-          outputs[0].alloc_attr = ctx.output_alloc_attr(0);
-        } else {
-          // In the common case, avoid creating any tracing objects.
-          if (op_kernel->IsExpensive()) {
-            KernelTimer timer;
-            device->Compute(op_kernel, &ctx);
-            op_kernel->UpdateCostEstimate(timer.ElapsedCycles());
-          } else {
-            device->Compute(op_kernel, &ctx);
-          }
-          nodestats::SetOpEnd(stats);
-          s = ProcessOutputs(item, &ctx, &outputs, stats);
+        DeviceContext* device_context = nullptr;
+        s = ProcessSync(item, &params, &outputs, &accessed_tensors,
+                        &device_context, stats);
+        if (!accessed_tensors.empty()) {
+          nodestats::SetReferencedTensors(stats, accessed_tensors);
+          // device_context is set above in `ProcessSync()`.
+          device->ConsumeListOfAccessedTensors(device_context,
+                                               accessed_tensors);
         }
-        if (s.ok() && impl_->device_record_tensor_accesses_) {
-          // Get the list of all tensors accessed during the execution
-          ctx.retrieve_accessed_tensors(&accessed_tensors);
-          device_context = ctx.op_device_context();
-        }
-        nodestats::SetMemory(stats, &ctx);
       }
     }
 
@@ -2009,11 +2059,6 @@
         PropagateOutputs(tagged_node, &item, &outputs, &ready);
       }
       outputs.clear();
-      if (!accessed_tensors.empty()) {
-        nodestats::SetReferencedTensors(stats, accessed_tensors);
-        // device_context is set above in synchronous computes
-        device->ConsumeListOfAccessedTensors(device_context, accessed_tensors);
-      }
       if (stats) {
         scheduled_nsec = nodestats::NowInNsec();
       }