[libFuzzer] properly intercept memmem

llvm-svn: 276006
diff --git a/llvm/lib/Fuzzer/FuzzerTraceState.cpp b/llvm/lib/Fuzzer/FuzzerTraceState.cpp
index d6e1f79..6f87fca 100644
--- a/llvm/lib/Fuzzer/FuzzerTraceState.cpp
+++ b/llvm/lib/Fuzzer/FuzzerTraceState.cpp
@@ -173,6 +173,12 @@
 static bool RecordingTraces = false;
 static bool RecordingMemcmp = false;
 static bool RecordingMemmem = false;
+static bool DoingMyOwnMemmem = false;
+
+struct ScopedDoingMyOwnMemmem {
+  ScopedDoingMyOwnMemmem() { DoingMyOwnMemmem = true; }
+  ~ScopedDoingMyOwnMemmem() { DoingMyOwnMemmem = false; }
+};
 
 class TraceState {
 public:
@@ -400,6 +406,7 @@
 int TraceState::TryToAddDesiredData(uint64_t PresentData, uint64_t DesiredData,
                                     size_t DataSize) {
   if (NumMutations >= kMaxMutations || !WantToHandleOneMoreMutation()) return 0;
+  ScopedDoingMyOwnMemmem scoped_doing_my_own_memmem;
   const uint8_t *UnitData;
   auto UnitSize = F->GetCurrentUnitInFuzzingThead(&UnitData);
   int Res = 0;
@@ -423,6 +430,7 @@
                                     const uint8_t *DesiredData,
                                     size_t DataSize) {
   if (NumMutations >= kMaxMutations || !WantToHandleOneMoreMutation()) return 0;
+  ScopedDoingMyOwnMemmem scoped_doing_my_own_memmem;
   const uint8_t *UnitData;
   auto UnitSize = F->GetCurrentUnitInFuzzingThead(&UnitData);
   int Res = 0;
@@ -639,7 +647,8 @@
 }
 void __sanitizer_weak_hook_memmem(void *called_pc, const void *s1, size_t len1,
                                   const void *s2, size_t len2, void *result) {
-  // TODO: can't hook memmem since memmem is used by libFuzzer.
+  if (fuzzer::DoingMyOwnMemmem) return;
+  TS->AddInterestingWord(reinterpret_cast<const uint8_t *>(s2), len2);
 }
 
 #endif  // LLVM_FUZZER_DEFINES_SANITIZER_WEAK_HOOOKS
diff --git a/llvm/lib/Fuzzer/test/StrstrTest.cpp b/llvm/lib/Fuzzer/test/StrstrTest.cpp
index 90d539b6..dd83953 100644
--- a/llvm/lib/Fuzzer/test/StrstrTest.cpp
+++ b/llvm/lib/Fuzzer/test/StrstrTest.cpp
@@ -9,8 +9,12 @@
 #include <cstdlib>
 
 extern "C" int LLVMFuzzerTestOneInput(const uint8_t *Data, size_t Size) {
+  if (Size < 4) return 0;
   std::string s(reinterpret_cast<const char*>(Data), Size);
-  if (strstr(s.c_str(), "FUZZ") && strcasestr(s.c_str(), "aBcD")) {
+  if (strstr(s.c_str(), "FUZZ") &&
+      strcasestr(s.c_str(), "aBcD") &&
+      memmem(s.data(), s.size(), "kuku", 4)
+      ) {
     fprintf(stderr, "BINGO\n");
     exit(1);
   }