Additional CTS tests for Android Keystore AES GCM.

Bug: 21936191
Change-Id: I52ea4849782c6a92cf256eff534fefc53c919b37
diff --git a/tests/tests/keystore/src/android/keystore/cts/AES128GCMNoPaddingCipherTest.java b/tests/tests/keystore/src/android/keystore/cts/AES128GCMNoPaddingCipherTest.java
index 6ae13ce..ed9c2b1 100644
--- a/tests/tests/keystore/src/android/keystore/cts/AES128GCMNoPaddingCipherTest.java
+++ b/tests/tests/keystore/src/android/keystore/cts/AES128GCMNoPaddingCipherTest.java
@@ -23,9 +23,14 @@
     private static final byte[] KAT_PLAINTEXT = HexEncoding.decode(
             "6d7596a8fd56ceaec61de7940984b7736fec44f572afc3c8952e4dc6541e2bc6a702c440a37610989543f6"
             + "3fedb047ca2173bc18581944");
-    private static final byte[] KAT_CIPHERTEXT = HexEncoding.decode(
+    private static final byte[] KAT_CIPHERTEXT_WITHOUT_AAD = HexEncoding.decode(
             "b3f6799e8f9326f2df1e80fcd2cb16d78c9dc7cc14bb677862dc6c639b3a6338d24b312d3989e5920b5dbf"
             + "c976765efbfe57bb385940a7a43bdf05bddae3c9d6a2fbbdfcc0cba0");
+    private static final byte[] KAT_AAD = HexEncoding.decode(
+            "d3bc7458914f45d56d5fcfbb2eeff2dcc0e620c1229d90904e98930ea71aa43b6898f846f3244d");
+    private static final byte[] KAT_CIPHERTEXT_WITH_AAD = HexEncoding.decode(
+            "b3f6799e8f9326f2df1e80fcd2cb16d78c9dc7cc14bb677862dc6c639b3a6338d24b312d3989e5920b5dbf"
+            + "c976765efbfe57bb385940a70c106264d81506f8daf9cd6a1c70988c");
 
     @Override
     protected byte[] getKatKey() {
@@ -44,6 +49,16 @@
 
     @Override
     protected byte[] getKatCiphertext() {
-        return KAT_CIPHERTEXT.clone();
+        return KAT_CIPHERTEXT_WITHOUT_AAD.clone();
+    }
+
+    @Override
+    protected byte[] getKatAad() {
+        return KAT_AAD.clone();
+    }
+
+    @Override
+    protected byte[] getKatCiphertextWhenKatAadPresent() {
+        return KAT_CIPHERTEXT_WITH_AAD.clone();
     }
 }
diff --git a/tests/tests/keystore/src/android/keystore/cts/AESECBCipherTestBase.java b/tests/tests/keystore/src/android/keystore/cts/AESECBCipherTestBase.java
index 5ecf22f..b2752dc 100644
--- a/tests/tests/keystore/src/android/keystore/cts/AESECBCipherTestBase.java
+++ b/tests/tests/keystore/src/android/keystore/cts/AESECBCipherTestBase.java
@@ -54,4 +54,8 @@
         }
         return null;
     }
+
+    public void testInitRejectsIvParameterSpec() throws Exception {
+        assertInitRejectsIvParameterSpec(new byte[getBlockSize()]);
+    }
 }
diff --git a/tests/tests/keystore/src/android/keystore/cts/AESGCMCipherTestBase.java b/tests/tests/keystore/src/android/keystore/cts/AESGCMCipherTestBase.java
index d901674..1c3404a 100644
--- a/tests/tests/keystore/src/android/keystore/cts/AESGCMCipherTestBase.java
+++ b/tests/tests/keystore/src/android/keystore/cts/AESGCMCipherTestBase.java
@@ -16,14 +16,21 @@
 
 package android.keystore.cts;
 
+import java.nio.ByteBuffer;
 import java.security.AlgorithmParameters;
+import java.security.Key;
 import java.security.spec.AlgorithmParameterSpec;
 import java.security.spec.InvalidParameterSpecException;
 
+import javax.crypto.AEADBadTagException;
+import javax.crypto.Cipher;
 import javax.crypto.spec.GCMParameterSpec;
 
 abstract class AESGCMCipherTestBase extends BlockCipherTestBase {
 
+    protected abstract byte[] getKatAad();
+    protected abstract byte[] getKatCiphertextWhenKatAadPresent();
+
     @Override
     protected boolean isStreamCipher() {
         return true;
@@ -54,4 +61,148 @@
         GCMParameterSpec spec = params.getParameterSpec(GCMParameterSpec.class);
         return spec.getIV();
     }
+
+    public void testKatEncryptWithAadProvidedInOneGo() throws Exception {
+        createCipher();
+        assertKatTransformWithAadProvidedInOneGo(
+                Cipher.ENCRYPT_MODE,
+                getKatAad(),
+                getKatPlaintext(),
+                getKatCiphertextWhenKatAadPresent());
+    }
+
+    public void testKatDecryptWithAadProvidedInOneGo() throws Exception {
+        createCipher();
+        assertKatTransformWithAadProvidedInOneGo(
+                Cipher.DECRYPT_MODE,
+                getKatAad(),
+                getKatCiphertextWhenKatAadPresent(),
+                getKatPlaintext());
+    }
+
+    public void testKatEncryptWithAadProvidedInChunks() throws Exception {
+        createCipher();
+        assertKatTransformWithAadProvidedInChunks(
+                Cipher.ENCRYPT_MODE,
+                getKatAad(),
+                getKatPlaintext(),
+                getKatCiphertextWhenKatAadPresent(),
+                1);
+        assertKatTransformWithAadProvidedInChunks(
+                Cipher.ENCRYPT_MODE,
+                getKatAad(),
+                getKatPlaintext(),
+                getKatCiphertextWhenKatAadPresent(),
+                8);
+        assertKatTransformWithAadProvidedInChunks(
+                Cipher.ENCRYPT_MODE,
+                getKatAad(),
+                getKatPlaintext(),
+                getKatCiphertextWhenKatAadPresent(),
+                3);
+        assertKatTransformWithAadProvidedInChunks(
+                Cipher.ENCRYPT_MODE,
+                getKatAad(),
+                getKatPlaintext(),
+                getKatCiphertextWhenKatAadPresent(),
+                7);
+        assertKatTransformWithAadProvidedInChunks(
+                Cipher.ENCRYPT_MODE,
+                getKatAad(),
+                getKatPlaintext(),
+                getKatCiphertextWhenKatAadPresent(),
+                23);
+    }
+
+    public void testKatDecryptWithAadProvidedInChunks() throws Exception {
+        createCipher();
+        assertKatTransformWithAadProvidedInChunks(
+                Cipher.DECRYPT_MODE,
+                getKatAad(),
+                getKatCiphertextWhenKatAadPresent(),
+                getKatPlaintext(),
+                1);
+        assertKatTransformWithAadProvidedInChunks(
+                Cipher.DECRYPT_MODE,
+                getKatAad(),
+                getKatCiphertextWhenKatAadPresent(),
+                getKatPlaintext(),
+                8);
+        assertKatTransformWithAadProvidedInChunks(
+                Cipher.DECRYPT_MODE,
+                getKatAad(),
+                getKatCiphertextWhenKatAadPresent(),
+                getKatPlaintext(),
+                3);
+        assertKatTransformWithAadProvidedInChunks(
+                Cipher.DECRYPT_MODE,
+                getKatAad(),
+                getKatCiphertextWhenKatAadPresent(),
+                getKatPlaintext(),
+                7);
+        assertKatTransformWithAadProvidedInChunks(
+                Cipher.DECRYPT_MODE,
+                getKatAad(),
+                getKatCiphertextWhenKatAadPresent(),
+                getKatPlaintext(),
+                23);
+    }
+
+    private void assertKatTransformWithAadProvidedInOneGo(int opmode,
+            byte[] aad, byte[] input, byte[] expectedOutput) throws Exception {
+        initKat(opmode);
+        updateAAD(aad);
+        assertEquals(expectedOutput, doFinal(input));
+
+        initKat(opmode);
+        updateAAD(aad, 0, aad.length);
+        assertEquals(expectedOutput, doFinal(input));
+
+        initKat(opmode);
+        updateAAD(ByteBuffer.wrap(aad));
+        assertEquals(expectedOutput, doFinal(input));
+    }
+
+    private void assertKatTransformWithAadProvidedInChunks(int opmode,
+            byte[] aad, byte[] input, byte[] expectedOutput, int maxChunkSize) throws Exception {
+        createCipher();
+        initKat(opmode);
+        int aadOffset = 0;
+        while (aadOffset < aad.length) {
+            int chunkSize = Math.min(aad.length - aadOffset, maxChunkSize);
+            updateAAD(aad, aadOffset, chunkSize);
+            aadOffset += chunkSize;
+        }
+        assertEquals(expectedOutput, doFinal(input));
+    }
+
+    public void testCiphertextBitflipDetectedWhenDecrypting() throws Exception {
+        createCipher();
+        Key key = importKey(getKatKey());
+        byte[] ciphertext = getKatCiphertext();
+        ciphertext[ciphertext.length / 2] ^= 0x40;
+        init(Cipher.DECRYPT_MODE, key, getKatAlgorithmParameterSpec());
+        try {
+            doFinal(ciphertext);
+            fail();
+        } catch (AEADBadTagException expected) {}
+    }
+
+    public void testAadBitflipDetectedWhenDecrypting() throws Exception {
+        createCipher();
+        Key key = importKey(getKatKey());
+        byte[] ciphertext = getKatCiphertextWhenKatAadPresent();
+        byte[] aad = getKatCiphertext();
+        aad[aad.length / 3] ^= 0x2;
+        init(Cipher.DECRYPT_MODE, key, getKatAlgorithmParameterSpec());
+        updateAAD(aad);
+        try {
+            doFinal(ciphertext);
+            fail();
+        } catch (AEADBadTagException expected) {}
+    }
+
+    public void testInitRejectsIvParameterSpec() throws Exception {
+        assertInitRejectsIvParameterSpec(getKatIv());
+    }
 }
diff --git a/tests/tests/keystore/src/android/keystore/cts/BlockCipherTestBase.java b/tests/tests/keystore/src/android/keystore/cts/BlockCipherTestBase.java
index 398d373..f583c51 100644
--- a/tests/tests/keystore/src/android/keystore/cts/BlockCipherTestBase.java
+++ b/tests/tests/keystore/src/android/keystore/cts/BlockCipherTestBase.java
@@ -45,6 +45,7 @@
 import javax.crypto.NoSuchPaddingException;
 import javax.crypto.SecretKey;
 import javax.crypto.ShortBufferException;
+import javax.crypto.spec.IvParameterSpec;
 import javax.crypto.spec.SecretKeySpec;
 
 abstract class BlockCipherTestBase extends AndroidTestCase {
@@ -828,21 +829,33 @@
     }
 
     public void testUpdateAADNotSupported() throws Exception {
-        createCipher();
-        initKat(Cipher.ENCRYPT_MODE);
         if (isAuthenticatedCipher()) {
-            assertUpdateAADSupported();
-        } else {
-            assertUpdateAADNotSupported();
+            // Not applicable to authenticated ciphers where updateAAD is supported.
+            return;
         }
 
         createCipher();
+        initKat(Cipher.ENCRYPT_MODE);
+        assertUpdateAADNotSupported();
+
+        createCipher();
         initKat(Cipher.DECRYPT_MODE);
-        if (isAuthenticatedCipher()) {
-            assertUpdateAADSupported();
-        } else {
-            assertUpdateAADNotSupported();
+        assertUpdateAADNotSupported();
+    }
+
+    public void testUpdateAADSupported() throws Exception {
+        if (!isAuthenticatedCipher()) {
+            // Not applicable to unauthenticated ciphers where updateAAD is not supported.
+            return;
         }
+
+        createCipher();
+        initKat(Cipher.ENCRYPT_MODE);
+        assertUpdateAADSupported();
+
+        createCipher();
+        initKat(Cipher.DECRYPT_MODE);
+        assertUpdateAADSupported();
     }
 
     private void assertUpdateAADNotSupported() throws Exception {
@@ -1235,7 +1248,7 @@
                 subarray(buffer, outputOffsetInBuffer, outputEndIndexInBuffer));
     }
 
-    private void createCipher() throws NoSuchAlgorithmException,
+    protected void createCipher() throws NoSuchAlgorithmException,
             NoSuchPaddingException  {
         mCipher = Cipher.getInstance(getTransformation());
     }
@@ -1279,7 +1292,7 @@
         return importKey(getKatKey());
     }
 
-    private SecretKey importKey(byte[] keyMaterial) {
+    protected SecretKey importKey(byte[] keyMaterial) {
         try {
             int keyId = mNextKeyId++;
             String keyAlias = "key" + keyId;
@@ -1318,75 +1331,75 @@
         }
     }
 
-    private void initKat(int opmode)
+    protected void initKat(int opmode)
             throws InvalidKeyException, InvalidAlgorithmParameterException {
         init(opmode, getKey(), getKatAlgorithmParameterSpec());
     }
 
-    private void init(int opmode, Key key, AlgorithmParameters spec)
+    protected void init(int opmode, Key key, AlgorithmParameters spec)
             throws InvalidKeyException, InvalidAlgorithmParameterException {
         mCipher.init(opmode, key, spec);
         mOpmode = opmode;
     }
 
-    private void init(int opmode, Key key, AlgorithmParameters spec, SecureRandom random)
+    protected void init(int opmode, Key key, AlgorithmParameters spec, SecureRandom random)
             throws InvalidKeyException, InvalidAlgorithmParameterException {
         mCipher.init(opmode, key, spec, random);
         mOpmode = opmode;
     }
 
-    private void init(int opmode, Key key, AlgorithmParameterSpec spec)
+    protected void init(int opmode, Key key, AlgorithmParameterSpec spec)
             throws InvalidKeyException, InvalidAlgorithmParameterException {
         mCipher.init(opmode, key, spec);
         mOpmode = opmode;
     }
 
-    private void init(int opmode, Key key, AlgorithmParameterSpec spec, SecureRandom random)
+    protected void init(int opmode, Key key, AlgorithmParameterSpec spec, SecureRandom random)
             throws InvalidKeyException, InvalidAlgorithmParameterException {
         mCipher.init(opmode, key, spec, random);
         mOpmode = opmode;
     }
 
-    private void init(int opmode, Key key) throws InvalidKeyException {
+    protected void init(int opmode, Key key) throws InvalidKeyException {
         mCipher.init(opmode, key);
         mOpmode = opmode;
     }
 
-    private void init(int opmode, Key key, SecureRandom random) throws InvalidKeyException {
+    protected void init(int opmode, Key key, SecureRandom random) throws InvalidKeyException {
         mCipher.init(opmode, key, random);
         mOpmode = opmode;
     }
 
-    private byte[] doFinal() throws IllegalBlockSizeException, BadPaddingException {
+    protected byte[] doFinal() throws IllegalBlockSizeException, BadPaddingException {
         return mCipher.doFinal();
     }
 
-    private byte[] doFinal(byte[] input) throws IllegalBlockSizeException, BadPaddingException {
+    protected byte[] doFinal(byte[] input) throws IllegalBlockSizeException, BadPaddingException {
         return mCipher.doFinal(input);
     }
 
-    private byte[] doFinal(byte[] input, int inputOffset, int inputLen)
+    protected byte[] doFinal(byte[] input, int inputOffset, int inputLen)
             throws IllegalBlockSizeException, BadPaddingException {
         return mCipher.doFinal(input, inputOffset, inputLen);
     }
 
-    private int doFinal(byte[] input, int inputOffset, int inputLen, byte[] output)
+    protected int doFinal(byte[] input, int inputOffset, int inputLen, byte[] output)
             throws ShortBufferException, IllegalBlockSizeException, BadPaddingException {
         return mCipher.doFinal(input, inputOffset, inputLen, output);
     }
 
-    private int doFinal(byte[] input, int inputOffset, int inputLen, byte[] output,
+    protected int doFinal(byte[] input, int inputOffset, int inputLen, byte[] output,
             int outputOffset) throws ShortBufferException, IllegalBlockSizeException,
             BadPaddingException {
         return mCipher.doFinal(input, inputOffset, inputLen, output, outputOffset);
     }
 
-    private int doFinal(byte[] output, int outputOffset) throws IllegalBlockSizeException,
+    protected int doFinal(byte[] output, int outputOffset) throws IllegalBlockSizeException,
             ShortBufferException, BadPaddingException {
         return mCipher.doFinal(output, outputOffset);
     }
 
-    private int doFinal(ByteBuffer input, ByteBuffer output) throws ShortBufferException,
+    protected int doFinal(ByteBuffer input, ByteBuffer output) throws ShortBufferException,
             IllegalBlockSizeException, BadPaddingException {
         return mCipher.doFinal(input, output);
     }
@@ -1415,21 +1428,21 @@
         }
     }
 
-    private byte[] update(byte[] input) {
+    protected byte[] update(byte[] input) {
         byte[] output = mCipher.update(input);
         assertUpdateOutputSize(
                 (input != null) ? input.length : 0, (output != null) ? output.length : 0);
         return output;
     }
 
-    private byte[] update(byte[] input, int offset, int len) {
+    protected byte[] update(byte[] input, int offset, int len) {
         byte[] output = mCipher.update(input, offset, len);
         assertUpdateOutputSize(len, (output != null) ? output.length : 0);
 
         return output;
     }
 
-    private int update(byte[] input, int offset, int len, byte[] output)
+    protected int update(byte[] input, int offset, int len, byte[] output)
             throws ShortBufferException {
         int outputLen = mCipher.update(input, offset, len, output);
         assertUpdateOutputSize(len, outputLen);
@@ -1437,7 +1450,7 @@
         return outputLen;
     }
 
-    private int update(byte[] input, int offset, int len, byte[] output, int outputOffset)
+    protected int update(byte[] input, int offset, int len, byte[] output, int outputOffset)
             throws ShortBufferException {
         int outputLen = mCipher.update(input, offset, len, output, outputOffset);
         assertUpdateOutputSize(len, outputLen);
@@ -1445,7 +1458,7 @@
         return outputLen;
     }
 
-    private int update(ByteBuffer input, ByteBuffer output) throws ShortBufferException {
+    protected int update(ByteBuffer input, ByteBuffer output) throws ShortBufferException {
         int inputLimitBefore = input.limit();
         int outputLimitBefore = output.limit();
         int inputLen = input.remaining();
@@ -1463,8 +1476,20 @@
         return outputLen;
     }
 
+    protected void updateAAD(byte[] input) {
+        mCipher.updateAAD(input);
+    }
+
+    protected void updateAAD(byte[] input, int offset, int len) {
+        mCipher.updateAAD(input, offset, len);
+    }
+
+    protected void updateAAD(ByteBuffer input) {
+        mCipher.updateAAD(input);
+    }
+
     @SuppressWarnings("unused")
-    private static void assertEquals(Buffer expected, Buffer actual) {
+    protected static void assertEquals(Buffer expected, Buffer actual) {
         throw new RuntimeException(
                 "Comparing ByteBuffers using their .equals is probably not what you want"
                 + " -- use assertByteBufferEquals instead.");
@@ -1474,7 +1499,7 @@
      * Asserts that the position, limit, and capacity of the provided buffers are the same, and that
      * their contents (from position {@code 0} to capacity) are the same.
      */
-    private static void assertByteBufferEquals(ByteBuffer expected, ByteBuffer actual) {
+    protected static void assertByteBufferEquals(ByteBuffer expected, ByteBuffer actual) {
         if (expected == null) {
             if (actual == null) {
                 return;
@@ -1504,7 +1529,7 @@
                         buffer.array(), buffer.arrayOffset(), buffer.capacity()) + "]";
     }
 
-    private static boolean equals(byte[] arr1, int offset1, int len1, byte[] arr2, int offset2,
+    protected static boolean equals(byte[] arr1, int offset1, int len1, byte[] arr2, int offset2,
             int len2) {
         if (arr1 == null) {
             return (arr2 == null);
@@ -1523,13 +1548,13 @@
         }
     }
 
-    private static byte[] subarray(byte[] array, int beginIndex, int endIndex) {
+    protected static byte[] subarray(byte[] array, int beginIndex, int endIndex) {
         byte[] result = new byte[endIndex - beginIndex];
         System.arraycopy(array, beginIndex, result, 0, result.length);
         return result;
     }
 
-    private static byte[] concat(byte[]... arrays) {
+    protected static byte[] concat(byte[]... arrays) {
         int resultLength = 0;
         for (byte[] array : arrays) {
             resultLength += (array != null) ? array.length : 0;
@@ -1546,11 +1571,11 @@
         return result;
     }
 
-    private static void assertEquals(byte[] expected, byte[] actual) {
+    protected static void assertEquals(byte[] expected, byte[] actual) {
         assertEquals(null, expected, actual);
     }
 
-    private static void assertEquals(String message, byte[] expected, byte[] actual) {
+    protected static void assertEquals(String message, byte[] expected, byte[] actual) {
         if (!Arrays.equals(expected, actual)) {
             StringBuilder detail = new StringBuilder();
             if (expected != null) {
@@ -1572,4 +1597,51 @@
             }
         }
     }
+
+    protected final void assertInitRejectsIvParameterSpec(byte[] iv) throws Exception {
+        Key key = importKey(getKatKey());
+        createCipher();
+        IvParameterSpec spec = new IvParameterSpec(iv);
+        try {
+            init(Cipher.ENCRYPT_MODE, key, spec);
+            fail();
+        } catch (InvalidAlgorithmParameterException expected) {}
+
+        try {
+            init(Cipher.WRAP_MODE, key, spec);
+            fail();
+        } catch (InvalidAlgorithmParameterException expected) {}
+
+        try {
+            init(Cipher.DECRYPT_MODE, key, spec);
+            fail();
+        } catch (InvalidAlgorithmParameterException expected) {}
+
+        try {
+            init(Cipher.UNWRAP_MODE, key, spec);
+            fail();
+        } catch (InvalidAlgorithmParameterException expected) {}
+
+        AlgorithmParameters param = AlgorithmParameters.getInstance("AES");
+        param.init(new IvParameterSpec(iv));
+        try {
+            init(Cipher.ENCRYPT_MODE, key, param);
+            fail();
+        } catch (InvalidAlgorithmParameterException expected) {}
+
+        try {
+            init(Cipher.WRAP_MODE, key, param);
+            fail();
+        } catch (InvalidAlgorithmParameterException expected) {}
+
+        try {
+            init(Cipher.DECRYPT_MODE, key, param);
+            fail();
+        } catch (InvalidAlgorithmParameterException expected) {}
+
+        try {
+            init(Cipher.UNWRAP_MODE, key, param);
+            fail();
+        } catch (InvalidAlgorithmParameterException expected) {}
+    }
 }