Fix checksum verification when opening DexFiles from OatFiles

Change-Id: Ic3d13f3d591c34f159bf0739536a1751c3e7dc75
diff --git a/src/class_linker.cc b/src/class_linker.cc
index 87976a3..f414c4b 100644
--- a/src/class_linker.cc
+++ b/src/class_linker.cc
@@ -40,6 +40,7 @@
 #include "oat_file.h"
 #include "object.h"
 #include "object_utils.h"
+#include "os.h"
 #include "runtime.h"
 #include "runtime_support.h"
 #include "ScopedLocalRef.h"
@@ -678,10 +679,18 @@
 }
 
 const OatFile* ClassLinker::FindOpenedOatFileForDexFile(const DexFile& dex_file) {
+  return FindOpenedOatFileFromDexLocation(dex_file.GetLocation(),
+                                          dex_file.GetLocationChecksum());
+}
+
+const OatFile* ClassLinker::FindOpenedOatFileFromDexLocation(const std::string& dex_location,
+                                                             uint32_t dex_location_checksum) {
   for (size_t i = 0; i < oat_files_.size(); i++) {
     const OatFile* oat_file = oat_files_[i];
     DCHECK(oat_file != NULL);
-    if (oat_file->GetOatDexFile(dex_file.GetLocation(), false)) {
+    const OatFile::OatDexFile* oat_dex_file = oat_file->GetOatDexFile(dex_location, false);
+    if (oat_dex_file != NULL
+        && oat_dex_file->GetDexFileLocationChecksum() == dex_location_checksum) {
       return oat_file;
     }
   }
@@ -734,77 +743,103 @@
   int fd_;
 };
 
-const OatFile* ClassLinker::FindOatFileForDexFile(const DexFile& dex_file) {
+static const DexFile* FindDexFileInOatLocation(const std::string& dex_location,
+                                               uint32_t dex_location_checksum,
+                                               const std::string& oat_location) {
+  UniquePtr<OatFile> oat_file(OatFile::Open(oat_location, "", NULL));
+  if (oat_file.get() == NULL) {
+    return NULL;
+  }
+  const OatFile::OatDexFile* oat_dex_file = oat_file->GetOatDexFile(dex_location);
+  if (oat_dex_file == NULL) {
+    return NULL;
+  }
+  if (oat_dex_file->GetDexFileLocationChecksum() != dex_location_checksum) {
+    return NULL;
+  }
+  Runtime::Current()->GetClassLinker()->RegisterOatFile(*oat_file.release());
+  return oat_dex_file->OpenDexFile();
+}
+
+const DexFile* ClassLinker::FindOrCreateOatFileForDexLocation(const std::string& dex_location,
+                                                              const std::string& oat_location) {
+  uint32_t dex_location_checksum;
+  if (!DexFile::GetChecksum(dex_location, dex_location_checksum)) {
+    LOG(ERROR) << "Failed to compute checksum '" << dex_location << "'";
+    return NULL;
+  }
+
+  // Check if we already have an up-to-date output file
+  const DexFile* dex_file = FindDexFileInOatLocation(dex_location,
+                                                     dex_location_checksum,
+                                                     oat_location);
+  if (dex_file != NULL) {
+    return dex_file;
+  }
+
+  // Generate the output oat file for the dex file
+  ClassLinker* class_linker = Runtime::Current()->GetClassLinker();
+  UniquePtr<File> file(OS::OpenFile(oat_location.c_str(), true));
+  if (file.get() == NULL) {
+    LOG(ERROR) << "Failed to create oat file: " << oat_location;
+    return NULL;
+  }
+  if (!class_linker->GenerateOatFile(dex_location, file->Fd(), oat_location)) {
+    LOG(ERROR) << "Failed to generate oat file: " << oat_location;
+    return NULL;
+  }
+  // Open the oat from file descriptor we passed to GenerateOatFile
+  if (lseek(file->Fd(), 0, SEEK_SET) != 0) {
+    LOG(ERROR) << "Failed to seek to start of generated oat file: " << oat_location;
+    return NULL;
+  }
+  const OatFile* oat_file = OatFile::Open(*file.get(), oat_location, NULL);
+  if (oat_file == NULL) {
+    LOG(ERROR) << "Failed to open generated oat file: " << oat_location;
+    return NULL;
+  }
+  class_linker->RegisterOatFile(*oat_file);
+  const OatFile::OatDexFile* oat_dex_file = oat_file->GetOatDexFile(dex_location);
+  if (oat_dex_file == NULL) {
+    LOG(ERROR) << "Failed to find dex file in generated oat file: " << oat_location;
+    return NULL;
+  }
+  return oat_dex_file->OpenDexFile();
+}
+
+const DexFile* ClassLinker::FindDexFileInOatFileFromDexLocation(const std::string& dex_location) {
   MutexLock mu(dex_lock_);
-  const OatFile* open_oat_file = FindOpenedOatFileForDexFile(dex_file);
+
+  uint32_t dex_location_checksum;
+  if (!DexFile::GetChecksum(dex_location, dex_location_checksum)) {
+    LOG(WARNING) << "Failed to compute checksum: " << dex_location;
+    return NULL;
+  }
+
+  const OatFile* open_oat_file = FindOpenedOatFileFromDexLocation(dex_location, dex_location_checksum);
   if (open_oat_file != NULL) {
-    return open_oat_file;
+    return open_oat_file->GetOatDexFile(dex_location)->OpenDexFile();
   }
 
-  std::string oat_filename(OatFile::DexFilenameToOatFilename(dex_file.GetLocation()));
-  open_oat_file = FindOpenedOatFileFromOatLocation(oat_filename);
-  if (open_oat_file != NULL) {
-    return open_oat_file;
-  }
-
-  while (true) {
-    UniquePtr<const OatFile> oat_file(FindOatFileFromOatLocation(oat_filename));
-    if (oat_file.get() != NULL) {
-      const OatFile::OatDexFile* oat_dex_file = oat_file->GetOatDexFile(dex_file.GetLocation());
-      if (dex_file.GetHeader().checksum_ == oat_dex_file->GetDexFileChecksum()) {
-        return oat_file.release();
-      }
-      LOG(WARNING) << ".oat file " << oat_file->GetLocation()
-                   << " checksum mismatch with " << dex_file.GetLocation() << " --- regenerating";
-      if (TEMP_FAILURE_RETRY(unlink(oat_file->GetLocation().c_str())) != 0) {
-        PLOG(FATAL) << "Couldn't remove obsolete .oat file " << oat_file->GetLocation();
-      }
-      // Fall through...
+  // Look for an existing file first next to dex and in art-cache
+  std::string oat_filename(OatFile::DexFilenameToOatFilename(dex_location));
+  const OatFile* oat_file(FindOatFileFromOatLocation(oat_filename));
+  if (oat_file != NULL) {
+    const OatFile::OatDexFile* oat_dex_file = oat_file->GetOatDexFile(dex_location);
+    if (dex_location_checksum == oat_dex_file->GetDexFileLocationChecksum()) {
+      return oat_file->GetOatDexFile(dex_location)->OpenDexFile();
     }
-    // Try to generate oat file if it wasn't found or was obsolete.
-    // Note we can be racing with another runtime to do this.
-    std::string oat_cache_filename(GetArtCacheFilenameOrDie(oat_filename));
-    UniquePtr<LockedFd> locked_fd(LockedFd::CreateAndLock(oat_cache_filename, 0644));
-    if (locked_fd.get() == NULL) {
-      LOG(ERROR) << "Failed to create and lock oat file " << oat_cache_filename;
-      return NULL;
-    }
-    // Check to see if the fd we opened and locked matches the file in
-    // the filesystem.  If they don't, then somebody else unlinked ours
-    // and created a new file, and we need to use that one instead.  (If
-    // we caught them between the unlink and the create, we'll get an
-    // ENOENT from the file stat.)
-    struct stat fd_stat;
-    int fd_stat_result = fstat(locked_fd->GetFd(), &fd_stat);
-    if (fd_stat_result != 0) {
-      PLOG(ERROR) << "Failed to fstat file descriptor of oat file " << oat_cache_filename;
-      return NULL;
-    }
-    struct stat file_stat;
-    int file_stat_result = stat(oat_cache_filename.c_str(), &file_stat);
-    if (file_stat_result != 0
-        || fd_stat.st_dev != file_stat.st_dev
-        || fd_stat.st_ino != file_stat.st_ino) {
-      LOG(INFO) << "Opened oat file " << oat_cache_filename << " is stale; sleeping and retrying";
-      usleep(250 * 1000);  // if something is hosed, don't peg machine
-      continue;
-    }
-
-    // We have the correct file open and locked.  If the file size is
-    // zero, then it was just created by us and we can generate its
-    // contents. If not, someone else created it. Either way, we'll
-    // loop to retry opening the file.
-    if (fd_stat.st_size == 0) {
-      bool success = GenerateOatFile(dex_file.GetLocation(),
-                                     locked_fd->GetFd(),
-                                     oat_cache_filename);
-      if (!success) {
-        LOG(ERROR) << "Failed to generate oat file " << oat_cache_filename;
-        return NULL;
-      }
+    LOG(WARNING) << ".oat file " << oat_file->GetLocation()
+                 << " checksum ( " << std::hex << oat_dex_file->GetDexFileLocationChecksum()
+                 << ") mismatch with " << dex_location
+                 << " (" << std::hex << dex_location_checksum << ")--- regenerating";
+    if (TEMP_FAILURE_RETRY(unlink(oat_file->GetLocation().c_str())) != 0) {
+      PLOG(FATAL) << "Couldn't remove obsolete .oat file " << oat_file->GetLocation();
     }
   }
-  // Not reached
+  // Try to generate oat file if it wasn't found or was obsolete.
+  std::string oat_cache_filename(GetArtCacheFilenameOrDie(oat_filename));
+  return FindOrCreateOatFileForDexLocation(dex_location, oat_cache_filename);
 }
 
 const OatFile* ClassLinker::FindOpenedOatFileFromOatLocation(const std::string& oat_location) {
@@ -850,26 +885,6 @@
   return oat_file;
 }
 
-const DexFile* ClassLinker::FindDexFileFromDexLocation(const std::string& location) {
-  std::string oat_location(OatFile::DexFilenameToOatFilename(location));
-  const OatFile* oat_file = FindOatFileFromOatLocation(oat_location);
-  if (oat_file == NULL) {
-    return NULL;
-  }
-  const OatFile::OatDexFile* oat_dex_file = oat_file->GetOatDexFile(location);
-  if (oat_dex_file == NULL) {
-    return NULL;
-  }
-  const DexFile* dex_file = oat_dex_file->OpenDexFile();
-  if (dex_file == NULL) {
-    return NULL;
-  }
-  if (oat_dex_file->GetDexFileChecksum() != dex_file->GetHeader().checksum_) {
-    return NULL;
-  }
-  return dex_file;
-}
-
 void ClassLinker::InitFromImage() {
   VLOG(startup) << "ClassLinker::InitFromImage entering";
   CHECK(!init_done_);
@@ -903,7 +918,7 @@
                      << " from within oat file " << oat_file->GetLocation();
         }
 
-        CHECK_EQ(dex_file->GetHeader().checksum_, oat_dex_file->GetDexFileChecksum());
+        CHECK_EQ(dex_file->GetLocationChecksum(), oat_dex_file->GetDexFileLocationChecksum());
 
         AppendToBootClassPath(*dex_file, dex_cache);
       }
@@ -1402,17 +1417,15 @@
 
   UniquePtr<const OatFile::OatClass> oat_class;
   if (Runtime::Current()->IsStarted() && !ClassLoader::UseCompileTimeClassPath()) {
-    const OatFile* oat_file = FindOatFileForDexFile(dex_file);
-    if (oat_file != NULL) {
-      const OatFile::OatDexFile* oat_dex_file = oat_file->GetOatDexFile(dex_file.GetLocation());
-      if (oat_dex_file != NULL) {
-        uint32_t class_def_index;
-        bool found = dex_file.FindClassDefIndex(descriptor, class_def_index);
-        CHECK(found) << descriptor;
-        oat_class.reset(oat_dex_file->GetOatClass(class_def_index));
-        CHECK(oat_class.get() != NULL) << descriptor;
-      }
-    }
+    const OatFile* oat_file = FindOpenedOatFileForDexFile(dex_file);
+    CHECK(oat_file != NULL) << dex_file.GetLocation() << " " << descriptor;
+    const OatFile::OatDexFile* oat_dex_file = oat_file->GetOatDexFile(dex_file.GetLocation());
+    CHECK(oat_dex_file != NULL) << dex_file.GetLocation() << " " << descriptor;
+    uint32_t class_def_index;
+    bool found = dex_file.FindClassDefIndex(descriptor, class_def_index);
+    CHECK(found) << dex_file.GetLocation() << " " << descriptor;
+    oat_class.reset(oat_dex_file->GetOatClass(class_def_index));
+    CHECK(oat_class.get() != NULL) << dex_file.GetLocation() << " " << descriptor;
   }
   // Load methods.
   if (it.NumDirectMethods() != 0) {
@@ -1956,18 +1969,17 @@
   if (ClassLoader::UseCompileTimeClassPath()) {
     return false;
   }
-  const OatFile* oat_file = FindOatFileForDexFile(dex_file);
-  if (oat_file == NULL) {
-    return false;
-  }
+  const OatFile* oat_file = FindOpenedOatFileForDexFile(dex_file);
+  CHECK(oat_file != NULL) << dex_file.GetLocation() << " " << PrettyClass(klass);
   const OatFile::OatDexFile* oat_dex_file = oat_file->GetOatDexFile(dex_file.GetLocation());
-  CHECK(oat_dex_file != NULL) << PrettyClass(klass);
+  CHECK(oat_dex_file != NULL) << dex_file.GetLocation() << " " << PrettyClass(klass);
   const char* descriptor = ClassHelper(klass).GetDescriptor();
   uint32_t class_def_index;
   bool found = dex_file.FindClassDefIndex(descriptor, class_def_index);
-  CHECK(found) << descriptor;
+  CHECK(found) << dex_file.GetLocation() << " " << PrettyClass(klass) << " " << descriptor;
   UniquePtr<const OatFile::OatClass> oat_class(oat_dex_file->GetOatClass(class_def_index));
-  CHECK(oat_class.get() != NULL) << descriptor;
+  CHECK(oat_class.get() != NULL)
+          << dex_file.GetLocation() << " " << PrettyClass(klass) << " " << descriptor;
   Class::Status status = oat_class->GetStatus();
   if (status == Class::kStatusVerified || status == Class::kStatusInitialized) {
     return true;
@@ -1999,7 +2011,9 @@
     // isn't a problem and this case shouldn't occur
     return false;
   }
-  LOG(FATAL) << "Unexpected class status: " << status;
+  LOG(FATAL) << "Unexpected class status: " << status
+             << " " << dex_file.GetLocation() << " " << PrettyClass(klass) << " " << descriptor;
+
   return false;
 }