Refactor key creation to use a registry of key factories.
Change-Id: I6ebab7b44e4a5dbea282397ab8aca437e71bdca0
diff --git a/aes_key.cpp b/aes_key.cpp
index 5eff739..e692711 100644
--- a/aes_key.cpp
+++ b/aes_key.cpp
@@ -26,6 +26,22 @@
namespace keymaster {
+class AesKeyFactory : public SymmetricKeyFactory {
+ public:
+ keymaster_algorithm_t registry_key() const { return KM_ALGORITHM_AES; }
+
+ virtual Key* LoadKey(const UnencryptedKeyBlob& blob, const Logger& logger,
+ keymaster_error_t* error) {
+ return new AesKey(blob, logger, error);
+ }
+
+ virtual SymmetricKey* CreateKey(const AuthorizationSet& auths, const Logger& logger) {
+ return new AesKey(auths, logger);
+ }
+};
+
+static KeyFactoryRegistry::Registration<AesKeyFactory> registration;
+
Operation* AesKey::CreateOperation(keymaster_purpose_t purpose, keymaster_error_t* error) {
keymaster_block_mode_t block_mode;
if (!authorizations().GetTagValue(TAG_BLOCK_MODE, &block_mode)) {
diff --git a/asymmetric_key.cpp b/asymmetric_key.cpp
index 064cd7e..65249fe 100644
--- a/asymmetric_key.cpp
+++ b/asymmetric_key.cpp
@@ -14,16 +14,79 @@
* limitations under the License.
*/
+#include "asymmetric_key.h"
+
#include <openssl/x509.h>
#include <hardware/keymaster_defs.h>
-#include "asymmetric_key.h"
+#include "ecdsa_key.h"
#include "openssl_utils.h"
+#include "rsa_key.h"
#include "unencrypted_key_blob.h"
namespace keymaster {
+struct PKCS8_PRIV_KEY_INFO_Delete {
+ void operator()(PKCS8_PRIV_KEY_INFO* p) const { PKCS8_PRIV_KEY_INFO_free(p); }
+};
+
+EVP_PKEY* AsymmetricKeyFactory::ExtractEvpKey(keymaster_key_format_t key_format,
+ keymaster_algorithm_t expected_algorithm,
+ const uint8_t* key_data, size_t key_data_length,
+ keymaster_error_t* error) {
+ *error = KM_ERROR_OK;
+
+ if (key_data == NULL || key_data_length <= 0) {
+ *error = KM_ERROR_INVALID_KEY_BLOB;
+ return NULL;
+ }
+
+ if (key_format != KM_KEY_FORMAT_PKCS8) {
+ *error = KM_ERROR_UNSUPPORTED_KEY_FORMAT;
+ return NULL;
+ }
+
+ UniquePtr<PKCS8_PRIV_KEY_INFO, PKCS8_PRIV_KEY_INFO_Delete> pkcs8(
+ d2i_PKCS8_PRIV_KEY_INFO(NULL, &key_data, key_data_length));
+ if (pkcs8.get() == NULL) {
+ *error = KM_ERROR_MEMORY_ALLOCATION_FAILED;
+ return NULL;
+ }
+
+ UniquePtr<EVP_PKEY, EVP_PKEY_Delete> pkey(EVP_PKCS82PKEY(pkcs8.get()));
+ if (pkey.get() == NULL || EVP_PKEY_type(pkey->type) != convert_to_evp(expected_algorithm)) {
+ *error = KM_ERROR_INVALID_KEY_BLOB;
+ return NULL;
+ }
+
+ return pkey.release();
+}
+
+static const keymaster_key_format_t supported_import_formats[] = {KM_KEY_FORMAT_PKCS8};
+const keymaster_key_format_t* AsymmetricKeyFactory::SupportedImportFormats(size_t* format_count) {
+ *format_count = array_length(supported_import_formats);
+ return supported_import_formats;
+}
+
+static const keymaster_key_format_t supported_export_formats[] = {KM_KEY_FORMAT_X509};
+const keymaster_key_format_t* AsymmetricKeyFactory::SupportedExportFormats(size_t* format_count) {
+ *format_count = array_length(supported_export_formats);
+ return supported_export_formats;
+}
+
+/* static */
+int AsymmetricKeyFactory::convert_to_evp(keymaster_algorithm_t algorithm) {
+ switch (algorithm) {
+ case KM_ALGORITHM_RSA:
+ return EVP_PKEY_RSA;
+ case KM_ALGORITHM_ECDSA:
+ return EVP_PKEY_EC;
+ default:
+ return -1;
+ };
+}
+
keymaster_error_t AsymmetricKey::LoadKey(const UnencryptedKeyBlob& blob) {
UniquePtr<EVP_PKEY, EVP_PKEY_Delete> evp_key(EVP_PKEY_new());
if (evp_key.get() == NULL)
diff --git a/asymmetric_key.h b/asymmetric_key.h
index ead8189..b715077 100644
--- a/asymmetric_key.h
+++ b/asymmetric_key.h
@@ -23,6 +23,19 @@
namespace keymaster {
+class AsymmetricKeyFactory : public KeyFactory {
+ protected:
+ EVP_PKEY* ExtractEvpKey(keymaster_key_format_t key_format,
+ keymaster_algorithm_t expected_algorithm, const uint8_t* key_data,
+ size_t key_data_length, keymaster_error_t* error);
+
+ virtual const keymaster_key_format_t* SupportedImportFormats(size_t* format_count);
+ virtual const keymaster_key_format_t* SupportedExportFormats(size_t* format_count);
+
+ private:
+ static int convert_to_evp(keymaster_algorithm_t algorithm);
+};
+
class AsymmetricKey : public Key {
public:
protected:
diff --git a/ecdsa_key.cpp b/ecdsa_key.cpp
index 84d9649..b499fdb 100644
--- a/ecdsa_key.cpp
+++ b/ecdsa_key.cpp
@@ -23,9 +23,32 @@
const uint32_t ECDSA_DEFAULT_KEY_SIZE = 224;
-/* static */
-EcdsaKey* EcdsaKey::GenerateKey(const AuthorizationSet& key_description, const Logger& logger,
- keymaster_error_t* error) {
+class EcdsaKeyFactory : public AsymmetricKeyFactory {
+ public:
+ virtual keymaster_algorithm_t registry_key() const { return KM_ALGORITHM_ECDSA; }
+
+ virtual Key* GenerateKey(const AuthorizationSet& key_description, const Logger& logger,
+ keymaster_error_t* error);
+ virtual Key* ImportKey(const AuthorizationSet& key_description,
+ keymaster_key_format_t key_format, const uint8_t* key_data,
+ size_t key_data_length, const Logger& logger, keymaster_error_t* error);
+ virtual Key* LoadKey(const UnencryptedKeyBlob& blob, const Logger& logger,
+ keymaster_error_t* error) {
+ return new EcdsaKey(blob, logger, error);
+ }
+
+ private:
+ static EC_GROUP* choose_group(size_t key_size_bits);
+ static keymaster_error_t get_group_size(const EC_GROUP& group, size_t* key_size_bits);
+
+ struct EC_GROUP_Delete {
+ void operator()(EC_GROUP* p) { EC_GROUP_free(p); }
+ };
+};
+static KeyFactoryRegistry::Registration<EcdsaKeyFactory> registration;
+
+Key* EcdsaKeyFactory::GenerateKey(const AuthorizationSet& key_description, const Logger& logger,
+ keymaster_error_t* error) {
if (!error)
return NULL;
@@ -35,7 +58,7 @@
if (!authorizations.GetTagValue(TAG_KEY_SIZE, &key_size))
authorizations.push_back(Authorization(TAG_KEY_SIZE, key_size));
- UniquePtr<EC_KEY, ECDSA_Delete> ecdsa_key(EC_KEY_new());
+ UniquePtr<EC_KEY, EcdsaKey::ECDSA_Delete> ecdsa_key(EC_KEY_new());
UniquePtr<EVP_PKEY, EVP_PKEY_Delete> pkey(EVP_PKEY_new());
if (ecdsa_key.get() == NULL || pkey.get() == NULL) {
*error = KM_ERROR_MEMORY_ALLOCATION_FAILED;
@@ -65,24 +88,32 @@
return new_key;
}
-/* static */
-EcdsaKey* EcdsaKey::ImportKey(const AuthorizationSet& key_description, EVP_PKEY* pkey,
- const Logger& logger, keymaster_error_t* error) {
+Key* EcdsaKeyFactory::ImportKey(const AuthorizationSet& key_description,
+ keymaster_key_format_t key_format, const uint8_t* key_data,
+ size_t key_data_length, const Logger& logger,
+ keymaster_error_t* error) {
if (!error)
return NULL;
- *error = KM_ERROR_UNKNOWN_ERROR;
- UniquePtr<EC_KEY, ECDSA_Delete> ecdsa_key(EVP_PKEY_get1_EC_KEY(pkey));
- if (!ecdsa_key.get())
+ UniquePtr<EVP_PKEY, EVP_PKEY_Delete> pkey(
+ ExtractEvpKey(key_format, KM_ALGORITHM_ECDSA, key_data, key_data_length, error));
+ if (*error != KM_ERROR_OK)
return NULL;
+ assert(pkey.get());
- AuthorizationSet authorizations(key_description);
+ UniquePtr<EC_KEY, EcdsaKey::ECDSA_Delete> ecdsa_key(EVP_PKEY_get1_EC_KEY(pkey.get()));
+ if (!ecdsa_key.get()) {
+ *error = KM_ERROR_UNKNOWN_ERROR;
+ return NULL;
+ }
size_t extracted_key_size_bits;
*error = get_group_size(*EC_KEY_get0_group(ecdsa_key.get()), &extracted_key_size_bits);
if (*error != KM_ERROR_OK)
return NULL;
+ AuthorizationSet authorizations(key_description);
+
uint32_t key_size_bits;
if (authorizations.GetTagValue(TAG_KEY_SIZE, &key_size_bits)) {
// key_size_bits specified, make sure it matches the key.
@@ -113,7 +144,7 @@
}
/* static */
-EC_GROUP* EcdsaKey::choose_group(size_t key_size_bits) {
+EC_GROUP* EcdsaKeyFactory::choose_group(size_t key_size_bits) {
switch (key_size_bits) {
case 224:
return EC_GROUP_new_by_curve_name(NID_secp224r1);
@@ -134,7 +165,7 @@
}
/* static */
-keymaster_error_t EcdsaKey::get_group_size(const EC_GROUP& group, size_t* key_size_bits) {
+keymaster_error_t EcdsaKeyFactory::get_group_size(const EC_GROUP& group, size_t* key_size_bits) {
switch (EC_GROUP_get_curve_name(&group)) {
case NID_secp224r1:
*key_size_bits = 224;
diff --git a/ecdsa_key.h b/ecdsa_key.h
index b657a65..19df1ed 100644
--- a/ecdsa_key.h
+++ b/ecdsa_key.h
@@ -23,23 +23,19 @@
namespace keymaster {
+class EcdsaKeyFactory;
+
class EcdsaKey : public AsymmetricKey {
public:
- static EcdsaKey* GenerateKey(const AuthorizationSet& key_description, const Logger& logger,
- keymaster_error_t* error);
- static EcdsaKey* ImportKey(const AuthorizationSet& key_description, EVP_PKEY* pkey,
- const Logger& logger, keymaster_error_t* error);
- EcdsaKey(const UnencryptedKeyBlob& blob, const Logger& logger, keymaster_error_t* error);
-
virtual Operation* CreateOperation(keymaster_purpose_t purpose, keymaster_error_t* error);
private:
+ friend EcdsaKeyFactory;
+
+ EcdsaKey(const UnencryptedKeyBlob& blob, const Logger& logger, keymaster_error_t* error);
EcdsaKey(EC_KEY* ecdsa_key, const AuthorizationSet auths, const Logger& logger)
: AsymmetricKey(auths, logger), ecdsa_key_(ecdsa_key) {}
- static EC_GROUP* choose_group(size_t key_size_bits);
- static keymaster_error_t get_group_size(const EC_GROUP& group, size_t* key_size_bits);
-
virtual int evp_key_type() { return EVP_PKEY_EC; }
virtual bool InternalToEvp(EVP_PKEY* pkey) const;
virtual bool EvpToInternal(const EVP_PKEY* pkey);
@@ -48,10 +44,6 @@
void operator()(EC_KEY* p) { EC_KEY_free(p); }
};
- struct EC_GROUP_Delete {
- void operator()(EC_GROUP* p) { EC_GROUP_free(p); }
- };
-
UniquePtr<EC_KEY, ECDSA_Delete> ecdsa_key_;
};
diff --git a/google_keymaster.cpp b/google_keymaster.cpp
index f89b6f7..2024c90 100644
--- a/google_keymaster.cpp
+++ b/google_keymaster.cpp
@@ -44,6 +44,7 @@
if (operation_table_.get() == NULL)
operation_table_size_ = 0;
}
+
GoogleKeymaster::~GoogleKeymaster() {
for (size_t i = 0; i < operation_table_size_; ++i)
if (operation_table_[i].operation != NULL)
@@ -60,14 +61,13 @@
// methods that return the same information. They'll get out of sync. Best to put the knowledge in
// the keytypes and provide some mechanism for GoogleKeymaster to query the keytypes for the
// information.
-
-keymaster_algorithm_t supported_algorithms[] = {
- KM_ALGORITHM_RSA, KM_ALGORITHM_ECDSA, KM_ALGORITHM_AES,
-};
+//
+// UPDATE: This TODO has been completed for supported algorithms. It still needs to be done for
+// modes, padding, etc. This will be done with a registry of operation factories.
template <typename T>
bool check_supported(keymaster_algorithm_t algorithm, SupportedResponse<T>* response) {
- if (!array_contains(supported_algorithms, algorithm)) {
+ if (KeyFactoryRegistry::Get(algorithm) == NULL) {
response->error = KM_ERROR_UNSUPPORTED_ALGORITHM;
return false;
}
@@ -88,7 +88,24 @@
SupportedResponse<keymaster_algorithm_t>* response) const {
if (response == NULL)
return;
- response->SetResults(supported_algorithms);
+
+ size_t factory_count = 0;
+ const KeyFactory** factories = KeyFactoryRegistry::GetAll(&factory_count);
+ assert(factories != NULL);
+ assert(factory_count > 0);
+
+ UniquePtr<keymaster_algorithm_t[]> algorithms(new keymaster_algorithm_t[factory_count]);
+ if (!algorithms.get()) {
+ response->error = KM_ERROR_MEMORY_ALLOCATION_FAILED;
+ return;
+ }
+
+ for (size_t i = 0; i < factory_count; ++i)
+ algorithms[i] = factories[i]->registry_key();
+
+ response->results = algorithms.release();
+ response->results_length = factory_count;
+ response->error = KM_ERROR_OK;
}
void GoogleKeymaster::SupportedBlockModes(
@@ -153,42 +170,26 @@
}
}
-keymaster_key_format_t supported_import_formats[] = {KM_KEY_FORMAT_PKCS8};
void GoogleKeymaster::SupportedImportFormats(
keymaster_algorithm_t algorithm, SupportedResponse<keymaster_key_format_t>* response) const {
if (response == NULL || !check_supported(algorithm, response))
return;
- response->error = KM_ERROR_OK;
- switch (algorithm) {
- case KM_ALGORITHM_RSA:
- case KM_ALGORITHM_DSA:
- case KM_ALGORITHM_ECDSA:
- response->SetResults(supported_import_formats);
- break;
- default:
- response->results_length = 0;
- break;
- }
+ size_t count;
+ const keymaster_key_format_t* formats =
+ KeyFactoryRegistry::Get(algorithm)->SupportedImportFormats(&count);
+ response->SetResults(formats, count);
}
-keymaster_key_format_t supported_export_formats[] = {KM_KEY_FORMAT_X509};
void GoogleKeymaster::SupportedExportFormats(
keymaster_algorithm_t algorithm, SupportedResponse<keymaster_key_format_t>* response) const {
if (response == NULL || !check_supported(algorithm, response))
return;
- response->error = KM_ERROR_OK;
- switch (algorithm) {
- case KM_ALGORITHM_RSA:
- case KM_ALGORITHM_DSA:
- case KM_ALGORITHM_ECDSA:
- response->SetResults(supported_export_formats);
- break;
- default:
- response->results_length = 0;
- break;
- }
+ size_t count;
+ const keymaster_key_format_t* formats =
+ KeyFactoryRegistry::Get(algorithm)->SupportedExportFormats(&count);
+ response->SetResults(formats, count);
}
void GoogleKeymaster::GenerateKey(const GenerateKeyRequest& request,
@@ -196,7 +197,15 @@
if (response == NULL)
return;
- UniquePtr<Key> key(Key::GenerateKey(request.key_description, logger(), &response->error));
+ keymaster_algorithm_t algorithm;
+ KeyFactory* factory = 0;
+ UniquePtr<Key> key;
+ if (!request.key_description.GetTagValue(TAG_ALGORITHM, &algorithm) ||
+ !(factory = KeyFactoryRegistry::Get(algorithm)))
+ response->error = KM_ERROR_UNSUPPORTED_ALGORITHM;
+ else
+ key.reset(factory->GenerateKey(request.key_description, logger(), &response->error));
+
if (response->error != KM_ERROR_OK)
return;
@@ -310,8 +319,16 @@
if (response == NULL)
return;
- UniquePtr<Key> key(Key::ImportKey(request.key_description, request.key_format, request.key_data,
- request.key_data_length, logger(), &response->error));
+ keymaster_algorithm_t algorithm;
+ KeyFactory* factory = 0;
+ UniquePtr<Key> key;
+ if (!request.key_description.GetTagValue(TAG_ALGORITHM, &algorithm) ||
+ !(factory = KeyFactoryRegistry::Get(algorithm)))
+ response->error = KM_ERROR_UNSUPPORTED_ALGORITHM;
+ else
+ key.reset(factory->ImportKey(request.key_description, request.key_format, request.key_data,
+ request.key_data_length, logger(), &response->error));
+
if (response->error != KM_ERROR_OK)
return;
@@ -369,7 +386,12 @@
UniquePtr<UnencryptedKeyBlob> blob(LoadKeyBlob(key, client_params, error));
if (*error != KM_ERROR_OK)
return NULL;
- return Key::CreateKey(*blob, logger(), error);
+
+ KeyFactory* factory = 0;
+ if ((factory = KeyFactoryRegistry::Get(blob->algorithm())))
+ return factory->LoadKey(*blob, logger(), error);
+ *error = KM_ERROR_UNSUPPORTED_ALGORITHM;
+ return NULL;
}
UnencryptedKeyBlob* GoogleKeymaster::LoadKeyBlob(const keymaster_key_blob_t& key,
diff --git a/google_keymaster_test.cpp b/google_keymaster_test.cpp
index 9d280f0..91aef7b 100644
--- a/google_keymaster_test.cpp
+++ b/google_keymaster_test.cpp
@@ -379,8 +379,9 @@
size_t len;
keymaster_algorithm_t* algorithms;
EXPECT_EQ(KM_ERROR_OK, device()->get_supported_algorithms(device(), &algorithms, &len));
- EXPECT_TRUE(ResponseContains({KM_ALGORITHM_RSA, KM_ALGORITHM_ECDSA, KM_ALGORITHM_AES},
- algorithms, len));
+ EXPECT_TRUE(ResponseContains(
+ {KM_ALGORITHM_RSA, KM_ALGORITHM_ECDSA, KM_ALGORITHM_AES, KM_ALGORITHM_HMAC}, algorithms,
+ len));
free(algorithms);
}
@@ -842,7 +843,7 @@
string pk8_key = read_file("rsa_privkey_pk8.der");
ASSERT_EQ(633U, pk8_key.size());
- ASSERT_EQ(KM_ERROR_OK, ImportKey(ParamBuilder().SigningKey().NoDigestOrPadding(),
+ ASSERT_EQ(KM_ERROR_OK, ImportKey(ParamBuilder().RsaSigningKey().NoDigestOrPadding(),
KM_KEY_FORMAT_PKCS8, pk8_key));
// Check values derived from the key.
@@ -865,8 +866,7 @@
ASSERT_EQ(633U, pk8_key.size());
ASSERT_EQ(KM_ERROR_IMPORT_PARAMETER_MISMATCH,
ImportKey(ParamBuilder()
- .SigningKey()
- .Option(TAG_KEY_SIZE, 2048) // Doesn't match key
+ .RsaSigningKey(2048) // Size doesn't match key
.NoDigestOrPadding(),
KM_KEY_FORMAT_PKCS8, pk8_key));
}
@@ -876,7 +876,7 @@
ASSERT_EQ(633U, pk8_key.size());
ASSERT_EQ(KM_ERROR_IMPORT_PARAMETER_MISMATCH,
ImportKey(ParamBuilder()
- .SigningKey()
+ .RsaSigningKey()
.Option(TAG_RSA_PUBLIC_EXPONENT, 3) // Doesn't match key
.NoDigestOrPadding(),
KM_KEY_FORMAT_PKCS8, pk8_key));
@@ -886,7 +886,8 @@
string pk8_key = read_file("ec_privkey_pk8.der");
ASSERT_EQ(138U, pk8_key.size());
- ASSERT_EQ(KM_ERROR_OK, ImportKey(ParamBuilder().SigningKey(), KM_KEY_FORMAT_PKCS8, pk8_key));
+ ASSERT_EQ(KM_ERROR_OK,
+ ImportKey(ParamBuilder().EcdsaSigningKey(), KM_KEY_FORMAT_PKCS8, pk8_key));
// Check values derived from the key.
EXPECT_TRUE(contains(sw_enforced(), TAG_ALGORITHM, KM_ALGORITHM_ECDSA));
@@ -906,8 +907,8 @@
string pk8_key = read_file("ec_privkey_pk8.der");
ASSERT_EQ(138U, pk8_key.size());
- ASSERT_EQ(KM_ERROR_OK, ImportKey(ParamBuilder().SigningKey().Option(TAG_KEY_SIZE, 256),
- KM_KEY_FORMAT_PKCS8, pk8_key));
+ ASSERT_EQ(KM_ERROR_OK,
+ ImportKey(ParamBuilder().EcdsaSigningKey(256), KM_KEY_FORMAT_PKCS8, pk8_key));
// Check values derived from the key.
EXPECT_TRUE(contains(sw_enforced(), TAG_ALGORITHM, KM_ALGORITHM_ECDSA));
@@ -927,8 +928,8 @@
string pk8_key = read_file("ec_privkey_pk8.der");
ASSERT_EQ(138U, pk8_key.size());
ASSERT_EQ(KM_ERROR_IMPORT_PARAMETER_MISMATCH,
- ImportKey(ParamBuilder().SigningKey().Option(TAG_KEY_SIZE, 224), KM_KEY_FORMAT_PKCS8,
- pk8_key));
+ ImportKey(ParamBuilder().EcdsaSigningKey(224), // Size does not match key
+ KM_KEY_FORMAT_PKCS8, pk8_key));
}
typedef KeymasterTest VersionTest;
diff --git a/hmac_key.cpp b/hmac_key.cpp
index 7314a68..d49f32c 100644
--- a/hmac_key.cpp
+++ b/hmac_key.cpp
@@ -23,6 +23,22 @@
namespace keymaster {
+class HmacKeyFactory : public SymmetricKeyFactory {
+ public:
+ keymaster_algorithm_t registry_key() const { return KM_ALGORITHM_HMAC; }
+
+ virtual Key* LoadKey(const UnencryptedKeyBlob& blob, const Logger& logger,
+ keymaster_error_t* error) {
+ return new HmacKey(blob, logger, error);
+ }
+
+ virtual SymmetricKey* CreateKey(const AuthorizationSet& auths, const Logger& logger) {
+ return new HmacKey(auths, logger);
+ }
+};
+
+static KeyFactoryRegistry::Registration<HmacKeyFactory> registration;
+
Operation* HmacKey::CreateOperation(keymaster_purpose_t purpose, keymaster_error_t* error) {
*error = KM_ERROR_OK;
diff --git a/include/keymaster/google_keymaster_messages.h b/include/keymaster/google_keymaster_messages.h
index 432d68f..ba9e1c4 100644
--- a/include/keymaster/google_keymaster_messages.h
+++ b/include/keymaster/google_keymaster_messages.h
@@ -111,13 +111,17 @@
~SupportedResponse() { delete[] results; }
template <size_t N> void SetResults(const T(&arr)[N]) {
+ SetResults(arr, N);
+ }
+
+ void SetResults(const T* arr, size_t n) {
delete[] results;
results_length = 0;
- results = dup_array(arr);
+ results = dup_array(arr, n);
if (results == NULL) {
error = KM_ERROR_MEMORY_ALLOCATION_FAILED;
} else {
- results_length = N;
+ results_length = n;
error = KM_ERROR_OK;
}
}
diff --git a/include/keymaster/google_keymaster_utils.h b/include/keymaster/google_keymaster_utils.h
index ae98407..bba0c36 100644
--- a/include/keymaster/google_keymaster_utils.h
+++ b/include/keymaster/google_keymaster_utils.h
@@ -62,15 +62,23 @@
/**
* Duplicate the array \p a. The memory for the new array is allocated and the caller takes
+ * responsibility.
+ */
+template <typename T> inline T* dup_array(const T* a, size_t n) {
+ T* dup = new T[n];
+ if (dup != NULL)
+ for (size_t i = 0; i < n; ++i)
+ dup[i] = a[i];
+ return dup;
+}
+
+/**
+ * Duplicate the array \p a. The memory for the new array is allocated and the caller takes
* responsibility. Note that the dup is necessarily returned as a pointer, so size is lost. Call
* array_length() on the original array to discover the size.
*/
template <typename T, size_t N> inline T* dup_array(const T (&a)[N]) {
- T* dup = new T[N];
- if (dup != NULL) {
- memcpy(dup, &a, array_size(a));
- }
- return dup;
+ return dup_array(a, N);
}
/**
@@ -136,7 +144,7 @@
template <typename T>
explicit Eraser(T& t)
- : buf_(reinterpret_cast<uint8_t*>(&t)), size_(sizeof(t)) {}
+ : buf_(reinterpret_cast<uint8_t*> (&t)), size_(sizeof(t)) {}
template <size_t N> explicit Eraser(uint8_t (&arr)[N]) : buf_(arr), size_(N) {}
diff --git a/key.cpp b/key.cpp
index 053c2d5..00affef 100644
--- a/key.cpp
+++ b/key.cpp
@@ -26,97 +26,16 @@
namespace keymaster {
-struct PKCS8_PRIV_KEY_INFO_Delete {
- void operator()(PKCS8_PRIV_KEY_INFO* p) const { PKCS8_PRIV_KEY_INFO_free(p); }
-};
+/* static */
+template <> KeyFactoryRegistry* KeyFactoryRegistry::instance_ptr = 0;
Key::Key(const KeyBlob& blob, const Logger& logger) : logger_(logger) {
authorizations_.push_back(blob.unenforced());
authorizations_.push_back(blob.enforced());
}
-/* static */
-Key* Key::CreateKey(const UnencryptedKeyBlob& blob, const Logger& logger,
- keymaster_error_t* error) {
- switch (blob.algorithm()) {
- case KM_ALGORITHM_RSA:
- return new RsaKey(blob, logger, error);
- case KM_ALGORITHM_ECDSA:
- return new EcdsaKey(blob, logger, error);
- case KM_ALGORITHM_AES:
- case KM_ALGORITHM_HMAC:
- return SymmetricKey::CreateKey(blob.algorithm(), blob, logger, error);
- default:
- *error = KM_ERROR_UNSUPPORTED_ALGORITHM;
- return NULL;
- }
-}
-
-/* static */
-Key* Key::GenerateKey(const AuthorizationSet& key_description, const Logger& logger,
- keymaster_error_t* error) {
- keymaster_algorithm_t algorithm;
- if (!key_description.GetTagValue(TAG_ALGORITHM, &algorithm)) {
- *error = KM_ERROR_UNSUPPORTED_ALGORITHM;
- return NULL;
- }
-
- switch (algorithm) {
- case KM_ALGORITHM_RSA:
- return RsaKey::GenerateKey(key_description, logger, error);
- case KM_ALGORITHM_ECDSA:
- return EcdsaKey::GenerateKey(key_description, logger, error);
- case KM_ALGORITHM_AES:
- case KM_ALGORITHM_HMAC:
- return SymmetricKey::GenerateKey(algorithm, key_description, logger, error);
- default:
- *error = KM_ERROR_UNSUPPORTED_ALGORITHM;
- return NULL;
- }
-}
-
-/* static */
-Key* Key::ImportKey(const AuthorizationSet& key_description, keymaster_key_format_t key_format,
- const uint8_t* key_data, size_t key_data_length, const Logger& logger,
- keymaster_error_t* error) {
- *error = KM_ERROR_OK;
-
- if (key_data == NULL || key_data_length <= 0) {
- *error = KM_ERROR_INVALID_KEY_BLOB;
- return NULL;
- }
-
- if (key_format != KM_KEY_FORMAT_PKCS8) {
- *error = KM_ERROR_UNSUPPORTED_KEY_FORMAT;
- return NULL;
- }
-
- UniquePtr<PKCS8_PRIV_KEY_INFO, PKCS8_PRIV_KEY_INFO_Delete> pkcs8(
- d2i_PKCS8_PRIV_KEY_INFO(NULL, &key_data, key_data_length));
- if (pkcs8.get() == NULL) {
- *error = KM_ERROR_MEMORY_ALLOCATION_FAILED;
- return NULL;
- }
-
- UniquePtr<EVP_PKEY, EVP_PKEY_Delete> pkey(EVP_PKCS82PKEY(pkcs8.get()));
- if (pkey.get() == NULL) {
- *error = KM_ERROR_INVALID_KEY_BLOB;
- return NULL;
- }
-
- UniquePtr<Key> key;
- switch (EVP_PKEY_type(pkey->type)) {
- case EVP_PKEY_RSA:
- return RsaKey::ImportKey(key_description, pkey.get(), logger, error);
- case EVP_PKEY_EC:
- return EcdsaKey::ImportKey(key_description, pkey.get(), logger, error);
- default:
- *error = KM_ERROR_UNSUPPORTED_ALGORITHM;
- return NULL;
- }
-
- *error = KM_ERROR_UNIMPLEMENTED;
- return NULL;
-}
+struct PKCS8_PRIV_KEY_INFO_Delete {
+ void operator()(PKCS8_PRIV_KEY_INFO* p) const { PKCS8_PRIV_KEY_INFO_free(p); }
+};
} // namespace keymaster
diff --git a/key.h b/key.h
index fab7816..77a8d96 100644
--- a/key.h
+++ b/key.h
@@ -21,22 +21,48 @@
#include <keymaster/authorization_set.h>
#include <keymaster/logger.h>
+#include "abstract_factory_registry.h"
+#include "unencrypted_key_blob.h"
+
namespace keymaster {
+class Key;
+
+/**
+ * KeyFactory is a pure interface whose subclasses know how to construct a specific subclass of Key.
+ * There is a one to one correspondence between Key subclasses and KeyFactory subclasses.
+ */
+class KeyFactory {
+ public:
+ virtual ~KeyFactory() {}
+
+ // Required for registry
+ typedef keymaster_algorithm_t KeyType;
+ virtual keymaster_algorithm_t registry_key() const = 0;
+
+ // Factory methods.
+ virtual Key* GenerateKey(const AuthorizationSet& key_description, const Logger& logger,
+ keymaster_error_t* error) = 0;
+ virtual Key* ImportKey(const AuthorizationSet& key_description,
+ keymaster_key_format_t key_format, const uint8_t* key_data,
+ size_t key_data_length, const Logger& logger,
+ keymaster_error_t* error) = 0;
+ virtual Key* LoadKey(const UnencryptedKeyBlob& blob, const Logger& logger,
+ keymaster_error_t* error) = 0;
+
+ // Informational methods.
+ virtual const keymaster_key_format_t* SupportedImportFormats(size_t* format_count) = 0;
+ virtual const keymaster_key_format_t* SupportedExportFormats(size_t* format_count) = 0;
+};
+
+typedef AbstractFactoryRegistry<KeyFactory> KeyFactoryRegistry;
+
class KeyBlob;
class Operation;
class UnencryptedKeyBlob;
class Key {
public:
- static Key* CreateKey(const UnencryptedKeyBlob& blob, const Logger& logger,
- keymaster_error_t* error);
- static Key* GenerateKey(const AuthorizationSet& key_description, const Logger& logger,
- keymaster_error_t* error);
- static Key* ImportKey(const AuthorizationSet& key_description,
- keymaster_key_format_t key_format, const uint8_t* key_data,
- size_t key_data_length, const Logger& logger, keymaster_error_t* error);
-
virtual ~Key() {}
virtual Operation* CreateOperation(keymaster_purpose_t purpose, keymaster_error_t* error) = 0;
diff --git a/rsa_key.cpp b/rsa_key.cpp
index 2e27883..35fe82e 100644
--- a/rsa_key.cpp
+++ b/rsa_key.cpp
@@ -30,9 +30,23 @@
const uint32_t RSA_DEFAULT_KEY_SIZE = 2048;
const uint64_t RSA_DEFAULT_EXPONENT = 65537;
-/* static */
-RsaKey* RsaKey::GenerateKey(const AuthorizationSet& key_description, const Logger& logger,
- keymaster_error_t* error) {
+class RsaKeyFactory : public AsymmetricKeyFactory {
+ public:
+ virtual keymaster_algorithm_t registry_key() const { return KM_ALGORITHM_RSA; }
+ virtual Key* GenerateKey(const AuthorizationSet& key_description, const Logger& logger,
+ keymaster_error_t* error);
+ virtual Key* ImportKey(const AuthorizationSet& key_description,
+ keymaster_key_format_t key_format, const uint8_t* key_data,
+ size_t key_data_length, const Logger& logger, keymaster_error_t* error);
+ virtual Key* LoadKey(const UnencryptedKeyBlob& blob, const Logger& logger,
+ keymaster_error_t* error) {
+ return new RsaKey(blob, logger, error);
+ }
+};
+static KeyFactoryRegistry::Registration<RsaKeyFactory> registration;
+
+Key* RsaKeyFactory::GenerateKey(const AuthorizationSet& key_description, const Logger& logger,
+ keymaster_error_t* error) {
if (!error)
return NULL;
@@ -47,7 +61,7 @@
authorizations.push_back(Authorization(TAG_KEY_SIZE, key_size));
UniquePtr<BIGNUM, BIGNUM_Delete> exponent(BN_new());
- UniquePtr<RSA, RSA_Delete> rsa_key(RSA_new());
+ UniquePtr<RSA, RsaKey::RSA_Delete> rsa_key(RSA_new());
UniquePtr<EVP_PKEY, EVP_PKEY_Delete> pkey(EVP_PKEY_new());
if (rsa_key.get() == NULL || pkey.get() == NULL) {
*error = KM_ERROR_MEMORY_ALLOCATION_FAILED;
@@ -65,16 +79,24 @@
return new_key;
}
-/* static */
-RsaKey* RsaKey::ImportKey(const AuthorizationSet& key_description, EVP_PKEY* pkey,
- const Logger& logger, keymaster_error_t* error) {
+Key* RsaKeyFactory::ImportKey(const AuthorizationSet& key_description,
+ keymaster_key_format_t key_format, const uint8_t* key_data,
+ size_t key_data_length, const Logger& logger,
+ keymaster_error_t* error) {
if (!error)
return NULL;
- *error = KM_ERROR_UNKNOWN_ERROR;
- UniquePtr<RSA, RSA_Delete> rsa_key(EVP_PKEY_get1_RSA(pkey));
- if (!rsa_key.get())
+ UniquePtr<EVP_PKEY, EVP_PKEY_Delete> pkey(
+ ExtractEvpKey(key_format, KM_ALGORITHM_RSA, key_data, key_data_length, error));
+ if (*error != KM_ERROR_OK)
return NULL;
+ assert(pkey.get());
+
+ UniquePtr<RSA, RsaKey::RSA_Delete> rsa_key(EVP_PKEY_get1_RSA(pkey.get()));
+ if (!rsa_key.get()) {
+ *error = KM_ERROR_UNKNOWN_ERROR;
+ return NULL;
+ }
AuthorizationSet authorizations(key_description);
diff --git a/rsa_key.h b/rsa_key.h
index cb341e2..e98ece3 100644
--- a/rsa_key.h
+++ b/rsa_key.h
@@ -23,17 +23,16 @@
namespace keymaster {
+class RsaKeyFactory;
+
class RsaKey : public AsymmetricKey {
public:
- static RsaKey* GenerateKey(const AuthorizationSet& key_description, const Logger& logger,
- keymaster_error_t* error);
- static RsaKey* ImportKey(const AuthorizationSet& key_description, EVP_PKEY* pkey,
- const Logger& logger, keymaster_error_t* error);
- RsaKey(const UnencryptedKeyBlob& blob, const Logger& logger, keymaster_error_t* error);
-
virtual Operation* CreateOperation(keymaster_purpose_t purpose, keymaster_error_t* error);
private:
+ friend class RsaKeyFactory;
+
+ RsaKey(const UnencryptedKeyBlob& blob, const Logger& logger, keymaster_error_t* error);
RsaKey(RSA* rsa_key, const AuthorizationSet& auths, const Logger& logger)
: AsymmetricKey(auths, logger), rsa_key_(rsa_key) {}
diff --git a/symmetric_key.cpp b/symmetric_key.cpp
index 9678f9d..a850076 100644
--- a/symmetric_key.cpp
+++ b/symmetric_key.cpp
@@ -27,27 +27,13 @@
namespace keymaster {
-/* static */
-SymmetricKey* SymmetricKey::GenerateKey(keymaster_algorithm_t algorithm,
- const AuthorizationSet& key_description,
- const Logger& logger, keymaster_error_t* error) {
+Key* SymmetricKeyFactory::GenerateKey(const AuthorizationSet& key_description, const Logger& logger,
+ keymaster_error_t* error) {
if (!error)
return NULL;
-
*error = KM_ERROR_OK;
- UniquePtr<SymmetricKey> key;
- switch (algorithm) {
- case KM_ALGORITHM_AES:
- key.reset(new AesKey(key_description, logger));
- break;
- case KM_ALGORITHM_HMAC:
- key.reset(new HmacKey(key_description, logger));
- break;
- default:
- *error = KM_ERROR_UNSUPPORTED_ALGORITHM;
- return NULL;
- };
+ UniquePtr<SymmetricKey> key(CreateKey(key_description, logger));
uint32_t key_size_bits;
if (!key_description.GetTagValue(TAG_KEY_SIZE, &key_size_bits) || key_size_bits % 8 != 0) {
@@ -56,7 +42,7 @@
}
key->key_data_size_ = key_size_bits / 8;
- if (key->key_data_size_ > MAX_KEY_SIZE) {
+ if (key->key_data_size_ > SymmetricKey::MAX_KEY_SIZE) {
*error = KM_ERROR_UNSUPPORTED_KEY_SIZE;
return NULL;
}
@@ -71,39 +57,18 @@
return key.release();
}
-/* static */
-SymmetricKey* SymmetricKey::CreateKey(keymaster_algorithm_t algorithm,
- const UnencryptedKeyBlob& blob, const Logger& logger,
- keymaster_error_t* error) {
- switch (algorithm) {
- case KM_ALGORITHM_AES:
- return new AesKey(blob, logger, error);
- case KM_ALGORITHM_HMAC:
- return new HmacKey(blob, logger, error);
- default:
- *error = KM_ERROR_UNSUPPORTED_ALGORITHM;
- return NULL;
- }
-}
-
SymmetricKey::SymmetricKey(const UnencryptedKeyBlob& blob, const Logger& logger,
keymaster_error_t* error)
: Key(blob, logger), key_data_size_(blob.unencrypted_key_material_length()) {
+ memcpy(key_data_, blob.unencrypted_key_material(), key_data_size_);
if (error)
- *error = LoadKey(blob);
+ *error = KM_ERROR_OK;
}
SymmetricKey::~SymmetricKey() {
memset_s(key_data_, 0, MAX_KEY_SIZE);
}
-keymaster_error_t SymmetricKey::LoadKey(const UnencryptedKeyBlob& blob) {
- assert(blob.unencrypted_key_material_length() == key_data_size_);
- memcpy(key_data_, blob.unencrypted_key_material(), key_data_size_);
-
- return KM_ERROR_OK;
-}
-
keymaster_error_t SymmetricKey::key_material(UniquePtr<uint8_t[]>* key_material,
size_t* size) const {
*size = key_data_size_;
diff --git a/symmetric_key.h b/symmetric_key.h
index dc65349..bef921f 100644
--- a/symmetric_key.h
+++ b/symmetric_key.h
@@ -21,19 +21,38 @@
namespace keymaster {
+class SymmetricKey;
+
+class SymmetricKeyFactory : public KeyFactory {
+ virtual Key* GenerateKey(const AuthorizationSet& key_description, const Logger& logger,
+ keymaster_error_t* error);
+ virtual Key* ImportKey(const AuthorizationSet&, keymaster_key_format_t, const uint8_t*, size_t,
+ const Logger&, keymaster_error_t* error) {
+ *error = KM_ERROR_UNIMPLEMENTED;
+ return NULL;
+ }
+
+ virtual const keymaster_key_format_t* SupportedImportFormats(size_t* format_count) {
+ return NoFormats(format_count);
+ }
+ virtual const keymaster_key_format_t* SupportedExportFormats(size_t* format_count) {
+ return NoFormats(format_count);
+ };
+
+ private:
+ virtual SymmetricKey* CreateKey(const AuthorizationSet& auths, const Logger& logger) = 0;
+ const keymaster_key_format_t* NoFormats(size_t* format_count) {
+ *format_count = 0;
+ return NULL;
+ }
+};
+
class SymmetricKey : public Key {
public:
static const int MAX_KEY_SIZE = 32;
static const int MAX_MAC_LENGTH = 32;
static const uint32_t MAX_CHUNK_LENGTH = 64 * 1024;
- static SymmetricKey* GenerateKey(keymaster_algorithm_t algorithm,
- const AuthorizationSet& key_description, const Logger& logger,
- keymaster_error_t* error);
- static SymmetricKey* CreateKey(keymaster_algorithm_t algorithm, const UnencryptedKeyBlob& blob,
- const Logger& logger, keymaster_error_t* error);
-
- SymmetricKey(const UnencryptedKeyBlob& blob, const Logger& logger, keymaster_error_t* error);
~SymmetricKey();
virtual keymaster_error_t key_material(UniquePtr<uint8_t[]>* key_material, size_t* size) const;
@@ -45,12 +64,16 @@
protected:
keymaster_error_t error_;
+ SymmetricKey(const UnencryptedKeyBlob& blob, const Logger& logger, keymaster_error_t* error);
+
const uint8_t* key_data() const { return key_data_; }
size_t key_data_size() const { return key_data_size_; }
SymmetricKey(const AuthorizationSet& auths, const Logger& logger) : Key(auths, logger) {}
private:
+ friend SymmetricKeyFactory;
+
keymaster_error_t LoadKey(const UnencryptedKeyBlob& blob);
size_t key_data_size_;