[llvm-profdata] Add overlap command to compute similarity b/w two profile files
Add overlap functionality to llvm-profdata tool to compute the similarity
between two profile files.
Differential Revision: https://reviews.llvm.org/D60977
llvm-svn: 359612
diff --git a/llvm/lib/ProfileData/InstrProf.cpp b/llvm/lib/ProfileData/InstrProf.cpp
index 91a3d17..560e390 100644
--- a/llvm/lib/ProfileData/InstrProf.cpp
+++ b/llvm/lib/ProfileData/InstrProf.cpp
@@ -29,6 +29,7 @@
 #include "llvm/IR/Metadata.h"
 #include "llvm/IR/Module.h"
 #include "llvm/IR/Type.h"
+#include "llvm/ProfileData/InstrProfReader.h"
 #include "llvm/Support/Casting.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Compiler.h"
@@ -478,6 +479,127 @@
   return Error::success();
 }
 
+void InstrProfRecord::accumuateCounts(CountSumOrPercent &Sum) const {
+  uint64_t FuncSum = 0;
+  Sum.NumEntries += Counts.size();
+  for (size_t F = 0, E = Counts.size(); F < E; ++F)
+    FuncSum += Counts[F];
+  Sum.CountSum += FuncSum;
+
+  for (uint32_t VK = IPVK_First; VK <= IPVK_Last; ++VK) {
+    uint64_t KindSum = 0;
+    uint32_t NumValueSites = getNumValueSites(VK);
+    for (size_t I = 0; I < NumValueSites; ++I) {
+      uint32_t NV = getNumValueDataForSite(VK, I);
+      std::unique_ptr<InstrProfValueData[]> VD = getValueForSite(VK, I);
+      for (uint32_t V = 0; V < NV; V++)
+        KindSum += VD[V].Count;
+    }
+    Sum.ValueCounts[VK] += KindSum;
+  }
+}
+
+void InstrProfValueSiteRecord::overlap(InstrProfValueSiteRecord &Input,
+                                       uint32_t ValueKind,
+                                       OverlapStats &Overlap,
+                                       OverlapStats &FuncLevelOverlap) {
+  this->sortByTargetValues();
+  Input.sortByTargetValues();
+  double Score = 0.0f, FuncLevelScore = 0.0f;
+  auto I = ValueData.begin();
+  auto IE = ValueData.end();
+  auto J = Input.ValueData.begin();
+  auto JE = Input.ValueData.end();
+  while (I != IE && J != JE) {
+    if (I->Value == J->Value) {
+      Score += OverlapStats::score(I->Count, J->Count,
+                                   Overlap.Base.ValueCounts[ValueKind],
+                                   Overlap.Test.ValueCounts[ValueKind]);
+      FuncLevelScore += OverlapStats::score(
+          I->Count, J->Count, FuncLevelOverlap.Base.ValueCounts[ValueKind],
+          FuncLevelOverlap.Test.ValueCounts[ValueKind]);
+      ++I;
+    } else if (I->Value < J->Value) {
+      ++I;
+      continue;
+    }
+    ++J;
+  }
+  Overlap.Overlap.ValueCounts[ValueKind] += Score;
+  FuncLevelOverlap.Overlap.ValueCounts[ValueKind] += FuncLevelScore;
+}
+
+// Return false on mismatch.
+void InstrProfRecord::overlapValueProfData(uint32_t ValueKind,
+                                           InstrProfRecord &Other,
+                                           OverlapStats &Overlap,
+                                           OverlapStats &FuncLevelOverlap) {
+  uint32_t ThisNumValueSites = getNumValueSites(ValueKind);
+  uint32_t OtherNumValueSites = Other.getNumValueSites(ValueKind);
+  assert(ThisNumValueSites == OtherNumValueSites);
+  if (!ThisNumValueSites)
+    return;
+
+  std::vector<InstrProfValueSiteRecord> &ThisSiteRecords =
+      getOrCreateValueSitesForKind(ValueKind);
+  MutableArrayRef<InstrProfValueSiteRecord> OtherSiteRecords =
+      Other.getValueSitesForKind(ValueKind);
+  for (uint32_t I = 0; I < ThisNumValueSites; I++)
+    ThisSiteRecords[I].overlap(OtherSiteRecords[I], ValueKind, Overlap,
+                               FuncLevelOverlap);
+}
+
+void InstrProfRecord::overlap(InstrProfRecord &Other, OverlapStats &Overlap,
+                              OverlapStats &FuncLevelOverlap,
+                              uint64_t ValueCutoff) {
+  // FuncLevel CountSum for other should already computed and nonzero.
+  assert(FuncLevelOverlap.Test.CountSum >= 1.0f);
+  accumuateCounts(FuncLevelOverlap.Base);
+  bool Mismatch = (Counts.size() != Other.Counts.size());
+
+  // Check if the value profiles mismatch.
+  if (!Mismatch) {
+    for (uint32_t Kind = IPVK_First; Kind <= IPVK_Last; ++Kind) {
+      uint32_t ThisNumValueSites = getNumValueSites(Kind);
+      uint32_t OtherNumValueSites = Other.getNumValueSites(Kind);
+      if (ThisNumValueSites != OtherNumValueSites) {
+        Mismatch = true;
+        break;
+      }
+    }
+  }
+  if (Mismatch) {
+    Overlap.addOneMismatch(FuncLevelOverlap.Test);
+    return;
+  }
+
+  // Compute overlap for value counts.
+  for (uint32_t Kind = IPVK_First; Kind <= IPVK_Last; ++Kind)
+    overlapValueProfData(Kind, Other, Overlap, FuncLevelOverlap);
+
+  double Score = 0.0;
+  uint64_t MaxCount = 0;
+  // Compute overlap for edge counts.
+  for (size_t I = 0, E = Other.Counts.size(); I < E; ++I) {
+    Score += OverlapStats::score(Counts[I], Other.Counts[I],
+                                 Overlap.Base.CountSum, Overlap.Test.CountSum);
+    MaxCount = std::max(Other.Counts[I], MaxCount);
+  }
+  Overlap.Overlap.CountSum += Score;
+  Overlap.Overlap.NumEntries += 1;
+
+  if (MaxCount >= ValueCutoff) {
+    double FuncScore = 0.0;
+    for (size_t I = 0, E = Other.Counts.size(); I < E; ++I)
+      FuncScore += OverlapStats::score(Counts[I], Other.Counts[I],
+                                       FuncLevelOverlap.Base.CountSum,
+                                       FuncLevelOverlap.Test.CountSum);
+    FuncLevelOverlap.Overlap.CountSum = FuncScore;
+    FuncLevelOverlap.Overlap.NumEntries = Other.Counts.size();
+    FuncLevelOverlap.Valid = true;
+  }
+}
+
 void InstrProfValueSiteRecord::merge(InstrProfValueSiteRecord &Input,
                                      uint64_t Weight,
                                      function_ref<void(instrprof_error)> Warn) {
@@ -1046,4 +1168,117 @@
   }
 }
 
+Error OverlapStats::accumuateCounts(const std::string &BaseFilename,
+                                    const std::string &TestFilename,
+                                    bool IsCS) {
+  auto getProfileSum = [IsCS](const std::string &Filename,
+                              CountSumOrPercent &Sum) -> Error {
+    auto ReaderOrErr = InstrProfReader::create(Filename);
+    if (Error E = ReaderOrErr.takeError()) {
+      return E;
+    }
+    auto Reader = std::move(ReaderOrErr.get());
+    Reader->accumuateCounts(Sum, IsCS);
+    return Error::success();
+  };
+  auto Ret = getProfileSum(BaseFilename, Base);
+  if (Ret)
+    return std::move(Ret);
+  Ret = getProfileSum(TestFilename, Test);
+  if (Ret)
+    return std::move(Ret);
+  this->BaseFilename = &BaseFilename;
+  this->TestFilename = &TestFilename;
+  Valid = true;
+  return Error::success();
+}
+
+void OverlapStats::addOneMismatch(const CountSumOrPercent &MismatchFunc) {
+  Mismatch.NumEntries += 1;
+  Mismatch.CountSum += MismatchFunc.CountSum / Test.CountSum;
+  for (unsigned I = 0; I < IPVK_Last - IPVK_First + 1; I++) {
+    if (Test.ValueCounts[I] >= 1.0f)
+      Mismatch.ValueCounts[I] +=
+          MismatchFunc.ValueCounts[I] / Test.ValueCounts[I];
+  }
+}
+
+void OverlapStats::addOneUnique(const CountSumOrPercent &UniqueFunc) {
+  Unique.NumEntries += 1;
+  Unique.CountSum += UniqueFunc.CountSum / Test.CountSum;
+  for (unsigned I = 0; I < IPVK_Last - IPVK_First + 1; I++) {
+    if (Test.ValueCounts[I] >= 1.0f)
+      Unique.ValueCounts[I] += UniqueFunc.ValueCounts[I] / Test.ValueCounts[I];
+  }
+}
+
+void OverlapStats::dump(raw_fd_ostream &OS) const {
+  if (!Valid)
+    return;
+
+  const char *EntryName =
+      (Level == ProgramLevel ? "functions" : "edge counters");
+  if (Level == ProgramLevel) {
+    OS << "Profile overlap infomation for base_profile: " << *BaseFilename
+       << " and test_profile: " << *TestFilename << "\nProgram level:\n";
+  } else {
+    OS << "Function level:\n"
+       << "  Function: " << FuncName << " (Hash=" << FuncHash << ")\n";
+  }
+
+  OS << "  # of " << EntryName << " overlap: " << Overlap.NumEntries << "\n";
+  if (Mismatch.NumEntries)
+    OS << "  # of " << EntryName << " mismatch: " << Mismatch.NumEntries
+       << "\n";
+  if (Unique.NumEntries)
+    OS << "  # of " << EntryName
+       << " only in test_profile: " << Unique.NumEntries << "\n";
+
+  OS << "  Edge profile overlap: " << format("%.3f%%", Overlap.CountSum * 100)
+     << "\n";
+  if (Mismatch.NumEntries)
+    OS << "  Mismatched count percentage (Edge): "
+       << format("%.3f%%", Mismatch.CountSum * 100) << "\n";
+  if (Unique.NumEntries)
+    OS << "  Percentage of Edge profile only in test_profile: "
+       << format("%.3f%%", Unique.CountSum * 100) << "\n";
+  OS << "  Edge profile base count sum: " << format("%.0f", Base.CountSum)
+     << "\n"
+     << "  Edge profile test count sum: " << format("%.0f", Test.CountSum)
+     << "\n";
+
+  for (unsigned I = 0; I < IPVK_Last - IPVK_First + 1; I++) {
+    if (Base.ValueCounts[I] < 1.0f && Test.ValueCounts[I] < 1.0f)
+      continue;
+    char ProfileKindName[20];
+    switch (I) {
+    case IPVK_IndirectCallTarget:
+      strncpy(ProfileKindName, "IndirectCall", 19);
+      break;
+    case IPVK_MemOPSize:
+      strncpy(ProfileKindName, "MemOP", 19);
+      break;
+    default:
+      snprintf(ProfileKindName, 19, "VP[%d]", I);
+      break;
+    }
+    OS << "  " << ProfileKindName
+       << " profile overlap: " << format("%.3f%%", Overlap.ValueCounts[I] * 100)
+       << "\n";
+    if (Mismatch.NumEntries)
+      OS << "  Mismatched count percentage (" << ProfileKindName
+         << "): " << format("%.3f%%", Mismatch.ValueCounts[I] * 100) << "\n";
+    if (Unique.NumEntries)
+      OS << "  Percentage of " << ProfileKindName
+         << " profile only in test_profile: "
+         << format("%.3f%%", Unique.ValueCounts[I] * 100) << "\n";
+    OS << "  " << ProfileKindName
+       << " profile base count sum: " << format("%.0f", Base.ValueCounts[I])
+       << "\n"
+       << "  " << ProfileKindName
+       << " profile test count sum: " << format("%.0f", Test.ValueCounts[I])
+       << "\n";
+  }
+}
+
 } // end namespace llvm
diff --git a/llvm/lib/ProfileData/InstrProfReader.cpp b/llvm/lib/ProfileData/InstrProfReader.cpp
index 9c2d364..074319b 100644
--- a/llvm/lib/ProfileData/InstrProfReader.cpp
+++ b/llvm/lib/ProfileData/InstrProfReader.cpp
@@ -900,3 +900,17 @@
   }
   return success();
 }
+
+void InstrProfReader::accumuateCounts(CountSumOrPercent &Sum, bool IsCS) {
+  uint64_t NumFuncs = 0;
+  for (const auto &Func : *this) {
+    if (isIRLevelProfile()) {
+      bool FuncIsCS = NamedInstrProfRecord::hasCSFlagInHash(Func.Hash);
+      if (FuncIsCS != IsCS)
+        continue;
+    }
+    Func.accumuateCounts(Sum);
+    ++NumFuncs;
+  }
+  Sum.NumEntries = NumFuncs;
+}
diff --git a/llvm/lib/ProfileData/InstrProfWriter.cpp b/llvm/lib/ProfileData/InstrProfWriter.cpp
index b9a0060..4ca2def 100644
--- a/llvm/lib/ProfileData/InstrProfWriter.cpp
+++ b/llvm/lib/ProfileData/InstrProfWriter.cpp
@@ -187,6 +187,40 @@
   addRecord(Name, Hash, std::move(I), Weight, Warn);
 }
 
+void InstrProfWriter::overlapRecord(NamedInstrProfRecord &&Other,
+                                    OverlapStats &Overlap,
+                                    OverlapStats &FuncLevelOverlap,
+                                    const OverlapFuncFilters &FuncFilter) {
+  auto Name = Other.Name;
+  auto Hash = Other.Hash;
+  Other.accumuateCounts(FuncLevelOverlap.Test);
+  if (FunctionData.find(Name) == FunctionData.end()) {
+    Overlap.addOneUnique(FuncLevelOverlap.Test);
+    return;
+  }
+  if (FuncLevelOverlap.Test.CountSum < 1.0f) {
+    Overlap.Overlap.NumEntries += 1;
+    return;
+  }
+  auto &ProfileDataMap = FunctionData[Name];
+  bool NewFunc;
+  ProfilingData::iterator Where;
+  std::tie(Where, NewFunc) =
+      ProfileDataMap.insert(std::make_pair(Hash, InstrProfRecord()));
+  if (NewFunc) {
+    Overlap.addOneMismatch(FuncLevelOverlap.Test);
+    return;
+  }
+  InstrProfRecord &Dest = Where->second;
+
+  uint64_t ValueCutoff = FuncFilter.ValueCutoff;
+  if (!FuncFilter.NameFilter.empty() &&
+      Name.find(FuncFilter.NameFilter) != Name.npos)
+    ValueCutoff = 0;
+
+  Dest.overlap(Other, Overlap, FuncLevelOverlap, ValueCutoff);
+}
+
 void InstrProfWriter::addRecord(StringRef Name, uint64_t Hash,
                                 InstrProfRecord &&I, uint64_t Weight,
                                 function_ref<void(Error)> Warn) {