Refactor key creation to use a registry of key factories.

Change-Id: I6ebab7b44e4a5dbea282397ab8aca437e71bdca0
diff --git a/aes_key.cpp b/aes_key.cpp
index 5eff739..e692711 100644
--- a/aes_key.cpp
+++ b/aes_key.cpp
@@ -26,6 +26,22 @@
 
 namespace keymaster {
 
+class AesKeyFactory : public SymmetricKeyFactory {
+  public:
+    keymaster_algorithm_t registry_key() const { return KM_ALGORITHM_AES; }
+
+    virtual Key* LoadKey(const UnencryptedKeyBlob& blob, const Logger& logger,
+                         keymaster_error_t* error) {
+        return new AesKey(blob, logger, error);
+    }
+
+    virtual SymmetricKey* CreateKey(const AuthorizationSet& auths, const Logger& logger) {
+        return new AesKey(auths, logger);
+    }
+};
+
+static KeyFactoryRegistry::Registration<AesKeyFactory> registration;
+
 Operation* AesKey::CreateOperation(keymaster_purpose_t purpose, keymaster_error_t* error) {
     keymaster_block_mode_t block_mode;
     if (!authorizations().GetTagValue(TAG_BLOCK_MODE, &block_mode)) {
diff --git a/asymmetric_key.cpp b/asymmetric_key.cpp
index 064cd7e..65249fe 100644
--- a/asymmetric_key.cpp
+++ b/asymmetric_key.cpp
@@ -14,16 +14,79 @@
  * limitations under the License.
  */
 
+#include "asymmetric_key.h"
+
 #include <openssl/x509.h>
 
 #include <hardware/keymaster_defs.h>
 
-#include "asymmetric_key.h"
+#include "ecdsa_key.h"
 #include "openssl_utils.h"
+#include "rsa_key.h"
 #include "unencrypted_key_blob.h"
 
 namespace keymaster {
 
+struct PKCS8_PRIV_KEY_INFO_Delete {
+    void operator()(PKCS8_PRIV_KEY_INFO* p) const { PKCS8_PRIV_KEY_INFO_free(p); }
+};
+
+EVP_PKEY* AsymmetricKeyFactory::ExtractEvpKey(keymaster_key_format_t key_format,
+                                              keymaster_algorithm_t expected_algorithm,
+                                              const uint8_t* key_data, size_t key_data_length,
+                                              keymaster_error_t* error) {
+    *error = KM_ERROR_OK;
+
+    if (key_data == NULL || key_data_length <= 0) {
+        *error = KM_ERROR_INVALID_KEY_BLOB;
+        return NULL;
+    }
+
+    if (key_format != KM_KEY_FORMAT_PKCS8) {
+        *error = KM_ERROR_UNSUPPORTED_KEY_FORMAT;
+        return NULL;
+    }
+
+    UniquePtr<PKCS8_PRIV_KEY_INFO, PKCS8_PRIV_KEY_INFO_Delete> pkcs8(
+        d2i_PKCS8_PRIV_KEY_INFO(NULL, &key_data, key_data_length));
+    if (pkcs8.get() == NULL) {
+        *error = KM_ERROR_MEMORY_ALLOCATION_FAILED;
+        return NULL;
+    }
+
+    UniquePtr<EVP_PKEY, EVP_PKEY_Delete> pkey(EVP_PKCS82PKEY(pkcs8.get()));
+    if (pkey.get() == NULL || EVP_PKEY_type(pkey->type) != convert_to_evp(expected_algorithm)) {
+        *error = KM_ERROR_INVALID_KEY_BLOB;
+        return NULL;
+    }
+
+    return pkey.release();
+}
+
+static const keymaster_key_format_t supported_import_formats[] = {KM_KEY_FORMAT_PKCS8};
+const keymaster_key_format_t* AsymmetricKeyFactory::SupportedImportFormats(size_t* format_count) {
+    *format_count = array_length(supported_import_formats);
+    return supported_import_formats;
+}
+
+static const keymaster_key_format_t supported_export_formats[] = {KM_KEY_FORMAT_X509};
+const keymaster_key_format_t* AsymmetricKeyFactory::SupportedExportFormats(size_t* format_count) {
+    *format_count = array_length(supported_export_formats);
+    return supported_export_formats;
+}
+
+/* static */
+int AsymmetricKeyFactory::convert_to_evp(keymaster_algorithm_t algorithm) {
+    switch (algorithm) {
+    case KM_ALGORITHM_RSA:
+        return EVP_PKEY_RSA;
+    case KM_ALGORITHM_ECDSA:
+        return EVP_PKEY_EC;
+    default:
+        return -1;
+    };
+}
+
 keymaster_error_t AsymmetricKey::LoadKey(const UnencryptedKeyBlob& blob) {
     UniquePtr<EVP_PKEY, EVP_PKEY_Delete> evp_key(EVP_PKEY_new());
     if (evp_key.get() == NULL)
diff --git a/asymmetric_key.h b/asymmetric_key.h
index ead8189..b715077 100644
--- a/asymmetric_key.h
+++ b/asymmetric_key.h
@@ -23,6 +23,19 @@
 
 namespace keymaster {
 
+class AsymmetricKeyFactory : public KeyFactory {
+  protected:
+    EVP_PKEY* ExtractEvpKey(keymaster_key_format_t key_format,
+                            keymaster_algorithm_t expected_algorithm, const uint8_t* key_data,
+                            size_t key_data_length, keymaster_error_t* error);
+
+    virtual const keymaster_key_format_t* SupportedImportFormats(size_t* format_count);
+    virtual const keymaster_key_format_t* SupportedExportFormats(size_t* format_count);
+
+  private:
+    static int convert_to_evp(keymaster_algorithm_t algorithm);
+};
+
 class AsymmetricKey : public Key {
   public:
   protected:
diff --git a/ecdsa_key.cpp b/ecdsa_key.cpp
index 84d9649..b499fdb 100644
--- a/ecdsa_key.cpp
+++ b/ecdsa_key.cpp
@@ -23,9 +23,32 @@
 
 const uint32_t ECDSA_DEFAULT_KEY_SIZE = 224;
 
-/* static */
-EcdsaKey* EcdsaKey::GenerateKey(const AuthorizationSet& key_description, const Logger& logger,
-                                keymaster_error_t* error) {
+class EcdsaKeyFactory : public AsymmetricKeyFactory {
+  public:
+    virtual keymaster_algorithm_t registry_key() const { return KM_ALGORITHM_ECDSA; }
+
+    virtual Key* GenerateKey(const AuthorizationSet& key_description, const Logger& logger,
+                             keymaster_error_t* error);
+    virtual Key* ImportKey(const AuthorizationSet& key_description,
+                           keymaster_key_format_t key_format, const uint8_t* key_data,
+                           size_t key_data_length, const Logger& logger, keymaster_error_t* error);
+    virtual Key* LoadKey(const UnencryptedKeyBlob& blob, const Logger& logger,
+                         keymaster_error_t* error) {
+        return new EcdsaKey(blob, logger, error);
+    }
+
+  private:
+    static EC_GROUP* choose_group(size_t key_size_bits);
+    static keymaster_error_t get_group_size(const EC_GROUP& group, size_t* key_size_bits);
+
+    struct EC_GROUP_Delete {
+        void operator()(EC_GROUP* p) { EC_GROUP_free(p); }
+    };
+};
+static KeyFactoryRegistry::Registration<EcdsaKeyFactory> registration;
+
+Key* EcdsaKeyFactory::GenerateKey(const AuthorizationSet& key_description, const Logger& logger,
+                                  keymaster_error_t* error) {
     if (!error)
         return NULL;
 
@@ -35,7 +58,7 @@
     if (!authorizations.GetTagValue(TAG_KEY_SIZE, &key_size))
         authorizations.push_back(Authorization(TAG_KEY_SIZE, key_size));
 
-    UniquePtr<EC_KEY, ECDSA_Delete> ecdsa_key(EC_KEY_new());
+    UniquePtr<EC_KEY, EcdsaKey::ECDSA_Delete> ecdsa_key(EC_KEY_new());
     UniquePtr<EVP_PKEY, EVP_PKEY_Delete> pkey(EVP_PKEY_new());
     if (ecdsa_key.get() == NULL || pkey.get() == NULL) {
         *error = KM_ERROR_MEMORY_ALLOCATION_FAILED;
@@ -65,24 +88,32 @@
     return new_key;
 }
 
-/* static */
-EcdsaKey* EcdsaKey::ImportKey(const AuthorizationSet& key_description, EVP_PKEY* pkey,
-                              const Logger& logger, keymaster_error_t* error) {
+Key* EcdsaKeyFactory::ImportKey(const AuthorizationSet& key_description,
+                                keymaster_key_format_t key_format, const uint8_t* key_data,
+                                size_t key_data_length, const Logger& logger,
+                                keymaster_error_t* error) {
     if (!error)
         return NULL;
-    *error = KM_ERROR_UNKNOWN_ERROR;
 
-    UniquePtr<EC_KEY, ECDSA_Delete> ecdsa_key(EVP_PKEY_get1_EC_KEY(pkey));
-    if (!ecdsa_key.get())
+    UniquePtr<EVP_PKEY, EVP_PKEY_Delete> pkey(
+        ExtractEvpKey(key_format, KM_ALGORITHM_ECDSA, key_data, key_data_length, error));
+    if (*error != KM_ERROR_OK)
         return NULL;
+    assert(pkey.get());
 
-    AuthorizationSet authorizations(key_description);
+    UniquePtr<EC_KEY, EcdsaKey::ECDSA_Delete> ecdsa_key(EVP_PKEY_get1_EC_KEY(pkey.get()));
+    if (!ecdsa_key.get()) {
+        *error = KM_ERROR_UNKNOWN_ERROR;
+        return NULL;
+    }
 
     size_t extracted_key_size_bits;
     *error = get_group_size(*EC_KEY_get0_group(ecdsa_key.get()), &extracted_key_size_bits);
     if (*error != KM_ERROR_OK)
         return NULL;
 
+    AuthorizationSet authorizations(key_description);
+
     uint32_t key_size_bits;
     if (authorizations.GetTagValue(TAG_KEY_SIZE, &key_size_bits)) {
         // key_size_bits specified, make sure it matches the key.
@@ -113,7 +144,7 @@
 }
 
 /* static */
-EC_GROUP* EcdsaKey::choose_group(size_t key_size_bits) {
+EC_GROUP* EcdsaKeyFactory::choose_group(size_t key_size_bits) {
     switch (key_size_bits) {
     case 224:
         return EC_GROUP_new_by_curve_name(NID_secp224r1);
@@ -134,7 +165,7 @@
 }
 
 /* static */
-keymaster_error_t EcdsaKey::get_group_size(const EC_GROUP& group, size_t* key_size_bits) {
+keymaster_error_t EcdsaKeyFactory::get_group_size(const EC_GROUP& group, size_t* key_size_bits) {
     switch (EC_GROUP_get_curve_name(&group)) {
     case NID_secp224r1:
         *key_size_bits = 224;
diff --git a/ecdsa_key.h b/ecdsa_key.h
index b657a65..19df1ed 100644
--- a/ecdsa_key.h
+++ b/ecdsa_key.h
@@ -23,23 +23,19 @@
 
 namespace keymaster {
 
+class EcdsaKeyFactory;
+
 class EcdsaKey : public AsymmetricKey {
   public:
-    static EcdsaKey* GenerateKey(const AuthorizationSet& key_description, const Logger& logger,
-                                 keymaster_error_t* error);
-    static EcdsaKey* ImportKey(const AuthorizationSet& key_description, EVP_PKEY* pkey,
-                               const Logger& logger, keymaster_error_t* error);
-    EcdsaKey(const UnencryptedKeyBlob& blob, const Logger& logger, keymaster_error_t* error);
-
     virtual Operation* CreateOperation(keymaster_purpose_t purpose, keymaster_error_t* error);
 
   private:
+    friend EcdsaKeyFactory;
+
+    EcdsaKey(const UnencryptedKeyBlob& blob, const Logger& logger, keymaster_error_t* error);
     EcdsaKey(EC_KEY* ecdsa_key, const AuthorizationSet auths, const Logger& logger)
         : AsymmetricKey(auths, logger), ecdsa_key_(ecdsa_key) {}
 
-    static EC_GROUP* choose_group(size_t key_size_bits);
-    static keymaster_error_t get_group_size(const EC_GROUP& group, size_t* key_size_bits);
-
     virtual int evp_key_type() { return EVP_PKEY_EC; }
     virtual bool InternalToEvp(EVP_PKEY* pkey) const;
     virtual bool EvpToInternal(const EVP_PKEY* pkey);
@@ -48,10 +44,6 @@
         void operator()(EC_KEY* p) { EC_KEY_free(p); }
     };
 
-    struct EC_GROUP_Delete {
-        void operator()(EC_GROUP* p) { EC_GROUP_free(p); }
-    };
-
     UniquePtr<EC_KEY, ECDSA_Delete> ecdsa_key_;
 };
 
diff --git a/google_keymaster.cpp b/google_keymaster.cpp
index f89b6f7..2024c90 100644
--- a/google_keymaster.cpp
+++ b/google_keymaster.cpp
@@ -44,6 +44,7 @@
     if (operation_table_.get() == NULL)
         operation_table_size_ = 0;
 }
+
 GoogleKeymaster::~GoogleKeymaster() {
     for (size_t i = 0; i < operation_table_size_; ++i)
         if (operation_table_[i].operation != NULL)
@@ -60,14 +61,13 @@
 // methods that return the same information.  They'll get out of sync.  Best to put the knowledge in
 // the keytypes and provide some mechanism for GoogleKeymaster to query the keytypes for the
 // information.
-
-keymaster_algorithm_t supported_algorithms[] = {
-    KM_ALGORITHM_RSA, KM_ALGORITHM_ECDSA, KM_ALGORITHM_AES,
-};
+//
+// UPDATE: This TODO has been completed for supported algorithms.  It still needs to be done for
+// modes, padding, etc.  This will be done with a registry of operation factories.
 
 template <typename T>
 bool check_supported(keymaster_algorithm_t algorithm, SupportedResponse<T>* response) {
-    if (!array_contains(supported_algorithms, algorithm)) {
+    if (KeyFactoryRegistry::Get(algorithm) == NULL) {
         response->error = KM_ERROR_UNSUPPORTED_ALGORITHM;
         return false;
     }
@@ -88,7 +88,24 @@
     SupportedResponse<keymaster_algorithm_t>* response) const {
     if (response == NULL)
         return;
-    response->SetResults(supported_algorithms);
+
+    size_t factory_count = 0;
+    const KeyFactory** factories = KeyFactoryRegistry::GetAll(&factory_count);
+    assert(factories != NULL);
+    assert(factory_count > 0);
+
+    UniquePtr<keymaster_algorithm_t[]> algorithms(new keymaster_algorithm_t[factory_count]);
+    if (!algorithms.get()) {
+        response->error = KM_ERROR_MEMORY_ALLOCATION_FAILED;
+        return;
+    }
+
+    for (size_t i = 0; i < factory_count; ++i)
+        algorithms[i] = factories[i]->registry_key();
+
+    response->results = algorithms.release();
+    response->results_length = factory_count;
+    response->error = KM_ERROR_OK;
 }
 
 void GoogleKeymaster::SupportedBlockModes(
@@ -153,42 +170,26 @@
     }
 }
 
-keymaster_key_format_t supported_import_formats[] = {KM_KEY_FORMAT_PKCS8};
 void GoogleKeymaster::SupportedImportFormats(
     keymaster_algorithm_t algorithm, SupportedResponse<keymaster_key_format_t>* response) const {
     if (response == NULL || !check_supported(algorithm, response))
         return;
 
-    response->error = KM_ERROR_OK;
-    switch (algorithm) {
-    case KM_ALGORITHM_RSA:
-    case KM_ALGORITHM_DSA:
-    case KM_ALGORITHM_ECDSA:
-        response->SetResults(supported_import_formats);
-        break;
-    default:
-        response->results_length = 0;
-        break;
-    }
+    size_t count;
+    const keymaster_key_format_t* formats =
+        KeyFactoryRegistry::Get(algorithm)->SupportedImportFormats(&count);
+    response->SetResults(formats, count);
 }
 
-keymaster_key_format_t supported_export_formats[] = {KM_KEY_FORMAT_X509};
 void GoogleKeymaster::SupportedExportFormats(
     keymaster_algorithm_t algorithm, SupportedResponse<keymaster_key_format_t>* response) const {
     if (response == NULL || !check_supported(algorithm, response))
         return;
 
-    response->error = KM_ERROR_OK;
-    switch (algorithm) {
-    case KM_ALGORITHM_RSA:
-    case KM_ALGORITHM_DSA:
-    case KM_ALGORITHM_ECDSA:
-        response->SetResults(supported_export_formats);
-        break;
-    default:
-        response->results_length = 0;
-        break;
-    }
+    size_t count;
+    const keymaster_key_format_t* formats =
+        KeyFactoryRegistry::Get(algorithm)->SupportedExportFormats(&count);
+    response->SetResults(formats, count);
 }
 
 void GoogleKeymaster::GenerateKey(const GenerateKeyRequest& request,
@@ -196,7 +197,15 @@
     if (response == NULL)
         return;
 
-    UniquePtr<Key> key(Key::GenerateKey(request.key_description, logger(), &response->error));
+    keymaster_algorithm_t algorithm;
+    KeyFactory* factory = 0;
+    UniquePtr<Key> key;
+    if (!request.key_description.GetTagValue(TAG_ALGORITHM, &algorithm) ||
+        !(factory = KeyFactoryRegistry::Get(algorithm)))
+        response->error = KM_ERROR_UNSUPPORTED_ALGORITHM;
+    else
+        key.reset(factory->GenerateKey(request.key_description, logger(), &response->error));
+
     if (response->error != KM_ERROR_OK)
         return;
 
@@ -310,8 +319,16 @@
     if (response == NULL)
         return;
 
-    UniquePtr<Key> key(Key::ImportKey(request.key_description, request.key_format, request.key_data,
-                                      request.key_data_length, logger(), &response->error));
+    keymaster_algorithm_t algorithm;
+    KeyFactory* factory = 0;
+    UniquePtr<Key> key;
+    if (!request.key_description.GetTagValue(TAG_ALGORITHM, &algorithm) ||
+        !(factory = KeyFactoryRegistry::Get(algorithm)))
+        response->error = KM_ERROR_UNSUPPORTED_ALGORITHM;
+    else
+        key.reset(factory->ImportKey(request.key_description, request.key_format, request.key_data,
+                                     request.key_data_length, logger(), &response->error));
+
     if (response->error != KM_ERROR_OK)
         return;
 
@@ -369,7 +386,12 @@
     UniquePtr<UnencryptedKeyBlob> blob(LoadKeyBlob(key, client_params, error));
     if (*error != KM_ERROR_OK)
         return NULL;
-    return Key::CreateKey(*blob, logger(), error);
+
+    KeyFactory* factory = 0;
+    if ((factory = KeyFactoryRegistry::Get(blob->algorithm())))
+        return factory->LoadKey(*blob, logger(), error);
+    *error = KM_ERROR_UNSUPPORTED_ALGORITHM;
+    return NULL;
 }
 
 UnencryptedKeyBlob* GoogleKeymaster::LoadKeyBlob(const keymaster_key_blob_t& key,
diff --git a/google_keymaster_test.cpp b/google_keymaster_test.cpp
index 9d280f0..91aef7b 100644
--- a/google_keymaster_test.cpp
+++ b/google_keymaster_test.cpp
@@ -379,8 +379,9 @@
     size_t len;
     keymaster_algorithm_t* algorithms;
     EXPECT_EQ(KM_ERROR_OK, device()->get_supported_algorithms(device(), &algorithms, &len));
-    EXPECT_TRUE(ResponseContains({KM_ALGORITHM_RSA, KM_ALGORITHM_ECDSA, KM_ALGORITHM_AES},
-                                 algorithms, len));
+    EXPECT_TRUE(ResponseContains(
+        {KM_ALGORITHM_RSA, KM_ALGORITHM_ECDSA, KM_ALGORITHM_AES, KM_ALGORITHM_HMAC}, algorithms,
+        len));
     free(algorithms);
 }
 
@@ -842,7 +843,7 @@
     string pk8_key = read_file("rsa_privkey_pk8.der");
     ASSERT_EQ(633U, pk8_key.size());
 
-    ASSERT_EQ(KM_ERROR_OK, ImportKey(ParamBuilder().SigningKey().NoDigestOrPadding(),
+    ASSERT_EQ(KM_ERROR_OK, ImportKey(ParamBuilder().RsaSigningKey().NoDigestOrPadding(),
                                      KM_KEY_FORMAT_PKCS8, pk8_key));
 
     // Check values derived from the key.
@@ -865,8 +866,7 @@
     ASSERT_EQ(633U, pk8_key.size());
     ASSERT_EQ(KM_ERROR_IMPORT_PARAMETER_MISMATCH,
               ImportKey(ParamBuilder()
-                            .SigningKey()
-                            .Option(TAG_KEY_SIZE, 2048)  // Doesn't match key
+                            .RsaSigningKey(2048)  // Size doesn't match key
                             .NoDigestOrPadding(),
                         KM_KEY_FORMAT_PKCS8, pk8_key));
 }
@@ -876,7 +876,7 @@
     ASSERT_EQ(633U, pk8_key.size());
     ASSERT_EQ(KM_ERROR_IMPORT_PARAMETER_MISMATCH,
               ImportKey(ParamBuilder()
-                            .SigningKey()
+                            .RsaSigningKey()
                             .Option(TAG_RSA_PUBLIC_EXPONENT, 3)  // Doesn't match key
                             .NoDigestOrPadding(),
                         KM_KEY_FORMAT_PKCS8, pk8_key));
@@ -886,7 +886,8 @@
     string pk8_key = read_file("ec_privkey_pk8.der");
     ASSERT_EQ(138U, pk8_key.size());
 
-    ASSERT_EQ(KM_ERROR_OK, ImportKey(ParamBuilder().SigningKey(), KM_KEY_FORMAT_PKCS8, pk8_key));
+    ASSERT_EQ(KM_ERROR_OK,
+              ImportKey(ParamBuilder().EcdsaSigningKey(), KM_KEY_FORMAT_PKCS8, pk8_key));
 
     // Check values derived from the key.
     EXPECT_TRUE(contains(sw_enforced(), TAG_ALGORITHM, KM_ALGORITHM_ECDSA));
@@ -906,8 +907,8 @@
     string pk8_key = read_file("ec_privkey_pk8.der");
     ASSERT_EQ(138U, pk8_key.size());
 
-    ASSERT_EQ(KM_ERROR_OK, ImportKey(ParamBuilder().SigningKey().Option(TAG_KEY_SIZE, 256),
-                                     KM_KEY_FORMAT_PKCS8, pk8_key));
+    ASSERT_EQ(KM_ERROR_OK,
+              ImportKey(ParamBuilder().EcdsaSigningKey(256), KM_KEY_FORMAT_PKCS8, pk8_key));
 
     // Check values derived from the key.
     EXPECT_TRUE(contains(sw_enforced(), TAG_ALGORITHM, KM_ALGORITHM_ECDSA));
@@ -927,8 +928,8 @@
     string pk8_key = read_file("ec_privkey_pk8.der");
     ASSERT_EQ(138U, pk8_key.size());
     ASSERT_EQ(KM_ERROR_IMPORT_PARAMETER_MISMATCH,
-              ImportKey(ParamBuilder().SigningKey().Option(TAG_KEY_SIZE, 224), KM_KEY_FORMAT_PKCS8,
-                        pk8_key));
+              ImportKey(ParamBuilder().EcdsaSigningKey(224),  // Size does not match key
+                        KM_KEY_FORMAT_PKCS8, pk8_key));
 }
 
 typedef KeymasterTest VersionTest;
diff --git a/hmac_key.cpp b/hmac_key.cpp
index 7314a68..d49f32c 100644
--- a/hmac_key.cpp
+++ b/hmac_key.cpp
@@ -23,6 +23,22 @@
 
 namespace keymaster {
 
+class HmacKeyFactory : public SymmetricKeyFactory {
+  public:
+    keymaster_algorithm_t registry_key() const { return KM_ALGORITHM_HMAC; }
+
+    virtual Key* LoadKey(const UnencryptedKeyBlob& blob, const Logger& logger,
+                         keymaster_error_t* error) {
+        return new HmacKey(blob, logger, error);
+    }
+
+    virtual SymmetricKey* CreateKey(const AuthorizationSet& auths, const Logger& logger) {
+        return new HmacKey(auths, logger);
+    }
+};
+
+static KeyFactoryRegistry::Registration<HmacKeyFactory> registration;
+
 Operation* HmacKey::CreateOperation(keymaster_purpose_t purpose, keymaster_error_t* error) {
     *error = KM_ERROR_OK;
 
diff --git a/include/keymaster/google_keymaster_messages.h b/include/keymaster/google_keymaster_messages.h
index 432d68f..ba9e1c4 100644
--- a/include/keymaster/google_keymaster_messages.h
+++ b/include/keymaster/google_keymaster_messages.h
@@ -111,13 +111,17 @@
     ~SupportedResponse() { delete[] results; }
 
     template <size_t N> void SetResults(const T(&arr)[N]) {
+        SetResults(arr, N);
+    }
+
+    void SetResults(const T* arr, size_t n) {
         delete[] results;
         results_length = 0;
-        results = dup_array(arr);
+        results = dup_array(arr, n);
         if (results == NULL) {
             error = KM_ERROR_MEMORY_ALLOCATION_FAILED;
         } else {
-            results_length = N;
+            results_length = n;
             error = KM_ERROR_OK;
         }
     }
diff --git a/include/keymaster/google_keymaster_utils.h b/include/keymaster/google_keymaster_utils.h
index ae98407..bba0c36 100644
--- a/include/keymaster/google_keymaster_utils.h
+++ b/include/keymaster/google_keymaster_utils.h
@@ -62,15 +62,23 @@
 
 /**
  * Duplicate the array \p a.  The memory for the new array is allocated and the caller takes
+ * responsibility.
+ */
+template <typename T> inline T* dup_array(const T* a, size_t n) {
+    T* dup = new T[n];
+    if (dup != NULL)
+        for (size_t i = 0; i < n; ++i)
+            dup[i] = a[i];
+    return dup;
+}
+
+/**
+ * Duplicate the array \p a.  The memory for the new array is allocated and the caller takes
  * responsibility.  Note that the dup is necessarily returned as a pointer, so size is lost.  Call
  * array_length() on the original array to discover the size.
  */
 template <typename T, size_t N> inline T* dup_array(const T (&a)[N]) {
-    T* dup = new T[N];
-    if (dup != NULL) {
-        memcpy(dup, &a, array_size(a));
-    }
-    return dup;
+    return dup_array(a, N);
 }
 
 /**
@@ -136,7 +144,7 @@
 
     template <typename T>
     explicit Eraser(T& t)
-        : buf_(reinterpret_cast<uint8_t*>(&t)), size_(sizeof(t)) {}
+        : buf_(reinterpret_cast<uint8_t*> (&t)), size_(sizeof(t)) {}
 
     template <size_t N> explicit Eraser(uint8_t (&arr)[N]) : buf_(arr), size_(N) {}
 
diff --git a/key.cpp b/key.cpp
index 053c2d5..00affef 100644
--- a/key.cpp
+++ b/key.cpp
@@ -26,97 +26,16 @@
 
 namespace keymaster {
 
-struct PKCS8_PRIV_KEY_INFO_Delete {
-    void operator()(PKCS8_PRIV_KEY_INFO* p) const { PKCS8_PRIV_KEY_INFO_free(p); }
-};
+/* static */
+template <> KeyFactoryRegistry* KeyFactoryRegistry::instance_ptr = 0;
 
 Key::Key(const KeyBlob& blob, const Logger& logger) : logger_(logger) {
     authorizations_.push_back(blob.unenforced());
     authorizations_.push_back(blob.enforced());
 }
 
-/* static */
-Key* Key::CreateKey(const UnencryptedKeyBlob& blob, const Logger& logger,
-                    keymaster_error_t* error) {
-    switch (blob.algorithm()) {
-    case KM_ALGORITHM_RSA:
-        return new RsaKey(blob, logger, error);
-    case KM_ALGORITHM_ECDSA:
-        return new EcdsaKey(blob, logger, error);
-    case KM_ALGORITHM_AES:
-    case KM_ALGORITHM_HMAC:
-        return SymmetricKey::CreateKey(blob.algorithm(), blob, logger, error);
-    default:
-        *error = KM_ERROR_UNSUPPORTED_ALGORITHM;
-        return NULL;
-    }
-}
-
-/* static */
-Key* Key::GenerateKey(const AuthorizationSet& key_description, const Logger& logger,
-                      keymaster_error_t* error) {
-    keymaster_algorithm_t algorithm;
-    if (!key_description.GetTagValue(TAG_ALGORITHM, &algorithm)) {
-        *error = KM_ERROR_UNSUPPORTED_ALGORITHM;
-        return NULL;
-    }
-
-    switch (algorithm) {
-    case KM_ALGORITHM_RSA:
-        return RsaKey::GenerateKey(key_description, logger, error);
-    case KM_ALGORITHM_ECDSA:
-        return EcdsaKey::GenerateKey(key_description, logger, error);
-    case KM_ALGORITHM_AES:
-    case KM_ALGORITHM_HMAC:
-        return SymmetricKey::GenerateKey(algorithm, key_description, logger, error);
-    default:
-        *error = KM_ERROR_UNSUPPORTED_ALGORITHM;
-        return NULL;
-    }
-}
-
-/* static */
-Key* Key::ImportKey(const AuthorizationSet& key_description, keymaster_key_format_t key_format,
-                    const uint8_t* key_data, size_t key_data_length, const Logger& logger,
-                    keymaster_error_t* error) {
-    *error = KM_ERROR_OK;
-
-    if (key_data == NULL || key_data_length <= 0) {
-        *error = KM_ERROR_INVALID_KEY_BLOB;
-        return NULL;
-    }
-
-    if (key_format != KM_KEY_FORMAT_PKCS8) {
-        *error = KM_ERROR_UNSUPPORTED_KEY_FORMAT;
-        return NULL;
-    }
-
-    UniquePtr<PKCS8_PRIV_KEY_INFO, PKCS8_PRIV_KEY_INFO_Delete> pkcs8(
-        d2i_PKCS8_PRIV_KEY_INFO(NULL, &key_data, key_data_length));
-    if (pkcs8.get() == NULL) {
-        *error = KM_ERROR_MEMORY_ALLOCATION_FAILED;
-        return NULL;
-    }
-
-    UniquePtr<EVP_PKEY, EVP_PKEY_Delete> pkey(EVP_PKCS82PKEY(pkcs8.get()));
-    if (pkey.get() == NULL) {
-        *error = KM_ERROR_INVALID_KEY_BLOB;
-        return NULL;
-    }
-
-    UniquePtr<Key> key;
-    switch (EVP_PKEY_type(pkey->type)) {
-    case EVP_PKEY_RSA:
-        return RsaKey::ImportKey(key_description, pkey.get(), logger, error);
-    case EVP_PKEY_EC:
-        return EcdsaKey::ImportKey(key_description, pkey.get(), logger, error);
-    default:
-        *error = KM_ERROR_UNSUPPORTED_ALGORITHM;
-        return NULL;
-    }
-
-    *error = KM_ERROR_UNIMPLEMENTED;
-    return NULL;
-}
+struct PKCS8_PRIV_KEY_INFO_Delete {
+    void operator()(PKCS8_PRIV_KEY_INFO* p) const { PKCS8_PRIV_KEY_INFO_free(p); }
+};
 
 }  // namespace keymaster
diff --git a/key.h b/key.h
index fab7816..77a8d96 100644
--- a/key.h
+++ b/key.h
@@ -21,22 +21,48 @@
 #include <keymaster/authorization_set.h>
 #include <keymaster/logger.h>
 
+#include "abstract_factory_registry.h"
+#include "unencrypted_key_blob.h"
+
 namespace keymaster {
 
+class Key;
+
+/**
+ * KeyFactory is a pure interface whose subclasses know how to construct a specific subclass of Key.
+ * There is a one to one correspondence between Key subclasses and KeyFactory subclasses.
+ */
+class KeyFactory {
+  public:
+    virtual ~KeyFactory() {}
+
+    // Required for registry
+    typedef keymaster_algorithm_t KeyType;
+    virtual keymaster_algorithm_t registry_key() const = 0;
+
+    // Factory methods.
+    virtual Key* GenerateKey(const AuthorizationSet& key_description, const Logger& logger,
+                             keymaster_error_t* error) = 0;
+    virtual Key* ImportKey(const AuthorizationSet& key_description,
+                           keymaster_key_format_t key_format, const uint8_t* key_data,
+                           size_t key_data_length, const Logger& logger,
+                           keymaster_error_t* error) = 0;
+    virtual Key* LoadKey(const UnencryptedKeyBlob& blob, const Logger& logger,
+                         keymaster_error_t* error) = 0;
+
+    // Informational methods.
+    virtual const keymaster_key_format_t* SupportedImportFormats(size_t* format_count) = 0;
+    virtual const keymaster_key_format_t* SupportedExportFormats(size_t* format_count) = 0;
+};
+
+typedef AbstractFactoryRegistry<KeyFactory> KeyFactoryRegistry;
+
 class KeyBlob;
 class Operation;
 class UnencryptedKeyBlob;
 
 class Key {
   public:
-    static Key* CreateKey(const UnencryptedKeyBlob& blob, const Logger& logger,
-                          keymaster_error_t* error);
-    static Key* GenerateKey(const AuthorizationSet& key_description, const Logger& logger,
-                            keymaster_error_t* error);
-    static Key* ImportKey(const AuthorizationSet& key_description,
-                          keymaster_key_format_t key_format, const uint8_t* key_data,
-                          size_t key_data_length, const Logger& logger, keymaster_error_t* error);
-
     virtual ~Key() {}
     virtual Operation* CreateOperation(keymaster_purpose_t purpose, keymaster_error_t* error) = 0;
 
diff --git a/rsa_key.cpp b/rsa_key.cpp
index 2e27883..35fe82e 100644
--- a/rsa_key.cpp
+++ b/rsa_key.cpp
@@ -30,9 +30,23 @@
 const uint32_t RSA_DEFAULT_KEY_SIZE = 2048;
 const uint64_t RSA_DEFAULT_EXPONENT = 65537;
 
-/* static */
-RsaKey* RsaKey::GenerateKey(const AuthorizationSet& key_description, const Logger& logger,
-                            keymaster_error_t* error) {
+class RsaKeyFactory : public AsymmetricKeyFactory {
+  public:
+    virtual keymaster_algorithm_t registry_key() const { return KM_ALGORITHM_RSA; }
+    virtual Key* GenerateKey(const AuthorizationSet& key_description, const Logger& logger,
+                             keymaster_error_t* error);
+    virtual Key* ImportKey(const AuthorizationSet& key_description,
+                           keymaster_key_format_t key_format, const uint8_t* key_data,
+                           size_t key_data_length, const Logger& logger, keymaster_error_t* error);
+    virtual Key* LoadKey(const UnencryptedKeyBlob& blob, const Logger& logger,
+                         keymaster_error_t* error) {
+        return new RsaKey(blob, logger, error);
+    }
+};
+static KeyFactoryRegistry::Registration<RsaKeyFactory> registration;
+
+Key* RsaKeyFactory::GenerateKey(const AuthorizationSet& key_description, const Logger& logger,
+                                keymaster_error_t* error) {
     if (!error)
         return NULL;
 
@@ -47,7 +61,7 @@
         authorizations.push_back(Authorization(TAG_KEY_SIZE, key_size));
 
     UniquePtr<BIGNUM, BIGNUM_Delete> exponent(BN_new());
-    UniquePtr<RSA, RSA_Delete> rsa_key(RSA_new());
+    UniquePtr<RSA, RsaKey::RSA_Delete> rsa_key(RSA_new());
     UniquePtr<EVP_PKEY, EVP_PKEY_Delete> pkey(EVP_PKEY_new());
     if (rsa_key.get() == NULL || pkey.get() == NULL) {
         *error = KM_ERROR_MEMORY_ALLOCATION_FAILED;
@@ -65,16 +79,24 @@
     return new_key;
 }
 
-/* static */
-RsaKey* RsaKey::ImportKey(const AuthorizationSet& key_description, EVP_PKEY* pkey,
-                          const Logger& logger, keymaster_error_t* error) {
+Key* RsaKeyFactory::ImportKey(const AuthorizationSet& key_description,
+                              keymaster_key_format_t key_format, const uint8_t* key_data,
+                              size_t key_data_length, const Logger& logger,
+                              keymaster_error_t* error) {
     if (!error)
         return NULL;
-    *error = KM_ERROR_UNKNOWN_ERROR;
 
-    UniquePtr<RSA, RSA_Delete> rsa_key(EVP_PKEY_get1_RSA(pkey));
-    if (!rsa_key.get())
+    UniquePtr<EVP_PKEY, EVP_PKEY_Delete> pkey(
+        ExtractEvpKey(key_format, KM_ALGORITHM_RSA, key_data, key_data_length, error));
+    if (*error != KM_ERROR_OK)
         return NULL;
+    assert(pkey.get());
+
+    UniquePtr<RSA, RsaKey::RSA_Delete> rsa_key(EVP_PKEY_get1_RSA(pkey.get()));
+    if (!rsa_key.get()) {
+        *error = KM_ERROR_UNKNOWN_ERROR;
+        return NULL;
+    }
 
     AuthorizationSet authorizations(key_description);
 
diff --git a/rsa_key.h b/rsa_key.h
index cb341e2..e98ece3 100644
--- a/rsa_key.h
+++ b/rsa_key.h
@@ -23,17 +23,16 @@
 
 namespace keymaster {
 
+class RsaKeyFactory;
+
 class RsaKey : public AsymmetricKey {
   public:
-    static RsaKey* GenerateKey(const AuthorizationSet& key_description, const Logger& logger,
-                               keymaster_error_t* error);
-    static RsaKey* ImportKey(const AuthorizationSet& key_description, EVP_PKEY* pkey,
-                             const Logger& logger, keymaster_error_t* error);
-    RsaKey(const UnencryptedKeyBlob& blob, const Logger& logger, keymaster_error_t* error);
-
     virtual Operation* CreateOperation(keymaster_purpose_t purpose, keymaster_error_t* error);
 
   private:
+    friend class RsaKeyFactory;
+
+    RsaKey(const UnencryptedKeyBlob& blob, const Logger& logger, keymaster_error_t* error);
     RsaKey(RSA* rsa_key, const AuthorizationSet& auths, const Logger& logger)
         : AsymmetricKey(auths, logger), rsa_key_(rsa_key) {}
 
diff --git a/symmetric_key.cpp b/symmetric_key.cpp
index 9678f9d..a850076 100644
--- a/symmetric_key.cpp
+++ b/symmetric_key.cpp
@@ -27,27 +27,13 @@
 
 namespace keymaster {
 
-/* static */
-SymmetricKey* SymmetricKey::GenerateKey(keymaster_algorithm_t algorithm,
-                                        const AuthorizationSet& key_description,
-                                        const Logger& logger, keymaster_error_t* error) {
+Key* SymmetricKeyFactory::GenerateKey(const AuthorizationSet& key_description, const Logger& logger,
+                                      keymaster_error_t* error) {
     if (!error)
         return NULL;
-
     *error = KM_ERROR_OK;
 
-    UniquePtr<SymmetricKey> key;
-    switch (algorithm) {
-    case KM_ALGORITHM_AES:
-        key.reset(new AesKey(key_description, logger));
-        break;
-    case KM_ALGORITHM_HMAC:
-        key.reset(new HmacKey(key_description, logger));
-        break;
-    default:
-        *error = KM_ERROR_UNSUPPORTED_ALGORITHM;
-        return NULL;
-    };
+    UniquePtr<SymmetricKey> key(CreateKey(key_description, logger));
 
     uint32_t key_size_bits;
     if (!key_description.GetTagValue(TAG_KEY_SIZE, &key_size_bits) || key_size_bits % 8 != 0) {
@@ -56,7 +42,7 @@
     }
 
     key->key_data_size_ = key_size_bits / 8;
-    if (key->key_data_size_ > MAX_KEY_SIZE) {
+    if (key->key_data_size_ > SymmetricKey::MAX_KEY_SIZE) {
         *error = KM_ERROR_UNSUPPORTED_KEY_SIZE;
         return NULL;
     }
@@ -71,39 +57,18 @@
     return key.release();
 }
 
-/* static */
-SymmetricKey* SymmetricKey::CreateKey(keymaster_algorithm_t algorithm,
-                                      const UnencryptedKeyBlob& blob, const Logger& logger,
-                                      keymaster_error_t* error) {
-    switch (algorithm) {
-    case KM_ALGORITHM_AES:
-        return new AesKey(blob, logger, error);
-    case KM_ALGORITHM_HMAC:
-        return new HmacKey(blob, logger, error);
-    default:
-        *error = KM_ERROR_UNSUPPORTED_ALGORITHM;
-        return NULL;
-    }
-}
-
 SymmetricKey::SymmetricKey(const UnencryptedKeyBlob& blob, const Logger& logger,
                            keymaster_error_t* error)
     : Key(blob, logger), key_data_size_(blob.unencrypted_key_material_length()) {
+    memcpy(key_data_, blob.unencrypted_key_material(), key_data_size_);
     if (error)
-        *error = LoadKey(blob);
+        *error = KM_ERROR_OK;
 }
 
 SymmetricKey::~SymmetricKey() {
     memset_s(key_data_, 0, MAX_KEY_SIZE);
 }
 
-keymaster_error_t SymmetricKey::LoadKey(const UnencryptedKeyBlob& blob) {
-    assert(blob.unencrypted_key_material_length() == key_data_size_);
-    memcpy(key_data_, blob.unencrypted_key_material(), key_data_size_);
-
-    return KM_ERROR_OK;
-}
-
 keymaster_error_t SymmetricKey::key_material(UniquePtr<uint8_t[]>* key_material,
                                              size_t* size) const {
     *size = key_data_size_;
diff --git a/symmetric_key.h b/symmetric_key.h
index dc65349..bef921f 100644
--- a/symmetric_key.h
+++ b/symmetric_key.h
@@ -21,19 +21,38 @@
 
 namespace keymaster {
 
+class SymmetricKey;
+
+class SymmetricKeyFactory : public KeyFactory {
+    virtual Key* GenerateKey(const AuthorizationSet& key_description, const Logger& logger,
+                             keymaster_error_t* error);
+    virtual Key* ImportKey(const AuthorizationSet&, keymaster_key_format_t, const uint8_t*, size_t,
+                           const Logger&, keymaster_error_t* error) {
+        *error = KM_ERROR_UNIMPLEMENTED;
+        return NULL;
+    }
+
+    virtual const keymaster_key_format_t* SupportedImportFormats(size_t* format_count) {
+        return NoFormats(format_count);
+    }
+    virtual const keymaster_key_format_t* SupportedExportFormats(size_t* format_count) {
+        return NoFormats(format_count);
+    };
+
+  private:
+    virtual SymmetricKey* CreateKey(const AuthorizationSet& auths, const Logger& logger) = 0;
+    const keymaster_key_format_t* NoFormats(size_t* format_count) {
+        *format_count = 0;
+        return NULL;
+    }
+};
+
 class SymmetricKey : public Key {
   public:
     static const int MAX_KEY_SIZE = 32;
     static const int MAX_MAC_LENGTH = 32;
     static const uint32_t MAX_CHUNK_LENGTH = 64 * 1024;
 
-    static SymmetricKey* GenerateKey(keymaster_algorithm_t algorithm,
-                                     const AuthorizationSet& key_description, const Logger& logger,
-                                     keymaster_error_t* error);
-    static SymmetricKey* CreateKey(keymaster_algorithm_t algorithm, const UnencryptedKeyBlob& blob,
-                                   const Logger& logger, keymaster_error_t* error);
-
-    SymmetricKey(const UnencryptedKeyBlob& blob, const Logger& logger, keymaster_error_t* error);
     ~SymmetricKey();
 
     virtual keymaster_error_t key_material(UniquePtr<uint8_t[]>* key_material, size_t* size) const;
@@ -45,12 +64,16 @@
   protected:
     keymaster_error_t error_;
 
+    SymmetricKey(const UnencryptedKeyBlob& blob, const Logger& logger, keymaster_error_t* error);
+
     const uint8_t* key_data() const { return key_data_; }
     size_t key_data_size() const { return key_data_size_; }
 
     SymmetricKey(const AuthorizationSet& auths, const Logger& logger) : Key(auths, logger) {}
 
   private:
+    friend SymmetricKeyFactory;
+
     keymaster_error_t LoadKey(const UnencryptedKeyBlob& blob);
 
     size_t key_data_size_;