Merge "Late binding: add Cipher#init checks"
diff --git a/luni/src/main/java/javax/crypto/Cipher.java b/luni/src/main/java/javax/crypto/Cipher.java
index db40117..b27ea88 100644
--- a/luni/src/main/java/javax/crypto/Cipher.java
+++ b/luni/src/main/java/javax/crypto/Cipher.java
@@ -108,6 +108,50 @@
     };
 
     /**
+     * Used to keep track of which underlying {@code CipherSpi#engineInit(...)}
+     * variant to call when testing suitability.
+     */
+    private static enum InitType {
+        KEY, ALGORITHM_PARAMS, ALGORITHM_PARAM_SPEC,
+    };
+
+    /**
+     * Keeps track of the possible arguments to {@code Cipher#init(...)}.
+     */
+    private static class InitParams {
+        private final InitType initType;
+        private final int opmode;
+        private final Key key;
+        private final SecureRandom random;
+        private final AlgorithmParameterSpec spec;
+        private final AlgorithmParameters params;
+
+        private InitParams(InitType initType, int opmode, Key key, SecureRandom random,
+                AlgorithmParameterSpec spec, AlgorithmParameters params) {
+            this.initType = initType;
+            this.opmode = opmode;
+            this.key = key;
+            this.random = random;
+            this.spec = spec;
+            this.params = params;
+        }
+    }
+
+    /**
+     * Expresses the various types of transforms that may be used during
+     * initialization.
+     */
+    private static class Transform {
+        private final String name;
+        private final NeedToSet needToSet;
+
+        public Transform(String name, NeedToSet needToSet) {
+            this.name = name;
+            this.needToSet = needToSet;
+        }
+    }
+
+    /**
      * The service name.
      */
     private static final String SERVICE = "Cipher";
@@ -298,7 +342,15 @@
 
         String[] transformParts = checkTransformation(transformation);
 
-        if (tryCombinations(null /* key */, provider, transformParts) == null) {
+        Engine.SpiAndProvider sap;
+        try {
+            sap = tryCombinations(null /* params */, provider, transformation, transformParts);
+        } catch (InvalidKeyException | InvalidAlgorithmParameterException e) {
+            // should never happen since we passed in null params
+            throw new ProviderException(e);
+        }
+
+        if (sap == null) {
             if (provider == null) {
                 throw new NoSuchAlgorithmException("No provider found for " + transformation);
             } else {
@@ -309,6 +361,25 @@
         return new Cipher(transformation, transformParts, provider);
     }
 
+    /**
+     * Checks that the provided algorithm {@code transformation} string is a
+     * valid input. The algorithm is the only mandatory field and input can be
+     * of the form:
+     * <ul>
+     * <li><code>"[cipher]"</code>
+     * <li><code>"[cipher]/[mode]/[padding]"</code>
+     * <li><code>"[cipher]//[padding]"</code>
+     * <li><code>"[cipher]/[mode]"</code>
+     * </ul>
+     * <p>
+     * Returns the specified transformation split up into three parts
+     * corresponding to their function:
+     * <p>
+     * <code>
+     * {&lt;algorithm&gt;, &lt;mode&gt;, &lt;padding&gt;}
+     * </code>
+     * <p>
+     */
     private static String[] checkTransformation(String transformation)
             throws NoSuchAlgorithmException {
         // ignore an extra prefix / characters such as in
@@ -340,40 +411,33 @@
     }
 
     /**
-     * Makes sure a CipherSpi that matches this type is selected.
-     *
-     * If {@code key != null} then it assumes that a suitable provider exists for this instance
-     * (used by {@link #init}.
+     * Makes sure a CipherSpi that matches this type is selected. If
+     * {@code key != null} then it assumes that a suitable provider exists for
+     * this instance (used by {@link #init}. If the {@code initParams} is passed
+     * in, then the {@code CipherSpi} returned will be initialized.
      *
      * @throws InvalidKeyException if the specified key cannot be used to
      *             initialize this cipher.
+     * @throws InvalidAlgorithmParameterException
      */
-    private CipherSpi getSpi(Key key) throws InvalidKeyException {
+    private CipherSpi getSpi(InitParams initParams) throws InvalidKeyException,
+            InvalidAlgorithmParameterException {
         if (specifiedSpi != null) {
             return specifiedSpi;
         }
 
         synchronized (initLock) {
             // This is not only a matter of performance. Many methods like update, doFinal, etc.
-            // call {@code #getSpi()} (ie, {@code #getSpi(null /* key */)}) and without this
+            // call {@code #getSpi()} (ie, {@code #getSpi(null /* params */)}) and without this
             // shortcut they would override an spi that was chosen using the key.
-            if (spiImpl != null && key == null) {
+            if (spiImpl != null && initParams == null) {
                 return spiImpl;
             }
 
-            final Engine.SpiAndProvider sap = tryCombinations(
-                    key, specifiedProvider, transformParts);
-
+            final Engine.SpiAndProvider sap = tryCombinations(initParams, specifiedProvider,
+                    transformation, transformParts);
             if (sap == null) {
-                if (key == null) {
-                    throw new ProviderException("No provider for " + transformation);
-                }
-                // Since the key is not null, a suitable provider exists,
-                // and it is an InvalidKeyException.
-                throw new InvalidKeyException(
-                        "No provider offers " + transformation + " for " + key.getAlgorithm()
-                                + " key of class " + key.getClass().getName()
-                                + " and export format " + key.getFormat());
+                throw new ProviderException("No provider found for " + transformation);
             }
 
             spiImpl = (CipherSpi) sap.spi;
@@ -389,8 +453,8 @@
     private CipherSpi getSpi() {
         try {
             return getSpi(null);
-        } catch (InvalidKeyException e) {
-            throw new IllegalStateException("InvalidKeyException thrown when key == null", e);
+        } catch (InvalidKeyException | InvalidAlgorithmParameterException e) {
+            throw new ProviderException("Exception thrown when params == null", e);
         }
     }
 
@@ -411,72 +475,105 @@
     }
 
     /**
-     * Try all combinations of mode strings:
-     *
-     * <pre>
-     *   [cipher]/[mode]/[padding]
-     *   [cipher]/[mode]
-     *   [cipher]//[padding]
-     *   [cipher]
-     * </pre>
+     * Tries to find the correct {@code Cipher} transform to use. Returns a
+     * {@link Engine.SpiAndProvider}, throws the first exception that was
+     * encountered during attempted initialization, or {@code null} if there are
+     * no providers that support the {@code initParams}.
+     * <p>
+     * {@code transformParts} must be in the format returned by
+     * {@link #checkTransformation(String)}. The combinations of mode strings
+     * tried are as follows:
+     * <ul>
+     * <li><code>[cipher]/[mode]/[padding]</code>
+     * <li><code>[cipher]/[mode]</code>
+     * <li><code>[cipher]//[padding]</code>
+     * <li><code>[cipher]</code>
+     * </ul>
      */
-    private static Engine.SpiAndProvider tryCombinations(Key key, Provider provider,
-            String[] transformParts) {
-        Engine.SpiAndProvider sap = null;
-
+    private static Engine.SpiAndProvider tryCombinations(InitParams initParams, Provider provider,
+            String transformation, String[] transformParts) throws InvalidKeyException,
+            InvalidAlgorithmParameterException {
+        // Enumerate all the transforms we need to try
+        ArrayList<Transform> transforms = new ArrayList<Transform>();
         if (transformParts[1] != null && transformParts[2] != null) {
-            sap = tryTransform(key, provider, transformParts[0] + "/" + transformParts[1] + "/"
-                    + transformParts[2], transformParts, NeedToSet.NONE);
-            if (sap != null) {
-                return sap;
-            }
+            transforms.add(new Transform(transformParts[0] + "/" + transformParts[1] + "/"
+                    + transformParts[2], NeedToSet.NONE));
         }
-
         if (transformParts[1] != null) {
-            sap = tryTransform(key, provider, transformParts[0] + "/" + transformParts[1],
-                    transformParts, NeedToSet.PADDING);
-            if (sap != null) {
-                return sap;
-            }
+            transforms.add(new Transform(transformParts[0] + "/" + transformParts[1],
+                    NeedToSet.PADDING));
         }
-
         if (transformParts[2] != null) {
-            sap = tryTransform(key, provider, transformParts[0] + "//" + transformParts[2],
-                    transformParts, NeedToSet.MODE);
-            if (sap != null) {
-                return sap;
-            }
+            transforms.add(new Transform(transformParts[0] + "//" + transformParts[2],
+                    NeedToSet.MODE));
         }
+        transforms.add(new Transform(transformParts[0], NeedToSet.BOTH));
 
-        return tryTransform(key, provider, transformParts[0], transformParts, NeedToSet.BOTH);
-    }
-
-    private static Engine.SpiAndProvider tryTransform(Key key, Provider provider, String transform,
-            String[] transformParts, NeedToSet type) {
-        if (provider != null) {
-            Provider.Service service = provider.getService(SERVICE, transform);
-            if (service == null) {
-                return null;
+        // Try each of the transforms and keep track of the first exception
+        // encountered.
+        Exception cause = null;
+        for (Transform transform : transforms) {
+            if (provider != null) {
+                Provider.Service service = provider.getService(SERVICE, transform.name);
+                if (service == null) {
+                    continue;
+                }
+                return tryTransformWithProvider(initParams, transformParts, transform.needToSet,
+                        service);
             }
-            return tryTransformWithProvider(transformParts, type, service);
-        }
-        ArrayList<Provider.Service> services = ENGINE.getServices(transform);
-        if (services == null || services.isEmpty()) {
-            return null;
-        }
-        for (Provider.Service service : services) {
-            if (key == null || service.supportsParameter(key)) {
-                Engine.SpiAndProvider sap = tryTransformWithProvider(transformParts, type, service);
-                if (sap != null) {
-                    return sap;
+            ArrayList<Provider.Service> services = ENGINE.getServices(transform.name);
+            if (services == null || services.isEmpty()) {
+                continue;
+            }
+            for (Provider.Service service : services) {
+                if (initParams == null || initParams.key == null
+                        || service.supportsParameter(initParams.key)) {
+                    try {
+                        Engine.SpiAndProvider sap = tryTransformWithProvider(initParams,
+                                transformParts, transform.needToSet, service);
+                        if (sap != null) {
+                            return sap;
+                        }
+                    } catch (Exception e) {
+                        if (cause == null) {
+                            cause = e;
+                        }
+                    }
                 }
             }
         }
-        return null;
+        if (cause instanceof InvalidKeyException) {
+            throw (InvalidKeyException) cause;
+        } else if (cause instanceof InvalidAlgorithmParameterException) {
+            throw (InvalidAlgorithmParameterException) cause;
+        } else if (cause instanceof RuntimeException) {
+            throw (RuntimeException) cause;
+        } else if (cause != null) {
+            throw new InvalidKeyException("No provider can be initialized with given key", cause);
+        } else if (initParams == null || initParams.key == null) {
+            return null;
+        } else {
+            // Since the key is not null, a suitable provider exists,
+            // and it is an InvalidKeyException.
+            throw new InvalidKeyException("No provider offers " + transformation + " for "
+                    + initParams.key.getAlgorithm() + " key of class "
+                    + initParams.key.getClass().getName() + " and export format "
+                    + initParams.key.getFormat());
+        }
     }
 
-    private static Engine.SpiAndProvider tryTransformWithProvider(String[] transformParts,
-            NeedToSet type, Provider.Service service) {
+    /**
+     * Tries to initialize the {@code Cipher} from a given {@code service}. If
+     * initialization is successful, the initialized {@code spi} is returned. If
+     * the {@code service} cannot be initialized with the specified
+     * {@code initParams}, then it's expected to throw
+     * {@code InvalidKeyException} or {@code InvalidAlgorithmParameterException}
+     * as a hint to the caller that it should continue searching for a
+     * {@code Service} that will work.
+     */
+    private static Engine.SpiAndProvider tryTransformWithProvider(InitParams initParams,
+            String[] transformParts, NeedToSet type, Provider.Service service)
+            throws InvalidKeyException, InvalidAlgorithmParameterException {
         try {
             /*
              * Check to see if the Cipher even supports the attributes before
@@ -491,9 +588,6 @@
             if (sap.spi == null || sap.provider == null) {
                 return null;
             }
-            if (!(sap.spi instanceof CipherSpi)) {
-                return null;
-            }
             CipherSpi spi = (CipherSpi) sap.spi;
             if (((type == NeedToSet.MODE) || (type == NeedToSet.BOTH))
                     && (transformParts[1] != null)) {
@@ -503,6 +597,24 @@
                     && (transformParts[2] != null)) {
                 spi.engineSetPadding(transformParts[2]);
             }
+
+            if (initParams != null) {
+                switch (initParams.initType) {
+                    case ALGORITHM_PARAMS:
+                        spi.engineInit(initParams.opmode, initParams.key, initParams.params,
+                                initParams.random);
+                        break;
+                    case ALGORITHM_PARAM_SPEC:
+                        spi.engineInit(initParams.opmode, initParams.key, initParams.spec,
+                                initParams.random);
+                        break;
+                    case KEY:
+                        spi.engineInit(initParams.opmode, initParams.key, initParams.random);
+                        break;
+                    default:
+                        throw new AssertionError("This should never be reached");
+                }
+            }
             return sap;
         } catch (NoSuchAlgorithmException ignored) {
         } catch (NoSuchPaddingException ignored) {
@@ -614,6 +726,10 @@
 
     }
 
+    /**
+     * Checks that the provided {@code mode} is one that is valid for
+     * {@code Cipher}.
+     */
     private void checkMode(int mode) {
         if (mode != ENCRYPT_MODE && mode != DECRYPT_MODE && mode != UNWRAP_MODE
                 && mode != WRAP_MODE) {
@@ -695,7 +811,12 @@
         //        FIXME InvalidKeyException
         //        if keysize exceeds the maximum allowable keysize
         //        (jurisdiction policy files)
-        getSpi(key).engineInit(opmode, key, random);
+        try {
+            getSpi(new InitParams(InitType.KEY, opmode, key, random, null, null));
+        } catch (InvalidAlgorithmParameterException e) {
+            // Should never happen since we only specified the key.
+            throw new ProviderException("Invalid parameters when params == null", e);
+        }
         mode = opmode;
     }
 
@@ -785,7 +906,7 @@
         //        FIXME InvalidAlgorithmParameterException
         //        cryptographic strength exceed the legal limits
         //        (jurisdiction policy files)
-        getSpi(key).engineInit(opmode, key, params, random);
+        getSpi(new InitParams(InitType.ALGORITHM_PARAM_SPEC, opmode, key, random, params, null));
         mode = opmode;
     }
 
@@ -876,7 +997,7 @@
         //        FIXME InvalidAlgorithmParameterException
         //        cryptographic strength exceed the legal limits
         //        (jurisdiction policy files)
-        getSpi(key).engineInit(opmode, key, params, random);
+        getSpi(new InitParams(InitType.ALGORITHM_PARAMS, opmode, key, random, null, params));
         mode = opmode;
     }
 
@@ -1001,7 +1122,12 @@
         //        if keysize exceeds the maximum allowable keysize
         //        (jurisdiction policy files)
         final Key key = certificate.getPublicKey();
-        getSpi(key).engineInit(opmode, key, random);
+        try {
+            getSpi(new InitParams(InitType.KEY, opmode, key, random, null, null));
+        } catch (InvalidAlgorithmParameterException e) {
+            // Should never happen since we only specified the key.
+            throw new ProviderException("Invalid parameters when params == null", e);
+        }
         mode = opmode;
     }
 
diff --git a/luni/src/test/java/libcore/javax/crypto/CipherTest.java b/luni/src/test/java/libcore/javax/crypto/CipherTest.java
index 5f05d0c..68545c9 100644
--- a/luni/src/test/java/libcore/javax/crypto/CipherTest.java
+++ b/luni/src/test/java/libcore/javax/crypto/CipherTest.java
@@ -989,19 +989,9 @@
         Security.addProvider(mockProviderInvalid);
         try {
             Cipher c = Cipher.getInstance("FOO");
-            if (StandardNames.IS_RI) {
-                c.init(Cipher.ENCRYPT_MODE, new SecretKeySpec(new byte[16], "FOO"));
-            } else {
-                fail("Should not find any matching providers; found: " + c);
-            }
-        } catch (NoSuchAlgorithmException maybe) {
-            if (StandardNames.IS_RI) {
-                throw maybe;
-            }
-        } catch (ClassCastException maybe) {
-            if (!StandardNames.IS_RI) {
-                throw maybe;
-            }
+            c.init(Cipher.ENCRYPT_MODE, new SecretKeySpec(new byte[16], "FOO"));
+            fail("Should not find any matching providers; found: " + c);
+        } catch (ClassCastException expected) {
         } finally {
             Security.removeProvider(mockProviderInvalid.getName());
         }