Merge "secdiscard.cpp: Use getmntent and my newfound C++ knowledge."
diff --git a/AutoCloseFD.h b/AutoCloseFD.h
new file mode 100644
index 0000000..f9d7c86
--- /dev/null
+++ b/AutoCloseFD.h
@@ -0,0 +1,48 @@
+/*
+ * Copyright (C) 2015 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <string>
+
+#include <sys/types.h>
+#include <sys/stat.h>
+#include <fcntl.h>
+#include <unistd.h>
+
+// File descriptor which is automatically closed when this object is destroyed.
+// Cannot be copied, since that would cause double-closes.
+class AutoCloseFD {
+public:
+    AutoCloseFD(const char *path, int flags = O_RDONLY, int mode = 0):
+        fd{TEMP_FAILURE_RETRY(open(path, flags | O_CLOEXEC, mode))} {}
+    AutoCloseFD(const std::string &path, int flags = O_RDONLY, int mode = 0):
+        AutoCloseFD(path.c_str(), flags, mode) {}
+    ~AutoCloseFD() {
+        if (fd != -1) {
+            int preserve_errno = errno;
+            if (close(fd) == -1) {
+                SLOGE("close(2) failed: %s", strerror(errno));
+            };
+            errno = preserve_errno;
+        }
+    }
+    AutoCloseFD(const AutoCloseFD&) = delete;
+    AutoCloseFD& operator=(const AutoCloseFD&) = delete;
+    explicit operator bool() {return fd != -1;}
+    int get() const {return fd;}
+private:
+    const int fd;
+};
+
diff --git a/secdiscard.cpp b/secdiscard.cpp
index 3f4ab2e..a750043 100644
--- a/secdiscard.cpp
+++ b/secdiscard.cpp
@@ -24,69 +24,73 @@
 #include <fcntl.h>
 #include <linux/fs.h>
 #include <linux/fiemap.h>
+#include <mntent.h>
 
 #define LOG_TAG "secdiscard"
 #include "cutils/log.h"
 
+#include <AutoCloseFD.h>
+
+namespace {
 // Deliberately limit ourselves to wiping small files.
-#define MAX_WIPE_LENGTH 4096
-#define INIT_BUFFER_SIZE 2048
+constexpr uint64_t max_wipe_length = 4096;
 
-static void usage(char *progname);
-static void destroy_key(const std::string &path);
-static int file_device_range(const std::string &path, uint64_t range[2]);
-static int open_block_device_for_path(const std::string &path);
-static int read_file_as_string_atomically(const std::string &path, std::string &contents);
-static int find_block_device_for_path(
-    const std::string &mounts,
-    const std::string &path,
-    std::string &block_device);
+void usage(const char *progname);
+int secdiscard_path(const std::string &path);
+int path_device_range(const std::string &path, uint64_t range[2]);
+std::string block_device_for_path(const std::string &path);
+}
 
-int main(int argc, char **argv) {
+int main(int argc, const char * const argv[]) {
     if (argc != 2 || argv[1][0] != '/') {
         usage(argv[0]);
         return -1;
     }
     SLOGD("Running: %s %s", argv[0], argv[1]);
-    std::string target(argv[1]);
-    destroy_key(target);
+    secdiscard_path(argv[1]);
     if (unlink(argv[1]) != 0 && errno != ENOENT) {
         SLOGE("Unable to delete %s: %s",
             argv[1], strerror(errno));
         return -1;
     }
+    SLOGD("Discarded %s", argv[1]);
     return 0;
 }
 
-static void usage(char *progname) {
+namespace {
+
+void usage(const char *progname) {
     fprintf(stderr, "Usage: %s <absolute path>\n", progname);
 }
 
 // BLKSECDISCARD all content in "path", if it's small enough.
-static void destroy_key(const std::string &path) {
+int secdiscard_path(const std::string &path) {
     uint64_t range[2];
-    if (file_device_range(path, range) < 0) {
-        return;
+    if (path_device_range(path, range) == -1) {
+        return -1;
     }
-    int fs_fd = open_block_device_for_path(path);
-    if (fs_fd < 0) {
-        return;
+    auto block_device = block_device_for_path(path);
+    if (block_device.empty()) {
+        return -1;
     }
-    if (ioctl(fs_fd, BLKSECDISCARD, range) != 0) {
+    AutoCloseFD fs_fd(block_device, O_RDWR | O_LARGEFILE);
+    if (!fs_fd) {
+        SLOGE("Failed to open device %s: %s", block_device.c_str(), strerror(errno));
+        return -1;
+    }
+    if (ioctl(fs_fd.get(), BLKSECDISCARD, range) == -1) {
         SLOGE("Unable to BLKSECDISCARD %s: %s", path.c_str(), strerror(errno));
-        close(fs_fd);
-        return;
+        return -1;
     }
-    close(fs_fd);
-    SLOGD("Discarded %s", path.c_str());
+    return 0;
 }
 
 // Find a short range that completely covers the file.
 // If there isn't one, return -1, otherwise 0.
-static int file_device_range(const std::string &path, uint64_t range[2])
+int path_device_range(const std::string &path, uint64_t range[2])
 {
-    int fd = open(path.c_str(), O_RDONLY | O_CLOEXEC);
-    if (fd < 0) {
+    AutoCloseFD fd(path);
+    if (!fd) {
         if (errno == ENOENT) {
             SLOGD("Unable to open %s: %s", path.c_str(), strerror(errno));
         } else {
@@ -102,12 +106,10 @@
     fiemap->fm_flags = 0;
     fiemap->fm_extent_count = 1;
     fiemap->fm_mapped_extents = 0;
-    if (ioctl(fd, FS_IOC_FIEMAP, fiemap) != 0) {
+    if (ioctl(fd.get(), FS_IOC_FIEMAP, fiemap) != 0) {
         SLOGE("Unable to FIEMAP %s: %s", path.c_str(), strerror(errno));
-        close(fd);
         return -1;
     }
-    close(fd);
     if (fiemap->fm_mapped_extents != 1) {
         SLOGE("Expecting one extent, got %d in %s", fiemap->fm_mapped_extents, path.c_str());
         return -1;
@@ -122,7 +124,7 @@
         SLOGE("Extent has unexpected flags %ulx: %s", extent->fe_flags, path.c_str());
         return -1;
     }
-    if (extent->fe_length > MAX_WIPE_LENGTH) {
+    if (extent->fe_length > max_wipe_length) {
         SLOGE("Extent too big, %llu bytes in %s", extent->fe_length, path.c_str());
         return -1;
     }
@@ -131,106 +133,33 @@
     return 0;
 }
 
-// Given a file path, look for the corresponding
-// block device in /proc/mounts and open it.
-static int open_block_device_for_path(const std::string &path)
+// Given a file path, look for the corresponding block device in /proc/mount
+std::string block_device_for_path(const std::string &path)
 {
-    std::string mountsfile("/proc/mounts");
-    std::string mounts;
-    if (read_file_as_string_atomically(mountsfile, mounts) < 0) {
-        return -1;
+    std::unique_ptr<FILE, int(*)(FILE*)> mnts(setmntent("/proc/mounts", "re"), endmntent);
+    if (!mnts) {
+        SLOGE("Unable to open /proc/mounts: %s", strerror(errno));
+        return "";
     }
-    std::string block_device;
-    if (find_block_device_for_path(mounts, path, block_device) < 0) {
-        return -1;
+    std::string result;
+    size_t best_length = 0;
+    struct mntent *mnt; // getmntent returns a thread local, so it's safe.
+    while ((mnt = getmntent(mnts.get())) != nullptr) {
+        auto l = strlen(mnt->mnt_dir);
+        if (l > best_length &&
+            path.size() > l &&
+            path[l] == '/' &&
+            path.compare(0, l, mnt->mnt_dir) == 0) {
+                result = mnt->mnt_fsname;
+                best_length = l;
+        }
     }
-    SLOGD("For path %s block device is %s", path.c_str(), block_device.c_str());
-    int res = open(block_device.c_str(), O_RDWR | O_LARGEFILE | O_CLOEXEC);
-    if (res < 0) {
-        SLOGE("Failed to open device %s: %s", block_device.c_str(), strerror(errno));
-        return -1;
+    if (result.empty()) {
+        SLOGE("Didn't find a mountpoint to match path %s", path.c_str());
+        return "";
     }
-    return res;
+    SLOGD("For path %s block device is %s", path.c_str(), result.c_str());
+    return result;
 }
 
-// Read a file into a buffer in a single gulp, for atomicity.
-// Null-terminate the buffer.
-// Retry until the buffer is big enough.
-static int read_file_as_string_atomically(const std::string &path, std::string &contents)
-{
-    ssize_t buffer_size = INIT_BUFFER_SIZE;
-    while (true) {
-        int fd = open(path.c_str(), O_RDONLY | O_CLOEXEC);
-        if (fd < 0) {
-            SLOGE("Failed to open %s: %s", path.c_str(), strerror(errno));
-            return -1;
-        }
-        contents.resize(buffer_size);
-        ssize_t read_size = read(fd, &contents[0], buffer_size);
-        if (read_size < 0) {
-            SLOGE("Failed to read from %s: %s", path.c_str(), strerror(errno));
-            close(fd);
-            return -1;
-        }
-        close(fd);
-        if (read_size < buffer_size) {
-            contents.resize(read_size);
-            return 0;
-        }
-        SLOGD("%s too big for buffer of size %zu", path.c_str(), buffer_size);
-        buffer_size <<= 1;
-    }
-}
-
-// Search a string representing the contents of /proc/mounts
-// for the mount point of a particular file by prefix matching
-// and return the corresponding block device.
-static int find_block_device_for_path(
-    const std::string &mounts,
-    const std::string &path,
-    std::string &block_device)
-{
-    auto line_begin = mounts.begin();
-    size_t best_prefix = 0;
-    std::string::const_iterator line_end;
-    while (line_begin != mounts.end()) {
-        line_end = std::find(line_begin, mounts.end(), '\n');
-        if (line_end == mounts.end()) {
-            break;
-        }
-        auto device_end = std::find(line_begin, line_end, ' ');
-        if (device_end == line_end) {
-            break;
-        }
-        auto mountpoint_begin = device_end + 1;
-        auto mountpoint_end = std::find(mountpoint_begin, line_end, ' ');
-        if (mountpoint_end == line_end) {
-            break;
-        }
-        if (std::find(line_begin, mountpoint_end, '\\') != mountpoint_end) {
-            // We don't correctly handle escape sequences, and we don't expect
-            // to encounter any, so fail if we do.
-            break;
-        }
-        size_t mountpoint_len = mountpoint_end - mountpoint_begin;
-        if (mountpoint_len > best_prefix &&
-                mountpoint_len < path.length() &&
-                path[mountpoint_len] == '/' &&
-                std::equal(mountpoint_begin, mountpoint_end, path.begin())) {
-            block_device = std::string(line_begin, device_end);
-            best_prefix = mountpoint_len;
-        }
-        line_begin = line_end + 1;
-    }
-    // All of the "break"s above are fatal parse errors.
-    if (line_begin != mounts.end()) {
-        auto bad_line = std::string(line_begin, line_end);
-        SLOGE("Unable to parse line in %s: %s", path.c_str(), bad_line.c_str());
-        return -1;
-    }
-    if (best_prefix == 0) {
-        SLOGE("No prefix found for path: %s", path.c_str());
-        return -1;
-    }
-    return 0;
 }