Rewrite/simplify tracing.

The implementation so far was prematurely optimized. It had all threads record directly into a shared vector indexed by block_ids. The idea was (1) to avoid the overhead of locking or other synchronization primitives when tracing a multi-thread execution, and (2) to avoid overhead of growing heap buffers. The new implementation is much more straightforward, as is most evident from the fact that it doesn't use relaxed_atomic_store anymore (yet still runs free of TSan errors), and that we were able to remove the ProcessedTrace class.

The above-mentioned issues (1) and (2) that drove the earlier design are now addressed as follows in the new design: (1) Each thread now records to its own specific vector of trace entries; these thread-specific vectors are only coalesced into a global vector when dumping a trace. This removed the need for any locking or atomic operations. (2) We are less careful than before about avoiding heap allocations. We just reserve upfront a rather large buffer size, large enough to avoid most subsequent heap reallocations and small enough to still not matter in practical tracing situations.

The proximate motivation for this change is that the existing design, requiring indexing of trace entries by block_id, is now inconvenient as we need to experiment with TrMul implementation changes where packing is not necessarily directly associated with a block_ids anymore.

PiperOrigin-RevId: 259996147
diff --git a/trace.cc b/trace.cc
index 6704303..8e28b47 100644
--- a/trace.cc
+++ b/trace.cc
@@ -24,200 +24,153 @@
 
 #include "block_map.h"
 #include "check_macros.h"
-#include "common.h"
+#include "side_pair.h"
 #include "time.h"
 
 namespace ruy {
 
 #ifdef RUY_TRACE
 
-struct BlockTraceEntry {
-  std::uint32_t thread_id = 0;
-  TimePoint time_reserved;
-  TimePoint time_computed_coords;
-  SidePair<TimePoint> time_packed;
-  TimePoint time_finished;
+enum class TraceEvent : std::uint8_t {
+  kNone,
+  kThreadStart,
+  kThreadLoopStart,
+  kThreadEnd,
+  kBlockReserved,
+  kBlockPackedLhs,
+  kBlockPackedRhs,
+  kBlockFinished
 };
 
-struct ThreadTraceEntry {
-  TimePoint time_start;
-  TimePoint time_loop_start;
-  TimePoint time_end;
+struct TraceEntry {
+  TimePoint time_point;
+  TraceEvent event;
+  // ruy-internal thread id i.e. contiguous index into array of threads,
+  // with 0 designating the main thread.
+  std::uint16_t thread_id = 0;
+  // Additional parameters whose meaning depends on the 'event' type.
+  std::uint32_t params[1];
 };
 
 struct Trace {
-  enum class LifeStage {
-    kInitial,
-    kRecordingRootFields,
-    kRecordingBlockAndThreadFields,
-    kComplete
-  };
-  void StartRecordingBlockAndThreadFields(const BlockMap& block_map_,
-                                          int thread_count_) {
-    RUY_DCHECK(life_stage == LifeStage::kRecordingRootFields);
-    block_map = block_map_;
-    thread_count = thread_count_;
-    int num_blocks = NumBlocks(block_map);
-    if (num_blocks > block_entries.size()) {
-      block_entries.resize(NumBlocks(block_map));
-    }
-    if (thread_count > thread_entries.size()) {
-      thread_entries.resize(thread_count);
-    }
-    life_stage = LifeStage::kRecordingBlockAndThreadFields;
-  }
   BlockMap block_map;
   int thread_count = 0;
-  std::vector<BlockTraceEntry> block_entries;
-  std::vector<ThreadTraceEntry> thread_entries;
+  // During recording, to avoid having to use locks or atomics, we let
+  // each thread append to its own specific vector.
+  std::vector<std::vector<TraceEntry>> thread_specific_entries;
+  // Global vector of entries into which we coalesce thread_specific_entries
+  // after recording is finished, when dumping a trace. See
+  // AggregateThreadSpecificEntries.
+  std::vector<TraceEntry> entries;
   TimePoint time_start;
   TimePoint time_execute;
   TimePoint time_end;
-  LifeStage life_stage = LifeStage::kInitial;
 };
 
-struct ProcessedTrace {
-  enum class Event : std::uint8_t {
-    kNone,
-    kThreadStart,
-    kThreadLoopStart,
-    kThreadEnd,
-    kBlockReserved,
-    kBlockComputedCoords,
-    kBlockPackedLhs,
-    kBlockPackedRhs,
-    kBlockFinished
-  };
-  struct Entry {
-    Event event = Event::kNone;
-    std::uint32_t thread_id = 0;
-    std::uint32_t block_id = 0;
-    TimePoint time;
-  };
+namespace {
 
-  BlockMap block_map;
-  int thread_count = 0;
-  TimePoint time_start;
-  TimePoint time_execute;
-  TimePoint time_end;
-  std::vector<Entry> entries;
-  void Add(Event event, std::uint32_t thread_id, std::uint32_t block_id,
-           TimePoint time) {
-    // If the time point is still in its default-constructed state,
-    // that means we didn't record it.
-    if (!time.time_since_epoch().count()) {
-      return;
+// Coalesce Trace::thread_specific_entries into Trace::entries.
+void AggregateThreadSpecificEntries(Trace* trace) {
+  RUY_CHECK(trace->entries.empty());
+  for (auto& thread_specific_entries_vector : trace->thread_specific_entries) {
+    for (const TraceEntry& entry : thread_specific_entries_vector) {
+      trace->entries.push_back(entry);
     }
-    Entry entry;
-    entry.event = event;
-    entry.thread_id = thread_id;
-    entry.block_id = block_id;
-    entry.time = time;
-    entries.push_back(entry);
+    thread_specific_entries_vector.clear();
   }
-  void Process(const Trace& trace) {
-    thread_count = trace.thread_count;
-    block_map = trace.block_map;
-    time_start = trace.time_start;
-    time_execute = trace.time_execute;
-    time_end = trace.time_end;
-    entries.clear();
-    for (int i = 0; i < trace.thread_count; i++) {
-      const auto& entry = trace.thread_entries[i];
-      Add(Event::kThreadStart, i, 0, entry.time_start);
-      Add(Event::kThreadLoopStart, i, 0, entry.time_loop_start);
-      Add(Event::kThreadEnd, i, 0, entry.time_end);
-    }
-    std::uint32_t num_blocks = NumBlocks(block_map);
-    for (int i = 0; i < num_blocks; i++) {
-      const auto& entry = trace.block_entries[i];
-      Add(Event::kBlockReserved, entry.thread_id, i, entry.time_reserved);
-      Add(Event::kBlockComputedCoords, entry.thread_id, i,
-          entry.time_computed_coords);
-      Add(Event::kBlockPackedLhs, entry.thread_id, i,
-          entry.time_packed[Side::kLhs]);
-      Add(Event::kBlockPackedRhs, entry.thread_id, i,
-          entry.time_packed[Side::kRhs]);
-      Add(Event::kBlockFinished, entry.thread_id, i, entry.time_finished);
-    }
-    std::sort(entries.begin(), entries.end(),
-              [](const Entry& a, const Entry& b) -> bool {
-                return a.time < b.time ||
-                       (a.time == b.time &&
-                        static_cast<int>(a.event) < static_cast<int>(b.event));
-              });
-  }
-  void Dump() {
-    const char* trace_filename = getenv("RUY_TRACE_FILE");
-    FILE* trace_file = trace_filename ? fopen(trace_filename, "w") : stderr;
-    if (!trace_file) {
-      fprintf(stderr, "Failed to open %s for write, errno=%d\n", trace_filename,
-              errno);
-      RUY_CHECK(false);
-    }
-    fprintf(trace_file, "thread_count:%d\n", thread_count);
-    fprintf(trace_file, "num_blocks:%d\n", NumBlocks(block_map));
-    fprintf(trace_file, "rows:%d\n", block_map.rows);
-    fprintf(trace_file, "cols:%d\n", block_map.cols);
-    fprintf(trace_file, "Execute: %.9f\n",
-            ToSeconds(time_execute - time_start));
-    for (const Entry& entry : entries) {
-      double time = ToSeconds(entry.time - time_start);
-      switch (entry.event) {
-        case Event::kThreadStart:
-          fprintf(trace_file, "ThreadStart: %.9f, %d\n", time, entry.thread_id);
-          break;
-        case Event::kThreadLoopStart:
-          fprintf(trace_file, "ThreadLoopStart: %.9f, %d\n", time,
-                  entry.thread_id);
-          break;
-        case Event::kThreadEnd:
-          fprintf(trace_file, "ThreadEnd: %.9f, %d\n", time, entry.thread_id);
-          break;
-        case Event::kBlockReserved: {
-          std::uint16_t block_r, block_c;
-          int start_r, start_c, end_r, end_c;
-          GetBlockByIndex(block_map, entry.block_id, &block_r, &block_c);
-          GetBlockMatrixCoords(block_map, block_r, block_c, &start_r, &start_c,
-                               &end_r, &end_c);
-          fprintf(trace_file, "BlockReserved: %.9f, %d, %d, %d, %d, %d, %d\n",
-                  time, entry.thread_id, entry.block_id, start_r, start_c,
-                  end_r, end_c);
-          break;
-        }
-        case Event::kBlockComputedCoords:
-          fprintf(trace_file, "BlockComputedCoords: %.9f, %d, %d\n", time,
-                  entry.thread_id, entry.block_id);
-          break;
-        case Event::kBlockPackedLhs:
-          fprintf(trace_file, "BlockPackedLhs: %.9f, %d, %d\n", time,
-                  entry.thread_id, entry.block_id);
-          break;
-        case Event::kBlockPackedRhs:
-          fprintf(trace_file, "BlockPackedRhs: %.9f, %d, %d\n", time,
-                  entry.thread_id, entry.block_id);
-          break;
-        case Event::kBlockFinished:
-          fprintf(trace_file, "BlockFinished: %.9f, %d, %d\n", time,
-                  entry.thread_id, entry.block_id);
-          break;
-        default:
-          RUY_CHECK(false);
-      }
-    }
-    fprintf(trace_file, "End: %.9f\n", ToSeconds(time_end - time_start));
-    if (trace_filename) {
-      fclose(trace_file);
-    }
-  }
-};
-
-void DumpTrace(const Trace& trace) {
-  ProcessedTrace processed_trace;
-  processed_trace.Process(trace);
-  processed_trace.Dump();
 }
 
+// Sort Trace::entries by ascending time. In case of equal timepoints,
+// sort by some semi-arbitrary ordering of event types.
+void Sort(Trace* trace) {
+  std::sort(std::begin(trace->entries), std::end(trace->entries),
+            [](const TraceEntry& a, const TraceEntry& b) -> bool {
+              return a.time_point < b.time_point ||
+                     (a.time_point == b.time_point &&
+                      static_cast<int>(a.event) < static_cast<int>(b.event));
+            });
+}
+
+// Dump a trace. Assumes that AggregateThreadSpecificEntries and Sort have
+// already been called on it.
+void Dump(const Trace& trace) {
+  const char* trace_filename = getenv("RUY_TRACE_FILE");
+  FILE* trace_file = trace_filename ? fopen(trace_filename, "w") : stderr;
+  if (!trace_file) {
+    fprintf(stderr, "Failed to open %s for write, errno=%d\n", trace_filename,
+            errno);
+    RUY_CHECK(false);
+  }
+  fprintf(trace_file, "thread_count:%d\n", trace.thread_count);
+  fprintf(trace_file, "rows:%d\n", trace.block_map.dims[Side::kLhs]);
+  fprintf(trace_file, "cols:%d\n", trace.block_map.dims[Side::kRhs]);
+  fprintf(trace_file, "Execute: %.9f\n",
+          ToSeconds(trace.time_execute - trace.time_start));
+  for (const TraceEntry& entry : trace.entries) {
+    double time = ToSeconds(entry.time_point - trace.time_start);
+    switch (entry.event) {
+      case TraceEvent::kThreadStart:
+        fprintf(trace_file, "ThreadStart: %.9f, %d\n", time, entry.thread_id);
+        break;
+      case TraceEvent::kThreadLoopStart:
+        fprintf(trace_file, "ThreadLoopStart: %.9f, %d\n", time,
+                entry.thread_id);
+        break;
+      case TraceEvent::kThreadEnd:
+        fprintf(trace_file, "ThreadEnd: %.9f, %d\n", time, entry.thread_id);
+        break;
+      case TraceEvent::kBlockReserved: {
+        std::uint32_t block_id = entry.params[0];
+        SidePair<int> block;
+        GetBlockByIndex(trace.block_map, block_id, &block);
+        SidePair<int> start, end;
+        GetBlockMatrixCoords(trace.block_map, block, &start, &end);
+        fprintf(trace_file,
+                "BlockReserved: %.9f, %d, %d, %d, %d, %d, %d, %d, %d\n", time,
+                entry.thread_id, block_id, block[Side::kLhs], block[Side::kRhs],
+                start[Side::kLhs], start[Side::kRhs], end[Side::kLhs],
+                end[Side::kRhs]);
+        break;
+      }
+      case TraceEvent::kBlockPackedLhs: {
+        std::uint32_t block = entry.params[0];
+        int start, end;
+        GetBlockMatrixCoords(Side::kLhs, trace.block_map, block, &start, &end);
+        fprintf(trace_file, "BlockPackedLhs: %.9f, %d, %d, %d, %d\n", time,
+                entry.thread_id, block, start, end);
+        break;
+      }
+      case TraceEvent::kBlockPackedRhs: {
+        std::uint32_t block = entry.params[0];
+        int start, end;
+        GetBlockMatrixCoords(Side::kRhs, trace.block_map, block, &start, &end);
+        fprintf(trace_file, "BlockPackedRhs: %.9f, %d, %d, %d, %d\n", time,
+                entry.thread_id, block, start, end);
+        break;
+      }
+      case TraceEvent::kBlockFinished: {
+        std::uint32_t block_id = entry.params[0];
+        SidePair<int> block;
+        GetBlockByIndex(trace.block_map, block_id, &block);
+        fprintf(trace_file, "BlockFinished: %.9f, %d, %d, %d, %d\n", time,
+                entry.thread_id, block_id, block[Side::kLhs],
+                block[Side::kRhs]);
+        break;
+      }
+      default:
+        RUY_CHECK(false);
+    }
+  }
+  fprintf(trace_file, "End: %.9f\n",
+          ToSeconds(trace.time_end - trace.time_start));
+  if (trace_filename) {
+    fclose(trace_file);
+  }
+}
+
+}  // anonymous namespace
+
+// Get a Trace object to record to, or null of tracing is not enabled.
 Trace* NewTraceOrNull(TracingContext* tracing, int rows, int depth, int cols) {
   if (!tracing->initialized) {
     tracing->initialized = true;
@@ -254,122 +207,114 @@
   return tracing->trace;
 }
 
+// The trace recorded on a context is finalized and dumped by
+// this TracingContext destructor.
+//
+// The idea of dumping on context destructor is that typically one wants to
+// run many matrix multiplications, e.g. to hit a steady state in terms of
+// performance characteristics, but only trace the last repetition of the
+// workload, when that steady state was attained.
 TracingContext::~TracingContext() {
   if (trace) {
-    DumpTrace(*trace);
+    AggregateThreadSpecificEntries(trace);
+    Sort(trace);
+    Dump(*trace);
   }
   delete trace;
 }
 
+void TraceRecordStart(Trace* trace) {
+  if (trace) {
+    trace->time_start = Clock::now();
+  }
+}
+
+void TraceRecordExecute(const BlockMap& block_map, int thread_count,
+                        Trace* trace) {
+  if (trace) {
+    trace->time_execute = Clock::now();
+    trace->block_map = block_map;
+    trace->thread_count = thread_count;
+    trace->thread_specific_entries.resize(thread_count);
+    for (int thread = 0; thread < thread_count; thread++) {
+      trace->thread_specific_entries[thread].clear();
+      // Reserve some large size to avoid frequent heap allocations
+      // affecting the recorded timings.
+      trace->thread_specific_entries[thread].reserve(16384);
+    }
+  }
+}
+
+void TraceRecordEnd(Trace* trace) {
+  if (trace) {
+    trace->time_end = Clock::now();
+  }
+}
+
 void TraceRecordThreadStart(std::uint32_t thread_id, Trace* trace) {
   if (trace) {
-    RUY_DCHECK(trace->life_stage ==
-               Trace::LifeStage::kRecordingBlockAndThreadFields);
-    relaxed_atomic_store(&trace->block_entries[thread_id].thread_id, thread_id);
-    TimePoint now = Clock::now();
-    relaxed_atomic_store(&trace->block_entries[thread_id].time_reserved, now);
-    relaxed_atomic_store(&trace->thread_entries[thread_id].time_start, now);
+    TraceEntry entry;
+    entry.event = TraceEvent::kThreadStart;
+    entry.time_point = Clock::now();
+    entry.thread_id = thread_id;
+    trace->thread_specific_entries[thread_id].push_back(entry);
   }
 }
 
 void TraceRecordThreadLoopStart(std::uint32_t thread_id, Trace* trace) {
   if (trace) {
-    RUY_DCHECK(trace->life_stage ==
-               Trace::LifeStage::kRecordingBlockAndThreadFields);
-    TimePoint now = Clock::now();
-    relaxed_atomic_store(&trace->thread_entries[thread_id].time_loop_start,
-                         now);
+    TraceEntry entry;
+    entry.event = TraceEvent::kThreadLoopStart;
+    entry.time_point = Clock::now();
+    entry.thread_id = thread_id;
+    trace->thread_specific_entries[thread_id].push_back(entry);
   }
 }
 
 void TraceRecordBlockReserved(std::uint32_t thread_id, std::uint32_t block_id,
                               Trace* trace) {
   if (trace) {
-    RUY_DCHECK(trace->life_stage ==
-               Trace::LifeStage::kRecordingBlockAndThreadFields);
-    // This is typically called on the next block id just obtained by atomic
-    // increment; this may be out of range.
-    if (block_id < trace->block_entries.size()) {
-      relaxed_atomic_store(&trace->block_entries[block_id].thread_id,
-                           thread_id);
-      TimePoint now = Clock::now();
-      relaxed_atomic_store(&trace->block_entries[block_id].time_reserved, now);
-    }
+    TraceEntry entry;
+    entry.event = TraceEvent::kBlockReserved;
+    entry.time_point = Clock::now();
+    entry.thread_id = thread_id;
+    entry.params[0] = block_id;
+    trace->thread_specific_entries[thread_id].push_back(entry);
   }
 }
 
-void TraceRecordBlockCoordsComputed(std::uint32_t block_id, Trace* trace) {
+void TraceRecordBlockPacked(std::uint32_t thread_id, Side side, int block,
+                            Trace* trace) {
   if (trace) {
-    RUY_DCHECK(trace->life_stage ==
-               Trace::LifeStage::kRecordingBlockAndThreadFields);
-    TimePoint now = Clock::now();
-    relaxed_atomic_store(&trace->block_entries[block_id].time_computed_coords,
-                         now);
+    TraceEntry entry;
+    entry.event = side == Side::kLhs ? TraceEvent::kBlockPackedLhs
+                                     : TraceEvent::kBlockPackedRhs;
+    entry.time_point = Clock::now();
+    entry.thread_id = thread_id;
+    entry.params[0] = block;
+    trace->thread_specific_entries[thread_id].push_back(entry);
   }
 }
 
-void TraceRecordBlockPacked(Side side, std::uint32_t block_id, Trace* trace) {
+void TraceRecordBlockFinished(std::uint32_t thread_id, std::uint32_t block_id,
+                              Trace* trace) {
   if (trace) {
-    RUY_DCHECK(trace->life_stage ==
-               Trace::LifeStage::kRecordingBlockAndThreadFields);
-    TimePoint now = Clock::now();
-    relaxed_atomic_store(&trace->block_entries[block_id].time_packed[side],
-                         now);
-  }
-}
-
-void TraceRecordBlockFinished(std::uint32_t block_id, Trace* trace) {
-  if (trace) {
-    RUY_DCHECK(trace->life_stage ==
-               Trace::LifeStage::kRecordingBlockAndThreadFields);
-    TimePoint now = Clock::now();
-    relaxed_atomic_store(&trace->block_entries[block_id].time_finished, now);
+    TraceEntry entry;
+    entry.event = TraceEvent::kBlockFinished;
+    entry.time_point = Clock::now();
+    entry.thread_id = thread_id;
+    entry.params[0] = block_id;
+    trace->thread_specific_entries[thread_id].push_back(entry);
   }
 }
 
 void TraceRecordThreadEnd(std::uint32_t thread_id, Trace* trace) {
   if (trace) {
-    RUY_DCHECK(trace->life_stage ==
-               Trace::LifeStage::kRecordingBlockAndThreadFields);
-    TimePoint now = Clock::now();
-    relaxed_atomic_store(&trace->thread_entries[thread_id].time_end, now);
-  }
-}
-
-void TraceRecordStart(Trace* trace) {
-  if (trace) {
-    RUY_DCHECK(trace->life_stage == Trace::LifeStage::kInitial ||
-               trace->life_stage == Trace::LifeStage::kComplete);
-    TimePoint now = Clock::now();
-    relaxed_atomic_store(&trace->time_start, now);
-    trace->life_stage = Trace::LifeStage::kRecordingRootFields;
-  }
-}
-
-void TraceRecordExecute(Trace* trace) {
-  if (trace) {
-    RUY_DCHECK(trace->life_stage == Trace::LifeStage::kRecordingRootFields);
-    TimePoint now = Clock::now();
-    relaxed_atomic_store(&trace->time_execute, now);
-  }
-}
-
-void TraceRecordEnd(Trace* trace) {
-  if (trace) {
-    RUY_DCHECK(trace->life_stage ==
-               Trace::LifeStage::kRecordingBlockAndThreadFields);
-    TimePoint now = Clock::now();
-    relaxed_atomic_store(&trace->time_end, now);
-    trace->life_stage = Trace::LifeStage::kComplete;
-  }
-}
-
-void TraceStartRecordingBlockAndThreadFields(const BlockMap& block_map,
-                                             int thread_count, Trace* trace) {
-  if (trace) {
-    RUY_DCHECK(trace->life_stage == Trace::LifeStage::kRecordingRootFields);
-    trace->StartRecordingBlockAndThreadFields(block_map, thread_count);
-    trace->life_stage = Trace::LifeStage::kRecordingBlockAndThreadFields;
+    TraceEntry entry;
+    entry.event = TraceEvent::kThreadEnd;
+    entry.time_point = Clock::now();
+    entry.thread_id = thread_id;
+    trace->thread_specific_entries[thread_id].push_back(entry);
   }
 }