Change ClassFileLoadHook to lazily compute dex file

Creating a dex file from the quickened or compact-dex'd internal
format for calling the JVMTI ClassFileLoadHook is quite expensive.
This meant that agents could not generally listen for this event
without causing unacceptable performance problems.

Since agents will generally not touch the buffer, needing to
instrument only a handful of classes we will fix this problem by doing
the de-quickening lazily. This is done by mmaping an empty buffer with
PROT_NONE and then filling it in when the program has a SEGV in the
appropriate address range. This should improve the performance of any
agent that listens for the ClassFileLoadHook but does not commonly do
anything to the buffer.

This does have the disadvantage that we can no longer ensure that the
buffer size we pass down in class_data_len might no longer be fully
accurate. Some agents that assert that class_data_len is exactly the
same as the dex-file will need to be updated.

Bug: 72402467
Bug: 72064989
Test: ./test.py --host -j50
Test: ./test.py --host --redefine-stress -j50

Change-Id: I39354837f1417ae10a57c5b42cbb4f38f8a563dc
diff --git a/openjdkjvmti/OpenjdkJvmTi.cc b/openjdkjvmti/OpenjdkJvmTi.cc
index 027635b..a0c7f40 100644
--- a/openjdkjvmti/OpenjdkJvmTi.cc
+++ b/openjdkjvmti/OpenjdkJvmTi.cc
@@ -1535,6 +1535,7 @@
   MethodUtil::Register(&gEventHandler);
   SearchUtil::Register();
   HeapUtil::Register();
+  Transformer::Setup();
 
   {
     // Make sure we can deopt anything we need to.
diff --git a/openjdkjvmti/ti_class_definition.cc b/openjdkjvmti/ti_class_definition.cc
index 7f2f800..1b641cd 100644
--- a/openjdkjvmti/ti_class_definition.cc
+++ b/openjdkjvmti/ti_class_definition.cc
@@ -45,6 +45,31 @@
 
 namespace openjdkjvmti {
 
+void ArtClassDefinition::InitializeMemory() const {
+  DCHECK(art::MemMap::kCanReplaceMapping);
+  VLOG(signals) << "Initializing de-quickened memory for dex file of " << name_;
+  CHECK(dex_data_mmap_ != nullptr);
+  CHECK(temp_mmap_ != nullptr);
+  CHECK_EQ(dex_data_mmap_->GetProtect(), PROT_NONE);
+  CHECK_EQ(temp_mmap_->GetProtect(), PROT_READ | PROT_WRITE);
+
+  std::string desc = std::string("L") + name_ + ";";
+  std::unique_ptr<FixedUpDexFile>
+      fixed_dex_file(FixedUpDexFile::Create(*initial_dex_file_unquickened_, desc.c_str()));
+  CHECK(fixed_dex_file.get() != nullptr);
+  CHECK_LE(fixed_dex_file->Size(), temp_mmap_->Size());
+  CHECK_EQ(temp_mmap_->Size(), dex_data_mmap_->Size());
+  // Copy the data to the temp mmap.
+  memcpy(temp_mmap_->Begin(), fixed_dex_file->Begin(), fixed_dex_file->Size());
+
+  // Move the mmap atomically.
+  art::MemMap* source = temp_mmap_.release();
+  std::string error;
+  CHECK(dex_data_mmap_->ReplaceWith(&source, &error)) << "Failed to replace mmap for "
+                                                      << name_ << " because " << error;
+  CHECK(dex_data_mmap_->Protect(PROT_READ));
+}
+
 bool ArtClassDefinition::IsModified() const {
   // RedefineClasses calls always are 'modified' since they need to change the current_dex_file of
   // the class.
@@ -58,6 +83,27 @@
     return false;
   }
 
+  // The dex_data_ was never touched by the agents.
+  if (dex_data_mmap_ != nullptr && dex_data_mmap_->GetProtect() == PROT_NONE) {
+    if (current_dex_file_.data() == dex_data_mmap_->Begin()) {
+      // the dex_data_ looks like it changed (not equal to current_dex_file_) but we never
+      // initialized the dex_data_mmap_. This means the new_dex_data was filled in without looking
+      // at the initial dex_data_.
+      return true;
+    } else if (dex_data_.data() == dex_data_mmap_->Begin()) {
+      // The dex file used to have modifications but they were not added again.
+      return true;
+    } else {
+      // It's not clear what happened. It's possible that the agent got the current dex file data
+      // from some other source so we need to initialize everything to see if it is the same.
+      VLOG(signals) << "Lazy dex file for " << name_ << " was never touched but the dex_data_ is "
+                    << "changed! Need to initialize the memory to see if anything changed";
+      InitializeMemory();
+    }
+  }
+
+  // We can definitely read current_dex_file_ and dex_file_ without causing page faults.
+
   // Check if the dex file we want to set is the same as the current one.
   // Unfortunately we need to do this check even if no modifications have been done since it could
   // be that agents were removed in the mean-time so we still have a different dex file. The dex
@@ -194,6 +240,53 @@
                                      const art::DexFile* quick_dex) {
   art::Thread* self = art::Thread::Current();
   DCHECK(quick_dex != nullptr);
+  if (art::MemMap::kCanReplaceMapping && kEnableOnDemandDexDequicken) {
+    size_t dequick_size = quick_dex->GetDequickenedSize();
+    std::string mmap_name("anon-mmap-for-redefine: ");
+    mmap_name += name_;
+    std::string error;
+    dex_data_mmap_.reset(art::MemMap::MapAnonymous(mmap_name.c_str(),
+                                                   nullptr,
+                                                   dequick_size,
+                                                   PROT_NONE,
+                                                   /*low_4gb*/ false,
+                                                   /*reuse*/ false,
+                                                   &error));
+    mmap_name += "-TEMP";
+    temp_mmap_.reset(art::MemMap::MapAnonymous(mmap_name.c_str(),
+                                               nullptr,
+                                               dequick_size,
+                                               PROT_READ | PROT_WRITE,
+                                               /*low_4gb*/ false,
+                                               /*reuse*/ false,
+                                               &error));
+    if (UNLIKELY(dex_data_mmap_ != nullptr && temp_mmap_ != nullptr)) {
+      // Need to save the initial dexfile so we don't need to search for it in the fault-handler.
+      initial_dex_file_unquickened_ = quick_dex;
+      dex_data_ = art::ArrayRef<const unsigned char>(dex_data_mmap_->Begin(),
+                                                     dex_data_mmap_->Size());
+      if (from_class_ext_) {
+        // We got initial from class_ext so the current one must have undergone redefinition so no
+        // cdex or quickening stuff.
+        // We can only do this if it's not a first load.
+        DCHECK(klass_ != nullptr);
+        const art::DexFile& cur_dex = self->DecodeJObject(klass_)->AsClass()->GetDexFile();
+        current_dex_file_ = art::ArrayRef<const unsigned char>(cur_dex.Begin(), cur_dex.Size());
+      } else {
+        // This class hasn't been redefined before. The dequickened current data is the same as the
+        // dex_data_mmap_ when it's filled it. We don't need to copy anything because the mmap will
+        // not be cleared until after everything is done.
+        current_dex_file_ = art::ArrayRef<const unsigned char>(dex_data_mmap_->Begin(),
+                                                               dequick_size);
+      }
+      return;
+    }
+  }
+  dex_data_mmap_.reset(nullptr);
+  temp_mmap_.reset(nullptr);
+  // Failed to mmap a large enough area (or on-demand dequickening was disabled). This is
+  // unfortunate. Since currently the size is just a guess though we might as well try to do it
+  // manually.
   get_original(/*out*/&dex_data_memory_);
   dex_data_ = art::ArrayRef<const unsigned char>(dex_data_memory_);
   if (from_class_ext_) {
diff --git a/openjdkjvmti/ti_class_definition.h b/openjdkjvmti/ti_class_definition.h
index 0084739..31c3611 100644
--- a/openjdkjvmti/ti_class_definition.h
+++ b/openjdkjvmti/ti_class_definition.h
@@ -32,9 +32,14 @@
 #ifndef ART_OPENJDKJVMTI_TI_CLASS_DEFINITION_H_
 #define ART_OPENJDKJVMTI_TI_CLASS_DEFINITION_H_
 
+#include <stddef.h>
+#include <sys/mman.h>
+#include <sys/types.h>
+
 #include "art_jvmti.h"
 
 #include "base/array_ref.h"
+#include "mem_map.h"
 
 namespace openjdkjvmti {
 
@@ -43,13 +48,20 @@
 // redefinition/retransformation function that created it.
 class ArtClassDefinition {
  public:
+  // If we support doing a on-demand dex-dequickening using signal handlers.
+  static constexpr bool kEnableOnDemandDexDequicken = true;
+
   ArtClassDefinition()
       : klass_(nullptr),
         loader_(nullptr),
         name_(),
         protection_domain_(nullptr),
+        dex_data_mmap_(nullptr),
+        temp_mmap_(nullptr),
         dex_data_memory_(),
+        initial_dex_file_unquickened_(nullptr),
         dex_data_(),
+        current_dex_memory_(),
         current_dex_file_(),
         redefined_(false),
         from_class_ext_(false),
@@ -87,6 +99,12 @@
     }
   }
 
+  bool ContainsAddress(uintptr_t ptr) const {
+    return dex_data_mmap_ != nullptr &&
+        reinterpret_cast<uintptr_t>(dex_data_mmap_->Begin()) <= ptr &&
+        reinterpret_cast<uintptr_t>(dex_data_mmap_->End()) > ptr;
+  }
+
   bool IsModified() const REQUIRES_SHARED(art::Locks::mutator_lock_);
 
   bool IsInitialized() const {
@@ -108,6 +126,13 @@
     return name_;
   }
 
+  bool IsLazyDefinition() const {
+    DCHECK(IsInitialized());
+    return dex_data_mmap_ != nullptr &&
+        dex_data_.data() == dex_data_mmap_->Begin() &&
+        dex_data_mmap_->GetProtect() == PROT_NONE;
+  }
+
   jobject GetProtectionDomain() const {
     DCHECK(IsInitialized());
     return protection_domain_;
@@ -118,6 +143,8 @@
     return dex_data_;
   }
 
+  void InitializeMemory() const;
+
  private:
   jvmtiError InitCommon(art::Thread* self, jclass klass);
 
@@ -130,9 +157,17 @@
   std::string name_;
   jobject protection_domain_;
 
+  // Mmap that will be filled with the original-dex-file lazily if it needs to be de-quickened or
+  // de-compact-dex'd
+  mutable std::unique_ptr<art::MemMap> dex_data_mmap_;
+  // This is a temporary mmap we will use to be able to fill the dex file data atomically.
+  mutable std::unique_ptr<art::MemMap> temp_mmap_;
+
   // A unique_ptr to the current dex_data if it needs to be cleaned up.
   std::vector<unsigned char> dex_data_memory_;
 
+  const art::DexFile* initial_dex_file_unquickened_;
+
   // A ref to the current dex data. This is either dex_data_memory_, or current_dex_file_. This is
   // what the dex file will be turned into.
   art::ArrayRef<const unsigned char> dex_data_;
diff --git a/openjdkjvmti/transform.cc b/openjdkjvmti/transform.cc
index 8445eca..d98b385 100644
--- a/openjdkjvmti/transform.cc
+++ b/openjdkjvmti/transform.cc
@@ -29,6 +29,9 @@
  * questions.
  */
 
+#include <stddef.h>
+#include <sys/types.h>
+
 #include <unordered_map>
 #include <unordered_set>
 
@@ -40,6 +43,7 @@
 #include "dex/dex_file.h"
 #include "dex/dex_file_types.h"
 #include "events-inl.h"
+#include "fault_handler.h"
 #include "gc_root-inl.h"
 #include "globals.h"
 #include "jni_env_ext-inl.h"
@@ -63,6 +67,174 @@
 
 namespace openjdkjvmti {
 
+// A FaultHandler that will deal with initializing ClassDefinitions when they are actually needed.
+class TransformationFaultHandler FINAL : public art::FaultHandler {
+ public:
+  explicit TransformationFaultHandler(art::FaultManager* manager)
+      : art::FaultHandler(manager),
+        uninitialized_class_definitions_lock_("JVMTI Initialized class definitions lock",
+                                              art::LockLevel::kSignalHandlingLock),
+        class_definition_initialized_cond_("JVMTI Initialized class definitions condition",
+                                           uninitialized_class_definitions_lock_) {
+    manager->AddHandler(this, /* generated_code */ false);
+  }
+
+  ~TransformationFaultHandler() {
+    art::MutexLock mu(art::Thread::Current(), uninitialized_class_definitions_lock_);
+    uninitialized_class_definitions_.clear();
+  }
+
+  bool Action(int sig, siginfo_t* siginfo, void* context ATTRIBUTE_UNUSED) OVERRIDE {
+    DCHECK_EQ(sig, SIGSEGV);
+    art::Thread* self = art::Thread::Current();
+    if (UNLIKELY(uninitialized_class_definitions_lock_.IsExclusiveHeld(self))) {
+      if (self != nullptr) {
+        LOG(FATAL) << "Recursive call into Transformation fault handler!";
+        UNREACHABLE();
+      } else {
+        LOG(ERROR) << "Possible deadlock due to recursive signal delivery of segv.";
+      }
+    }
+    uintptr_t ptr = reinterpret_cast<uintptr_t>(siginfo->si_addr);
+    ArtClassDefinition* res = nullptr;
+
+    {
+      // NB Technically using a mutex and condition variables here is non-posix compliant but
+      // everything should be fine since both glibc and bionic implementations of mutexs and
+      // condition variables work fine so long as the thread was not interrupted during a
+      // lock/unlock (which it wasn't) on all architectures we care about.
+      art::MutexLock mu(self, uninitialized_class_definitions_lock_);
+      auto it = std::find_if(uninitialized_class_definitions_.begin(),
+                             uninitialized_class_definitions_.end(),
+                             [&](const auto op) { return op->ContainsAddress(ptr); });
+      if (it != uninitialized_class_definitions_.end()) {
+        res = *it;
+        // Remove the class definition.
+        uninitialized_class_definitions_.erase(it);
+        // Put it in the initializing list
+        initializing_class_definitions_.push_back(res);
+      } else {
+        // Wait for the ptr to be initialized (if it is currently initializing).
+        while (DefinitionIsInitializing(ptr)) {
+          WaitForClassInitializationToFinish();
+        }
+        // Return true (continue with user code) if we find that the definition has been
+        // initialized. Return false (continue on to next signal handler) if the definition is not
+        // initialized or found.
+        return std::find_if(initialized_class_definitions_.begin(),
+                            initialized_class_definitions_.end(),
+                            [&](const auto op) { return op->ContainsAddress(ptr); }) !=
+            uninitialized_class_definitions_.end();
+      }
+    }
+
+    VLOG(signals) << "Lazy initialization of dex file for transformation of " << res->GetName()
+                  << " during SEGV";
+    res->InitializeMemory();
+
+    {
+      art::MutexLock mu(self, uninitialized_class_definitions_lock_);
+      // Move to initialized state and notify waiters.
+      initializing_class_definitions_.erase(std::find(initializing_class_definitions_.begin(),
+                                                      initializing_class_definitions_.end(),
+                                                      res));
+      initialized_class_definitions_.push_back(res);
+      class_definition_initialized_cond_.Broadcast(self);
+    }
+
+    return true;
+  }
+
+  void RemoveDefinition(ArtClassDefinition* def) REQUIRES(!uninitialized_class_definitions_lock_) {
+    art::MutexLock mu(art::Thread::Current(), uninitialized_class_definitions_lock_);
+    auto it = std::find(uninitialized_class_definitions_.begin(),
+                        uninitialized_class_definitions_.end(),
+                        def);
+    if (it != uninitialized_class_definitions_.end()) {
+      uninitialized_class_definitions_.erase(it);
+      return;
+    }
+    while (std::find(initializing_class_definitions_.begin(),
+                     initializing_class_definitions_.end(),
+                     def) != initializing_class_definitions_.end()) {
+      WaitForClassInitializationToFinish();
+    }
+    it = std::find(initialized_class_definitions_.begin(),
+                   initialized_class_definitions_.end(),
+                   def);
+    CHECK(it != initialized_class_definitions_.end()) << "Could not find class definition for "
+                                                      << def->GetName();
+    initialized_class_definitions_.erase(it);
+  }
+
+  void AddArtDefinition(ArtClassDefinition* def) REQUIRES(!uninitialized_class_definitions_lock_) {
+    DCHECK(def->IsLazyDefinition());
+    art::MutexLock mu(art::Thread::Current(), uninitialized_class_definitions_lock_);
+    uninitialized_class_definitions_.push_back(def);
+  }
+
+ private:
+  bool DefinitionIsInitializing(uintptr_t ptr) REQUIRES(uninitialized_class_definitions_lock_) {
+    return std::find_if(initializing_class_definitions_.begin(),
+                        initializing_class_definitions_.end(),
+                        [&](const auto op) { return op->ContainsAddress(ptr); }) !=
+        initializing_class_definitions_.end();
+  }
+
+  void WaitForClassInitializationToFinish() REQUIRES(uninitialized_class_definitions_lock_) {
+    class_definition_initialized_cond_.Wait(art::Thread::Current());
+  }
+
+  art::Mutex uninitialized_class_definitions_lock_ ACQUIRED_BEFORE(art::Locks::abort_lock_);
+  art::ConditionVariable class_definition_initialized_cond_
+      GUARDED_BY(uninitialized_class_definitions_lock_);
+
+  // A list of the class definitions that have a non-readable map.
+  std::vector<ArtClassDefinition*> uninitialized_class_definitions_
+      GUARDED_BY(uninitialized_class_definitions_lock_);
+
+  // A list of class definitions that are currently undergoing unquickening. Threads should wait
+  // until the definition is no longer in this before returning.
+  std::vector<ArtClassDefinition*> initializing_class_definitions_
+      GUARDED_BY(uninitialized_class_definitions_lock_);
+
+  // A list of class definitions that are already unquickened. Threads should immediately return if
+  // it is here.
+  std::vector<ArtClassDefinition*> initialized_class_definitions_
+      GUARDED_BY(uninitialized_class_definitions_lock_);
+};
+
+static TransformationFaultHandler* gTransformFaultHandler = nullptr;
+
+void Transformer::Setup() {
+  // Although we create this the fault handler is actually owned by the 'art::fault_manager' which
+  // will take care of destroying it.
+  if (art::MemMap::kCanReplaceMapping && ArtClassDefinition::kEnableOnDemandDexDequicken) {
+    gTransformFaultHandler = new TransformationFaultHandler(&art::fault_manager);
+  }
+}
+
+// Simple helper to add and remove the class definition from the fault handler.
+class ScopedDefinitionHandler {
+ public:
+  explicit ScopedDefinitionHandler(ArtClassDefinition* def)
+      : def_(def), is_lazy_(def_->IsLazyDefinition()) {
+    if (is_lazy_) {
+      gTransformFaultHandler->AddArtDefinition(def_);
+    }
+  }
+
+  ~ScopedDefinitionHandler() {
+    if (is_lazy_) {
+      gTransformFaultHandler->RemoveDefinition(def_);
+    }
+  }
+
+ private:
+  ArtClassDefinition* def_;
+  bool is_lazy_;
+};
+
 // Initialize templates.
 template
 void Transformer::TransformSingleClassDirect<ArtJvmtiEvent::kClassFileLoadHookNonRetransformable>(
@@ -78,6 +250,7 @@
   static_assert(kEvent == ArtJvmtiEvent::kClassFileLoadHookNonRetransformable ||
                 kEvent == ArtJvmtiEvent::kClassFileLoadHookRetransformable,
                 "bad event type");
+  ScopedDefinitionHandler handler(def);
   jint new_len = -1;
   unsigned char* new_data = nullptr;
   art::ArrayRef<const unsigned char> dex_data = def->GetDexData();
diff --git a/openjdkjvmti/transform.h b/openjdkjvmti/transform.h
index f43af17..8bbeda4 100644
--- a/openjdkjvmti/transform.h
+++ b/openjdkjvmti/transform.h
@@ -48,6 +48,8 @@
 
 class Transformer {
  public:
+  static void Setup();
+
   template<ArtJvmtiEvent kEvent>
   static void TransformSingleClassDirect(
       EventHandler* event_handler,
diff --git a/runtime/base/mutex.h b/runtime/base/mutex.h
index d541b79..6495bc6 100644
--- a/runtime/base/mutex.h
+++ b/runtime/base/mutex.h
@@ -62,6 +62,7 @@
   kUnexpectedSignalLock,
   kThreadSuspendCountLock,
   kAbortLock,
+  kSignalHandlingLock,
   kJdwpAdbStateLock,
   kJdwpSocketLock,
   kRegionSpaceRegionLock,
diff --git a/runtime/dex/compact_dex_file.h b/runtime/dex/compact_dex_file.h
index 1ecff04..31aeb27 100644
--- a/runtime/dex/compact_dex_file.h
+++ b/runtime/dex/compact_dex_file.h
@@ -245,6 +245,12 @@
   static bool IsVersionValid(const uint8_t* magic);
   virtual bool IsVersionValid() const OVERRIDE;
 
+  // TODO This is completely a guess. We really need to do better. b/72402467
+  // We ask for 64 megabytes which should be big enough for any realistic dex file.
+  virtual size_t GetDequickenedSize() const OVERRIDE {
+    return 64 * MB;
+  }
+
   const Header& GetHeader() const {
     return down_cast<const Header&>(DexFile::GetHeader());
   }
diff --git a/runtime/dex/dex_file.h b/runtime/dex/dex_file.h
index 7e2fe98..cf8c840 100644
--- a/runtime/dex/dex_file.h
+++ b/runtime/dex/dex_file.h
@@ -456,6 +456,13 @@
   // Returns true if the dex file supports default methods.
   virtual bool SupportsDefaultMethods() const = 0;
 
+  // Returns the maximum size in bytes needed to store an equivalent dex file strictly conforming to
+  // the dex file specification. That is the size if we wanted to get rid of all the
+  // quickening/compact-dexing/etc.
+  //
+  // TODO This should really be an exact size! b/72402467
+  virtual size_t GetDequickenedSize() const = 0;
+
   // Returns the number of string identifiers in the .dex file.
   size_t NumStringIds() const {
     DCHECK(header_ != nullptr) << GetLocation();
diff --git a/runtime/dex/standard_dex_file.h b/runtime/dex/standard_dex_file.h
index 94ef1f2..e0e9f2f 100644
--- a/runtime/dex/standard_dex_file.h
+++ b/runtime/dex/standard_dex_file.h
@@ -83,6 +83,10 @@
 
   uint32_t GetCodeItemSize(const DexFile::CodeItem& item) const OVERRIDE;
 
+  virtual size_t GetDequickenedSize() const OVERRIDE {
+    return Size();
+  }
+
  private:
   StandardDexFile(const uint8_t* base,
                   size_t size,
diff --git a/runtime/mem_map.cc b/runtime/mem_map.cc
index 55e9c39..26acef0 100644
--- a/runtime/mem_map.cc
+++ b/runtime/mem_map.cc
@@ -396,6 +396,91 @@
   return new MemMap(name, addr, byte_count, addr, page_aligned_byte_count, 0, true /* reuse */);
 }
 
+template<typename A, typename B>
+static ptrdiff_t PointerDiff(A* a, B* b) {
+  return static_cast<ptrdiff_t>(reinterpret_cast<intptr_t>(a) - reinterpret_cast<intptr_t>(b));
+}
+
+bool MemMap::ReplaceWith(MemMap** source_ptr, /*out*/std::string* error) {
+#if !HAVE_MREMAP_SYSCALL
+  UNUSED(source_ptr);
+  *error = "Cannot perform atomic replace because we are missing the required mremap syscall";
+  return false;
+#else  // !HAVE_MREMAP_SYSCALL
+  CHECK(source_ptr != nullptr);
+  CHECK(*source_ptr != nullptr);
+  if (!MemMap::kCanReplaceMapping) {
+    *error = "Unable to perform atomic replace due to runtime environment!";
+    return false;
+  }
+  MemMap* source = *source_ptr;
+  // neither can be reuse.
+  if (source->reuse_ || reuse_) {
+    *error = "One or both mappings is not a real mmap!";
+    return false;
+  }
+  // TODO Support redzones.
+  if (source->redzone_size_ != 0 || redzone_size_ != 0) {
+    *error = "source and dest have different redzone sizes";
+    return false;
+  }
+  // Make sure they have the same offset from the actual mmap'd address
+  if (PointerDiff(BaseBegin(), Begin()) != PointerDiff(source->BaseBegin(), source->Begin())) {
+    *error =
+        "source starts at a different offset from the mmap. Cannot atomically replace mappings";
+    return false;
+  }
+  // mremap doesn't allow the final [start, end] to overlap with the initial [start, end] (it's like
+  // memcpy but the check is explicit and actually done).
+  if (source->BaseBegin() > BaseBegin() &&
+      reinterpret_cast<uint8_t*>(BaseBegin()) + source->BaseSize() >
+      reinterpret_cast<uint8_t*>(source->BaseBegin())) {
+    *error = "destination memory pages overlap with source memory pages";
+    return false;
+  }
+  // Change the protection to match the new location.
+  int old_prot = source->GetProtect();
+  if (!source->Protect(GetProtect())) {
+    *error = "Could not change protections for source to those required for dest.";
+    return false;
+  }
+
+  // Do the mremap.
+  void* res = mremap(/*old_address*/source->BaseBegin(),
+                     /*old_size*/source->BaseSize(),
+                     /*new_size*/source->BaseSize(),
+                     /*flags*/MREMAP_MAYMOVE | MREMAP_FIXED,
+                     /*new_address*/BaseBegin());
+  if (res == MAP_FAILED) {
+    int saved_errno = errno;
+    // Wasn't able to move mapping. Change the protection of source back to the original one and
+    // return.
+    source->Protect(old_prot);
+    *error = std::string("Failed to mremap source to dest. Error was ") + strerror(saved_errno);
+    return false;
+  }
+  CHECK(res == BaseBegin());
+
+  // The new base_size is all the pages of the 'source' plus any remaining dest pages. We will unmap
+  // them later.
+  size_t new_base_size = std::max(source->base_size_, base_size_);
+
+  // Delete the old source, don't unmap it though (set reuse) since it is already gone.
+  *source_ptr = nullptr;
+  size_t source_size = source->size_;
+  source->already_unmapped_ = true;
+  delete source;
+  source = nullptr;
+
+  size_ = source_size;
+  base_size_ = new_base_size;
+  // Reduce base_size if needed (this will unmap the extra pages).
+  SetSize(source_size);
+
+  return true;
+#endif  // !HAVE_MREMAP_SYSCALL
+}
+
 MemMap* MemMap::MapFileAtAddress(uint8_t* expected_ptr,
                                  size_t byte_count,
                                  int prot,
@@ -499,9 +584,11 @@
 
   if (!reuse_) {
     MEMORY_TOOL_MAKE_UNDEFINED(base_begin_, base_size_);
-    int result = munmap(base_begin_, base_size_);
-    if (result == -1) {
-      PLOG(FATAL) << "munmap failed";
+    if (!already_unmapped_) {
+      int result = munmap(base_begin_, base_size_);
+      if (result == -1) {
+        PLOG(FATAL) << "munmap failed";
+      }
     }
   }
 
@@ -523,7 +610,7 @@
 MemMap::MemMap(const std::string& name, uint8_t* begin, size_t size, void* base_begin,
                size_t base_size, int prot, bool reuse, size_t redzone_size)
     : name_(name), begin_(begin), size_(size), base_begin_(base_begin), base_size_(base_size),
-      prot_(prot), reuse_(reuse), redzone_size_(redzone_size) {
+      prot_(prot), reuse_(reuse), already_unmapped_(false), redzone_size_(redzone_size) {
   if (size_ == 0) {
     CHECK(begin_ == nullptr);
     CHECK(base_begin_ == nullptr);
@@ -794,19 +881,21 @@
 }
 
 void MemMap::SetSize(size_t new_size) {
-  if (new_size == base_size_) {
+  CHECK_LE(new_size, size_);
+  size_t new_base_size = RoundUp(new_size + static_cast<size_t>(PointerDiff(Begin(), BaseBegin())),
+                                 kPageSize);
+  if (new_base_size == base_size_) {
+    size_ = new_size;
     return;
   }
-  CHECK_ALIGNED(new_size, kPageSize);
-  CHECK_EQ(base_size_, size_) << "Unsupported";
-  CHECK_LE(new_size, base_size_);
+  CHECK_LT(new_base_size, base_size_);
   MEMORY_TOOL_MAKE_UNDEFINED(
       reinterpret_cast<void*>(reinterpret_cast<uintptr_t>(BaseBegin()) +
-                              new_size),
-      base_size_ - new_size);
-  CHECK_EQ(munmap(reinterpret_cast<void*>(reinterpret_cast<uintptr_t>(BaseBegin()) + new_size),
-                  base_size_ - new_size), 0) << new_size << " " << base_size_;
-  base_size_ = new_size;
+                              new_base_size),
+      base_size_ - new_base_size);
+  CHECK_EQ(munmap(reinterpret_cast<void*>(reinterpret_cast<uintptr_t>(BaseBegin()) + new_base_size),
+                  base_size_ - new_base_size), 0) << new_base_size << " " << base_size_;
+  base_size_ = new_base_size;
   size_ = new_size;
 }
 
diff --git a/runtime/mem_map.h b/runtime/mem_map.h
index 5603963..0ecb414 100644
--- a/runtime/mem_map.h
+++ b/runtime/mem_map.h
@@ -39,8 +39,12 @@
 
 #ifdef __linux__
 static constexpr bool kMadviseZeroes = true;
+#define HAVE_MREMAP_SYSCALL true
 #else
 static constexpr bool kMadviseZeroes = false;
+// We cannot ever perform MemMap::ReplaceWith on non-linux hosts since the syscall is not
+// present.
+#define HAVE_MREMAP_SYSCALL false
 #endif
 
 // Used to keep track of mmap segments.
@@ -52,6 +56,32 @@
 // Otherwise, calls might see uninitialized values.
 class MemMap {
  public:
+  static constexpr bool kCanReplaceMapping = HAVE_MREMAP_SYSCALL;
+
+  // Replace the data in this memmmap with the data in the memmap pointed to by source. The caller
+  // relinquishes ownership of the source mmap.
+  //
+  // For the call to be successful:
+  //   * The range [dest->Begin, dest->Begin() + source->Size()] must not overlap with
+  //     [source->Begin(), source->End()].
+  //   * Neither source nor dest may be 'reused' mappings (they must own all the pages associated
+  //     with them.
+  //   * kCanReplaceMapping must be true.
+  //   * Neither source nor dest may use manual redzones.
+  //   * Both source and dest must have the same offset from the nearest page boundary.
+  //   * mremap must succeed when called on the mappings.
+  //
+  // If this call succeeds it will return true and:
+  //   * Deallocate *source
+  //   * Sets *source to nullptr
+  //   * The protection of this will remain the same.
+  //   * The size of this will be the size of the source
+  //   * The data in this will be the data from source.
+  //
+  // If this call fails it will return false and make no changes to *source or this. The ownership
+  // of the source mmap is returned to the caller.
+  bool ReplaceWith(/*in-out*/MemMap** source, /*out*/std::string* error);
+
   // Request an anonymous region of length 'byte_count' and a requested base address.
   // Use null as the requested base address if you don't care.
   // "reuse" allows re-mapping an address range from an existing mapping.
@@ -246,6 +276,9 @@
   // unmapping.
   const bool reuse_;
 
+  // When already_unmapped_ is true the destructor will not call munmap.
+  bool already_unmapped_;
+
   const size_t redzone_size_;
 
 #if USE_ART_LOW_4G_ALLOCATOR
diff --git a/runtime/mem_map_test.cc b/runtime/mem_map_test.cc
index a4ebb16..3adbf18 100644
--- a/runtime/mem_map_test.cc
+++ b/runtime/mem_map_test.cc
@@ -19,6 +19,7 @@
 #include <sys/mman.h>
 
 #include <memory>
+#include <random>
 
 #include "base/memory_tool.h"
 #include "base/unix_file/fd_file.h"
@@ -36,6 +37,25 @@
     return mem_map->base_size_;
   }
 
+  static bool IsAddressMapped(void* addr) {
+    bool res = msync(addr, 1, MS_SYNC) == 0;
+    if (!res && errno != ENOMEM) {
+      PLOG(FATAL) << "Unexpected error occurred on msync";
+    }
+    return res;
+  }
+
+  static std::vector<uint8_t> RandomData(size_t size) {
+    std::random_device rd;
+    std::uniform_int_distribution<uint8_t> dist;
+    std::vector<uint8_t> res;
+    res.resize(size);
+    for (size_t i = 0; i < size; i++) {
+      res[i] = dist(rd);
+    }
+    return res;
+  }
+
   static uint8_t* GetValidMapAddress(size_t size, bool low_4gb) {
     // Find a valid map address and unmap it before returning.
     std::string error_msg;
@@ -143,6 +163,186 @@
 }
 #endif
 
+// We need mremap to be able to test ReplaceMapping at all
+#if HAVE_MREMAP_SYSCALL
+TEST_F(MemMapTest, ReplaceMapping_SameSize) {
+  std::string error_msg;
+  std::unique_ptr<MemMap> dest(MemMap::MapAnonymous("MapAnonymousEmpty-atomic-replace-dest",
+                                                    nullptr,
+                                                    kPageSize,
+                                                    PROT_READ,
+                                                    false,
+                                                    false,
+                                                    &error_msg));
+  ASSERT_TRUE(dest != nullptr);
+  MemMap* source = MemMap::MapAnonymous("MapAnonymous-atomic-replace-source",
+                                        nullptr,
+                                        kPageSize,
+                                        PROT_WRITE | PROT_READ,
+                                        false,
+                                        false,
+                                        &error_msg);
+  ASSERT_TRUE(source != nullptr);
+  void* source_addr = source->Begin();
+  void* dest_addr = dest->Begin();
+  ASSERT_TRUE(IsAddressMapped(source_addr));
+  ASSERT_TRUE(IsAddressMapped(dest_addr));
+
+  std::vector<uint8_t> data = RandomData(kPageSize);
+  memcpy(source->Begin(), data.data(), data.size());
+
+  ASSERT_TRUE(dest->ReplaceWith(&source, &error_msg)) << error_msg;
+
+  ASSERT_FALSE(IsAddressMapped(source_addr));
+  ASSERT_TRUE(IsAddressMapped(dest_addr));
+  ASSERT_TRUE(source == nullptr);
+
+  ASSERT_EQ(dest->Size(), static_cast<size_t>(kPageSize));
+
+  ASSERT_EQ(memcmp(dest->Begin(), data.data(), dest->Size()), 0);
+}
+
+TEST_F(MemMapTest, ReplaceMapping_MakeLarger) {
+  std::string error_msg;
+  std::unique_ptr<MemMap> dest(MemMap::MapAnonymous("MapAnonymousEmpty-atomic-replace-dest",
+                                                    nullptr,
+                                                    5 * kPageSize,  // Need to make it larger
+                                                                    // initially so we know
+                                                                    // there won't be mappings
+                                                                    // in the way we we move
+                                                                    // source.
+                                                    PROT_READ,
+                                                    false,
+                                                    false,
+                                                    &error_msg));
+  ASSERT_TRUE(dest != nullptr);
+  MemMap* source = MemMap::MapAnonymous("MapAnonymous-atomic-replace-source",
+                                        nullptr,
+                                        3 * kPageSize,
+                                        PROT_WRITE | PROT_READ,
+                                        false,
+                                        false,
+                                        &error_msg);
+  ASSERT_TRUE(source != nullptr);
+  uint8_t* source_addr = source->Begin();
+  uint8_t* dest_addr = dest->Begin();
+  ASSERT_TRUE(IsAddressMapped(source_addr));
+
+  // Fill the source with random data.
+  std::vector<uint8_t> data = RandomData(3 * kPageSize);
+  memcpy(source->Begin(), data.data(), data.size());
+
+  // Make the dest smaller so that we know we'll have space.
+  dest->SetSize(kPageSize);
+
+  ASSERT_TRUE(IsAddressMapped(dest_addr));
+  ASSERT_FALSE(IsAddressMapped(dest_addr + 2 * kPageSize));
+  ASSERT_EQ(dest->Size(), static_cast<size_t>(kPageSize));
+
+  ASSERT_TRUE(dest->ReplaceWith(&source, &error_msg)) << error_msg;
+
+  ASSERT_FALSE(IsAddressMapped(source_addr));
+  ASSERT_EQ(dest->Size(), static_cast<size_t>(3 * kPageSize));
+  ASSERT_TRUE(IsAddressMapped(dest_addr));
+  ASSERT_TRUE(IsAddressMapped(dest_addr + 2 * kPageSize));
+  ASSERT_TRUE(source == nullptr);
+
+  ASSERT_EQ(memcmp(dest->Begin(), data.data(), dest->Size()), 0);
+}
+
+TEST_F(MemMapTest, ReplaceMapping_MakeSmaller) {
+  std::string error_msg;
+  std::unique_ptr<MemMap> dest(MemMap::MapAnonymous("MapAnonymousEmpty-atomic-replace-dest",
+                                                    nullptr,
+                                                    3 * kPageSize,
+                                                    PROT_READ,
+                                                    false,
+                                                    false,
+                                                    &error_msg));
+  ASSERT_TRUE(dest != nullptr);
+  MemMap* source = MemMap::MapAnonymous("MapAnonymous-atomic-replace-source",
+                                        nullptr,
+                                        kPageSize,
+                                        PROT_WRITE | PROT_READ,
+                                        false,
+                                        false,
+                                        &error_msg);
+  ASSERT_TRUE(source != nullptr);
+  uint8_t* source_addr = source->Begin();
+  uint8_t* dest_addr = dest->Begin();
+  ASSERT_TRUE(IsAddressMapped(source_addr));
+  ASSERT_TRUE(IsAddressMapped(dest_addr));
+  ASSERT_TRUE(IsAddressMapped(dest_addr + 2 * kPageSize));
+  ASSERT_EQ(dest->Size(), static_cast<size_t>(3 * kPageSize));
+
+  std::vector<uint8_t> data = RandomData(kPageSize);
+  memcpy(source->Begin(), data.data(), kPageSize);
+
+  ASSERT_TRUE(dest->ReplaceWith(&source, &error_msg)) << error_msg;
+
+  ASSERT_FALSE(IsAddressMapped(source_addr));
+  ASSERT_EQ(dest->Size(), static_cast<size_t>(kPageSize));
+  ASSERT_TRUE(IsAddressMapped(dest_addr));
+  ASSERT_FALSE(IsAddressMapped(dest_addr + 2 * kPageSize));
+  ASSERT_TRUE(source == nullptr);
+
+  ASSERT_EQ(memcmp(dest->Begin(), data.data(), dest->Size()), 0);
+}
+
+TEST_F(MemMapTest, ReplaceMapping_FailureOverlap) {
+  std::string error_msg;
+  std::unique_ptr<MemMap> dest(
+      MemMap::MapAnonymous(
+          "MapAnonymousEmpty-atomic-replace-dest",
+          nullptr,
+          3 * kPageSize,  // Need to make it larger initially so we know there won't be mappings in
+                          // the way we we move source.
+          PROT_READ | PROT_WRITE,
+          false,
+          false,
+          &error_msg));
+  ASSERT_TRUE(dest != nullptr);
+  // Resize down to 1 page so we can remap the rest.
+  dest->SetSize(kPageSize);
+  // Create source from the last 2 pages
+  MemMap* source = MemMap::MapAnonymous("MapAnonymous-atomic-replace-source",
+                                        dest->Begin() + kPageSize,
+                                        2 * kPageSize,
+                                        PROT_WRITE | PROT_READ,
+                                        false,
+                                        false,
+                                        &error_msg);
+  ASSERT_TRUE(source != nullptr);
+  MemMap* orig_source = source;
+  ASSERT_EQ(dest->Begin() + kPageSize, source->Begin());
+  uint8_t* source_addr = source->Begin();
+  uint8_t* dest_addr = dest->Begin();
+  ASSERT_TRUE(IsAddressMapped(source_addr));
+
+  // Fill the source and dest with random data.
+  std::vector<uint8_t> data = RandomData(2 * kPageSize);
+  memcpy(source->Begin(), data.data(), data.size());
+  std::vector<uint8_t> dest_data = RandomData(kPageSize);
+  memcpy(dest->Begin(), dest_data.data(), dest_data.size());
+
+  ASSERT_TRUE(IsAddressMapped(dest_addr));
+  ASSERT_EQ(dest->Size(), static_cast<size_t>(kPageSize));
+
+  ASSERT_FALSE(dest->ReplaceWith(&source, &error_msg)) << error_msg;
+
+  ASSERT_TRUE(source == orig_source);
+  ASSERT_TRUE(IsAddressMapped(source_addr));
+  ASSERT_TRUE(IsAddressMapped(dest_addr));
+  ASSERT_EQ(source->Size(), data.size());
+  ASSERT_EQ(dest->Size(), dest_data.size());
+
+  ASSERT_EQ(memcmp(source->Begin(), data.data(), data.size()), 0);
+  ASSERT_EQ(memcmp(dest->Begin(), dest_data.data(), dest_data.size()), 0);
+
+  delete source;
+}
+#endif  // HAVE_MREMAP_SYSCALL
+
 TEST_F(MemMapTest, MapAnonymousEmpty) {
   CommonInit();
   std::string error_msg;
diff --git a/test/983-source-transform-verify/source_transform.cc b/test/983-source-transform-verify/source_transform.cc
index c076d15..dfefce2 100644
--- a/test/983-source-transform-verify/source_transform.cc
+++ b/test/983-source-transform-verify/source_transform.cc
@@ -67,6 +67,14 @@
   if (IsJVM()) {
     return;
   }
+
+  // Due to b/72402467 the class_data_len might just be an estimate.
+  CHECK_GE(static_cast<size_t>(class_data_len), sizeof(DexFile::Header));
+  const DexFile::Header* header = reinterpret_cast<const DexFile::Header*>(class_data);
+  uint32_t header_file_size = header->file_size_;
+  CHECK_LE(static_cast<jint>(header_file_size), class_data_len);
+  class_data_len = static_cast<jint>(header_file_size);
+
   const ArtDexFileLoader dex_file_loader;
   std::string error;
   std::unique_ptr<const DexFile> dex(dex_file_loader.Open(class_data,