Add authorization enforcement to AndroidKeymaster.

Note: Moving List.h into system/keymaster is unfortunate, but required
to allow Trusty to use it.  b/22088154 tracks cleaning this up.

Bug: 19511945
Change-Id: Ia1dfe5fda5ea78935611b0a7656b323770edcbae
diff --git a/List.h b/List.h
new file mode 100644
index 0000000..403cd7f
--- /dev/null
+++ b/List.h
@@ -0,0 +1,332 @@
+/*
+ * Copyright (C) 2005 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+//
+// Templated list class.  Normally we'd use STL, but we don't have that.
+// This class mimics STL's interfaces.
+//
+// Objects are copied into the list with the '=' operator or with copy-
+// construction, so if the compiler's auto-generated versions won't work for
+// you, define your own.
+//
+// The only class you want to use from here is "List".
+//
+#ifndef _LIBS_UTILS_LIST_H
+#define _LIBS_UTILS_LIST_H
+
+#include <stddef.h>
+#include <stdint.h>
+
+namespace android {
+
+/*
+ * Doubly-linked list.  Instantiate with "List<MyClass> myList".
+ *
+ * Objects added to the list are copied using the assignment operator,
+ * so this must be defined.
+ */
+template<typename T> 
+class List 
+{
+protected:
+    /*
+     * One element in the list.
+     */
+    class _Node {
+    public:
+        explicit _Node(const T& val) : mVal(val) {}
+        ~_Node() {}
+        inline T& getRef() { return mVal; }
+        inline const T& getRef() const { return mVal; }
+        inline _Node* getPrev() const { return mpPrev; }
+        inline _Node* getNext() const { return mpNext; }
+        inline void setVal(const T& val) { mVal = val; }
+        inline void setPrev(_Node* ptr) { mpPrev = ptr; }
+        inline void setNext(_Node* ptr) { mpNext = ptr; }
+    private:
+        friend class List;
+        friend class _ListIterator;
+        T           mVal;
+        _Node*      mpPrev;
+        _Node*      mpNext;
+    };
+
+    /*
+     * Iterator for walking through the list.
+     */
+    
+    template <typename TYPE>
+    struct CONST_ITERATOR {
+        typedef _Node const * NodePtr;
+        typedef const TYPE Type;
+    };
+    
+    template <typename TYPE>
+    struct NON_CONST_ITERATOR {
+        typedef _Node* NodePtr;
+        typedef TYPE Type;
+    };
+    
+    template<
+        typename U,
+        template <class> class Constness
+    > 
+    class _ListIterator {
+        typedef _ListIterator<U, Constness>     _Iter;
+        typedef typename Constness<U>::NodePtr  _NodePtr;
+        typedef typename Constness<U>::Type     _Type;
+
+        explicit _ListIterator(_NodePtr ptr) : mpNode(ptr) {}
+
+    public:
+        _ListIterator() {}
+        _ListIterator(const _Iter& rhs) : mpNode(rhs.mpNode) {}
+        ~_ListIterator() {}
+        
+        // this will handle conversions from iterator to const_iterator
+        // (and also all convertible iterators)
+        // Here, in this implementation, the iterators can be converted
+        // if the nodes can be converted
+        template<typename V> explicit 
+        _ListIterator(const V& rhs) : mpNode(rhs.mpNode) {}
+        
+
+        /*
+         * Dereference operator.  Used to get at the juicy insides.
+         */
+        _Type& operator*() const { return mpNode->getRef(); }
+        _Type* operator->() const { return &(mpNode->getRef()); }
+
+        /*
+         * Iterator comparison.
+         */
+        inline bool operator==(const _Iter& right) const { 
+            return mpNode == right.mpNode; }
+        
+        inline bool operator!=(const _Iter& right) const { 
+            return mpNode != right.mpNode; }
+
+        /*
+         * handle comparisons between iterator and const_iterator
+         */
+        template<typename OTHER>
+        inline bool operator==(const OTHER& right) const { 
+            return mpNode == right.mpNode; }
+        
+        template<typename OTHER>
+        inline bool operator!=(const OTHER& right) const { 
+            return mpNode != right.mpNode; }
+
+        /*
+         * Incr/decr, used to move through the list.
+         */
+        inline _Iter& operator++() {     // pre-increment
+            mpNode = mpNode->getNext();
+            return *this;
+        }
+        const _Iter operator++(int) {    // post-increment
+            _Iter tmp(*this);
+            mpNode = mpNode->getNext();
+            return tmp;
+        }
+        inline _Iter& operator--() {     // pre-increment
+            mpNode = mpNode->getPrev();
+            return *this;
+        }
+        const _Iter operator--(int) {   // post-increment
+            _Iter tmp(*this);
+            mpNode = mpNode->getPrev();
+            return tmp;
+        }
+
+        inline _NodePtr getNode() const { return mpNode; }
+
+        _NodePtr mpNode;    /* should be private, but older gcc fails */
+    private:
+        friend class List;
+    };
+
+public:
+    List() {
+        prep();
+    }
+    List(const List<T>& src) {      // copy-constructor
+        prep();
+        insert(begin(), src.begin(), src.end());
+    }
+    virtual ~List() {
+        clear();
+        delete[] (unsigned char*) mpMiddle;
+    }
+
+    typedef _ListIterator<T, NON_CONST_ITERATOR> iterator;
+    typedef _ListIterator<T, CONST_ITERATOR> const_iterator;
+
+    List<T>& operator=(const List<T>& right);
+
+    /* returns true if the list is empty */
+    inline bool empty() const { return mpMiddle->getNext() == mpMiddle; }
+
+    /* return #of elements in list */
+    size_t size() const {
+        return size_t(distance(begin(), end()));
+    }
+
+    /*
+     * Return the first element or one past the last element.  The
+     * _Node* we're returning is converted to an "iterator" by a
+     * constructor in _ListIterator.
+     */
+    inline iterator begin() { 
+        return iterator(mpMiddle->getNext()); 
+    }
+    inline const_iterator begin() const { 
+        return const_iterator(const_cast<_Node const*>(mpMiddle->getNext())); 
+    }
+    inline iterator end() { 
+        return iterator(mpMiddle); 
+    }
+    inline const_iterator end() const { 
+        return const_iterator(const_cast<_Node const*>(mpMiddle)); 
+    }
+
+    /* add the object to the head or tail of the list */
+    void push_front(const T& val) { insert(begin(), val); }
+    void push_back(const T& val) { insert(end(), val); }
+
+    /* insert before the current node; returns iterator at new node */
+    iterator insert(iterator posn, const T& val) 
+    {
+        _Node* newNode = new _Node(val);        // alloc & copy-construct
+        newNode->setNext(posn.getNode());
+        newNode->setPrev(posn.getNode()->getPrev());
+        posn.getNode()->getPrev()->setNext(newNode);
+        posn.getNode()->setPrev(newNode);
+        return iterator(newNode);
+    }
+
+    /* insert a range of elements before the current node */
+    void insert(iterator posn, const_iterator first, const_iterator last) {
+        for ( ; first != last; ++first)
+            insert(posn, *first);
+    }
+
+    /* remove one entry; returns iterator at next node */
+    iterator erase(iterator posn) {
+        _Node* pNext = posn.getNode()->getNext();
+        _Node* pPrev = posn.getNode()->getPrev();
+        pPrev->setNext(pNext);
+        pNext->setPrev(pPrev);
+        delete posn.getNode();
+        return iterator(pNext);
+    }
+
+    /* remove a range of elements */
+    iterator erase(iterator first, iterator last) {
+        while (first != last)
+            erase(first++);     // don't erase than incr later!
+        return iterator(last);
+    }
+
+    /* remove all contents of the list */
+    void clear() {
+        _Node* pCurrent = mpMiddle->getNext();
+        _Node* pNext;
+
+        while (pCurrent != mpMiddle) {
+            pNext = pCurrent->getNext();
+            delete pCurrent;
+            pCurrent = pNext;
+        }
+        mpMiddle->setPrev(mpMiddle);
+        mpMiddle->setNext(mpMiddle);
+    }
+
+    /*
+     * Measure the distance between two iterators.  On exist, "first"
+     * will be equal to "last".  The iterators must refer to the same
+     * list.
+     *
+     * FIXME: This is actually a generic iterator function. It should be a 
+     * template function at the top-level with specializations for things like
+     * vector<>, which can just do pointer math). Here we limit it to
+     * _ListIterator of the same type but different constness.
+     */
+    template<
+        typename U,
+        template <class> class CL,
+        template <class> class CR
+    > 
+    ptrdiff_t distance(
+            _ListIterator<U, CL> first, _ListIterator<U, CR> last) const 
+    {
+        ptrdiff_t count = 0;
+        while (first != last) {
+            ++first;
+            ++count;
+        }
+        return count;
+    }
+
+private:
+    /*
+     * I want a _Node but don't need it to hold valid data.  More
+     * to the point, I don't want T's constructor to fire, since it
+     * might have side-effects or require arguments.  So, we do this
+     * slightly uncouth storage alloc.
+     */
+    void prep() {
+        mpMiddle = (_Node*) new unsigned char[sizeof(_Node)];
+        mpMiddle->setPrev(mpMiddle);
+        mpMiddle->setNext(mpMiddle);
+    }
+
+    /*
+     * This node plays the role of "pointer to head" and "pointer to tail".
+     * It sits in the middle of a circular list of nodes.  The iterator
+     * runs around the circle until it encounters this one.
+     */
+    _Node*      mpMiddle;
+};
+
+/*
+ * Assignment operator.
+ *
+ * The simplest way to do this would be to clear out the target list and
+ * fill it with the source.  However, we can speed things along by
+ * re-using existing elements.
+ */
+template<class T>
+List<T>& List<T>::operator=(const List<T>& right)
+{
+    if (this == &right)
+        return *this;       // self-assignment
+    iterator firstDst = begin();
+    iterator lastDst = end();
+    const_iterator firstSrc = right.begin();
+    const_iterator lastSrc = right.end();
+    while (firstSrc != lastSrc && firstDst != lastDst)
+        *firstDst++ = *firstSrc++;
+    if (firstSrc == lastSrc)        // ran out of elements in source?
+        erase(firstDst, lastDst);   // yes, erase any extras
+    else
+        insert(lastDst, firstSrc, lastSrc);     // copy remaining over
+    return *this;
+}
+
+}; // namespace android
+
+#endif // _LIBS_UTILS_LIST_H
diff --git a/Makefile b/Makefile
index 88131e8..83acdae 100644
--- a/Makefile
+++ b/Makefile
@@ -237,6 +237,7 @@
 	integrity_assured_key_blob.o \
 	key.o \
 	keymaster0_engine.o \
+	keymaster_enforcement.o \
 	logger.o \
 	ocb.o \
 	ocb_utils.o \
diff --git a/android_keymaster.cpp b/android_keymaster.cpp
index 7ebc8de..2bc6ea7 100644
--- a/android_keymaster.cpp
+++ b/android_keymaster.cpp
@@ -234,21 +234,30 @@
     if (!factory)
         return;
 
-    response->error = KM_ERROR_INCOMPATIBLE_PURPOSE;
-    if (!key->authorizations().Contains(TAG_PURPOSE, request.purpose) &&
-        !factory->is_public_key_operation())
-        return;
-
     UniquePtr<Operation> operation(
         factory->CreateOperation(*key, request.additional_params, &response->error));
     if (operation.get() == NULL)
         return;
 
+    if (context_->enforcement_policy()) {
+        km_id_t key_id;
+        response->error = KM_ERROR_UNKNOWN_ERROR;
+        if (!context_->enforcement_policy()->CreateKeyId(request.key_blob, &key_id))
+            return;
+        operation->set_key_id(key_id);
+        response->error = context_->enforcement_policy()->AuthorizeOperation(
+            request.purpose, key_id, key->authorizations(), request.additional_params,
+            0 /* op_handle */, true /* is_begin_operation */);
+        if (response->error != KM_ERROR_OK)
+            return;
+    }
+
     response->output_params.Clear();
     response->error = operation->Begin(request.additional_params, &response->output_params);
     if (response->error != KM_ERROR_OK)
         return;
 
+    operation->SetAuthorizations(key->authorizations());
     response->error = operation_table_->Add(operation.release(), &response->op_handle);
 }
 
@@ -262,6 +271,14 @@
     if (operation == NULL)
         return;
 
+    if (context_->enforcement_policy()) {
+        response->error = context_->enforcement_policy()->AuthorizeOperation(
+            operation->purpose(), operation->key_id(), operation->authorizations(),
+            request.additional_params, request.op_handle, false /* is_begin_operation */);
+        if (response->error != KM_ERROR_OK)
+            return;
+    }
+
     response->error =
         operation->Update(request.additional_params, request.input, &response->output_params,
                           &response->output, &response->input_consumed);
@@ -281,6 +298,14 @@
     if (operation == NULL)
         return;
 
+    if (context_->enforcement_policy()) {
+        response->error = context_->enforcement_policy()->AuthorizeOperation(
+            operation->purpose(), operation->key_id(), operation->authorizations(),
+            request.additional_params, request.op_handle, false /* is_begin_operation */);
+        if (response->error != KM_ERROR_OK)
+            return;
+    }
+
     response->error = operation->Finish(request.additional_params, request.signature,
                                         &response->output_params, &response->output);
     operation_table_->Delete(request.op_handle);
diff --git a/android_keymaster_test.cpp b/android_keymaster_test.cpp
index f7b37d3..d2ef08d 100644
--- a/android_keymaster_test.cpp
+++ b/android_keymaster_test.cpp
@@ -19,10 +19,13 @@
 #include <vector>
 
 #include <hardware/keymaster0.h>
+#include <keymaster/key_factory.h>
+#include <keymaster/soft_keymaster_context.h>
 #include <keymaster/soft_keymaster_device.h>
 #include <keymaster/softkeymaster.h>
 
 #include "android_keymaster_test_utils.h"
+#include "keymaster0_engine.h"
 
 using std::ifstream;
 using std::istreambuf_iterator;
@@ -43,11 +46,36 @@
 
 StdoutLogger logger;
 
+class TestKeymasterEnforcement : public KeymasterEnforcement {
+  public:
+    TestKeymasterEnforcement() : KeymasterEnforcement(3, 3) {}
+
+    virtual bool activation_date_valid(uint64_t /* activation_date */) const { return true; }
+    virtual bool expiration_date_passed(uint64_t /* expiration_date */) const { return false; }
+    virtual bool auth_token_timed_out(const hw_auth_token_t& /* token */,
+                                      uint32_t /* timeout */) const {
+        return false;
+    }
+    virtual uint32_t get_current_time() const { return 0; }
+    virtual bool ValidateTokenSignature(const hw_auth_token_t& /* token */) const { return true; }
+};
+
+class TestKeymasterContext : public SoftKeymasterContext {
+  public:
+    TestKeymasterContext(keymaster0_device_t* keymaster0 = nullptr)
+        : SoftKeymasterContext(keymaster0) {}
+
+    KeymasterEnforcement* enforcement_policy() override { return &test_policy_; }
+
+  private:
+    TestKeymasterEnforcement test_policy_;
+};
+
 class SoftKeymasterTestInstanceCreator : public Keymaster1TestInstanceCreator {
   public:
     keymaster1_device_t* CreateDevice() const override {
         std::cerr << "Creating software-only device" << std::endl;
-        SoftKeymasterDevice* device = new SoftKeymasterDevice;
+        SoftKeymasterDevice* device = new SoftKeymasterDevice(new TestKeymasterContext);
         return device->keymaster_device();
     }
 
@@ -77,7 +105,8 @@
 
         counting_keymaster0_device_ = new Keymaster0CountingWrapper(keymaster0_device);
 
-        SoftKeymasterDevice* keymaster = new SoftKeymasterDevice(counting_keymaster0_device_);
+        SoftKeymasterDevice* keymaster =
+            new SoftKeymasterDevice(new TestKeymasterContext(counting_keymaster0_device_));
         return keymaster->keymaster_device();
     }
 
@@ -641,7 +670,10 @@
                                            .RsaEncryptionKey(256, 3)
                                            .Digest(KM_DIGEST_NONE)
                                            .Padding(KM_PAD_NONE)));
-    ASSERT_EQ(KM_ERROR_INCOMPATIBLE_PURPOSE, BeginOperation(KM_PURPOSE_SIGN));
+    AuthorizationSet begin_params(client_params());
+    begin_params.push_back(TAG_PADDING, KM_PAD_NONE);
+    begin_params.push_back(TAG_DIGEST, KM_DIGEST_NONE);
+    ASSERT_EQ(KM_ERROR_INCOMPATIBLE_PURPOSE, BeginOperation(KM_PURPOSE_SIGN, begin_params));
 
     if (GetParam()->algorithm_in_hardware(KM_ALGORITHM_RSA))
         EXPECT_EQ(2, GetParam()->keymaster0_calls());
@@ -1939,7 +1971,10 @@
 TEST_P(EncryptionOperationsTest, RsaEncryptWithSigningKey) {
     ASSERT_EQ(KM_ERROR_OK,
               GenerateKey(AuthorizationSetBuilder().RsaSigningKey(256, 3).Padding(KM_PAD_NONE)));
-    ASSERT_EQ(KM_ERROR_INCOMPATIBLE_PURPOSE, BeginOperation(KM_PURPOSE_DECRYPT));
+
+    AuthorizationSet begin_params(client_params());
+    begin_params.push_back(TAG_PADDING, KM_PAD_NONE);
+    ASSERT_EQ(KM_ERROR_INCOMPATIBLE_PURPOSE, BeginOperation(KM_PURPOSE_DECRYPT, begin_params));
 
     if (GetParam()->algorithm_in_hardware(KM_ALGORITHM_RSA))
         EXPECT_EQ(2, GetParam()->keymaster0_calls());
@@ -2746,6 +2781,51 @@
     EXPECT_EQ(0, GetParam()->keymaster0_calls());
 }
 
+typedef Keymaster1Test MaxOperationsTest;
+INSTANTIATE_TEST_CASE_P(AndroidKeymasterTest, MaxOperationsTest, test_params);
+
+TEST_P(MaxOperationsTest, TestLimit) {
+    ASSERT_EQ(KM_ERROR_OK, GenerateKey(AuthorizationSetBuilder()
+                                           .AesEncryptionKey(128)
+                                           .EcbMode()
+                                           .Authorization(TAG_PADDING, KM_PAD_NONE)
+                                           .Authorization(TAG_MAX_USES_PER_BOOT, 3)));
+
+    string message = "1234567890123456";
+    string ciphertext1 = EncryptMessage(message, KM_MODE_ECB, KM_PAD_NONE);
+    string ciphertext2 = EncryptMessage(message, KM_MODE_ECB, KM_PAD_NONE);
+    string ciphertext3 = EncryptMessage(message, KM_MODE_ECB, KM_PAD_NONE);
+
+    // Fourth time should fail.
+    AuthorizationSet begin_params(client_params());
+    begin_params.push_back(TAG_BLOCK_MODE, KM_MODE_ECB);
+    begin_params.push_back(TAG_PADDING, KM_PAD_NONE);
+    EXPECT_EQ(KM_ERROR_KEY_MAX_OPS_EXCEEDED, BeginOperation(KM_PURPOSE_ENCRYPT, begin_params));
+
+    EXPECT_EQ(0, GetParam()->keymaster0_calls());
+}
+
+TEST_P(MaxOperationsTest, TestAbort) {
+    ASSERT_EQ(KM_ERROR_OK, GenerateKey(AuthorizationSetBuilder()
+                                           .AesEncryptionKey(128)
+                                           .EcbMode()
+                                           .Authorization(TAG_PADDING, KM_PAD_NONE)
+                                           .Authorization(TAG_MAX_USES_PER_BOOT, 3)));
+
+    string message = "1234567890123456";
+    string ciphertext1 = EncryptMessage(message, KM_MODE_ECB, KM_PAD_NONE);
+    string ciphertext2 = EncryptMessage(message, KM_MODE_ECB, KM_PAD_NONE);
+    string ciphertext3 = EncryptMessage(message, KM_MODE_ECB, KM_PAD_NONE);
+
+    // Fourth time should fail.
+    AuthorizationSet begin_params(client_params());
+    begin_params.push_back(TAG_BLOCK_MODE, KM_MODE_ECB);
+    begin_params.push_back(TAG_PADDING, KM_PAD_NONE);
+    EXPECT_EQ(KM_ERROR_KEY_MAX_OPS_EXCEEDED, BeginOperation(KM_PURPOSE_ENCRYPT, begin_params));
+
+    EXPECT_EQ(0, GetParam()->keymaster0_calls());
+}
+
 typedef Keymaster1Test AddEntropyTest;
 INSTANTIATE_TEST_CASE_P(AndroidKeymasterTest, AddEntropyTest, test_params);
 
diff --git a/include/keymaster/keymaster_context.h b/include/keymaster/keymaster_context.h
index be4f8d7..338b408 100644
--- a/include/keymaster/keymaster_context.h
+++ b/include/keymaster/keymaster_context.h
@@ -20,6 +20,7 @@
 #include <assert.h>
 
 #include <hardware/keymaster_defs.h>
+#include <keymaster/keymaster_enforcement.h>
 
 namespace keymaster {
 
@@ -121,6 +122,11 @@
      */
     virtual keymaster_error_t GenerateRandom(uint8_t* buf, size_t length) const = 0;
 
+    /**
+     * Return the enforcement policy for this context, or null if no enforcement should be done.
+     */
+    virtual KeymasterEnforcement* enforcement_policy() = 0;
+
   private:
     // Uncopyable.
     KeymasterContext(const KeymasterContext&);
diff --git a/include/keymaster/keymaster_enforcement.h b/include/keymaster/keymaster_enforcement.h
index a0fccdf..69ef5e3 100644
--- a/include/keymaster/keymaster_enforcement.h
+++ b/include/keymaster/keymaster_enforcement.h
@@ -19,8 +19,6 @@
 
 #include <stdio.h>
 
-#include <utils/List.h>
-
 #include <keymaster/authorization_set.h>
 
 namespace keymaster {
@@ -35,17 +33,16 @@
      */
 };
 
-class KeymasterEnforcement {
+class AccessTimeMap;
+class AccessCountMap;
 
+class KeymasterEnforcement {
   public:
     /**
-     * Construct a KeymasterEnforcement.  Takes ownership of the context.
+     * Construct a KeymasterEnforcement.
      */
-    explicit KeymasterEnforcement(uint32_t max_access_time_map_size,
-                                  uint32_t max_access_count_map_size)
-        : access_time_map_(max_access_time_map_size), access_count_map_(max_access_count_map_size) {
-    }
-    virtual ~KeymasterEnforcement() {}
+    KeymasterEnforcement(uint32_t max_access_time_map_size, uint32_t max_access_count_map_size);
+    virtual ~KeymasterEnforcement();
 
     /**
      * Iterates through the authorization set and returns the corresponding keymaster error. Will
@@ -159,52 +156,10 @@
                           const keymaster_operation_handle_t op_handle,
                           bool is_begin_operation) const;
 
-    class AccessTimeMap {
-      public:
-        AccessTimeMap(uint32_t max_size) : max_size_(max_size) {}
-
-        /* If the key is found, returns true and fills \p last_access_time.  If not found returns
-         * false. */
-        bool LastKeyAccessTime(km_id_t keyid, uint32_t* last_access_time) const;
-
-        /* Updates the last key access time with the currentTime parameter.  Adds the key if
-         * needed, returning false if key cannot be added because list is full. */
-        bool UpdateKeyAccessTime(km_id_t keyid, uint32_t current_time, uint32_t timeout);
-
-      private:
-        struct AccessTime {
-            km_id_t keyid;
-            uint32_t access_time;
-            uint32_t timeout;
-        };
-        android::List<AccessTime> last_access_list_;
-        const uint32_t max_size_;
-    };
-
-    class AccessCountMap {
-      public:
-        AccessCountMap(uint32_t max_size) : max_size_(max_size) {}
-
-        /* If the key is found, returns true and fills \p count.  If not found returns
-         * false. */
-        bool KeyAccessCount(km_id_t keyid, uint32_t* count) const;
-
-        /* Increments key access count, adding an entry if the key has never been used.  Returns
-         * false if the list has reached maximum size. */
-        bool IncrementKeyAccessCount(km_id_t keyid);
-
-      private:
-        struct AccessCount {
-            km_id_t keyid;
-            uint64_t access_count;
-        };
-        android::List<AccessCount> access_count_list_;
-        const uint32_t max_size_;
-    };
-
-    AccessTimeMap access_time_map_;
-    AccessCountMap access_count_map_;
+    AccessTimeMap* access_time_map_;
+    AccessCountMap* access_count_map_;
 };
+
 }; /* namespace keymaster */
 
 #endif  // ANDROID_LIBRARY_KEYMASTER_ENFORCEMENT_H
diff --git a/include/keymaster/soft_keymaster_context.h b/include/keymaster/soft_keymaster_context.h
index b0a4c1e..8f6fe2d 100644
--- a/include/keymaster/soft_keymaster_context.h
+++ b/include/keymaster/soft_keymaster_context.h
@@ -52,6 +52,11 @@
     keymaster_error_t AddRngEntropy(const uint8_t* buf, size_t length) const override;
     keymaster_error_t GenerateRandom(uint8_t* buf, size_t length) const override;
 
+    KeymasterEnforcement* enforcement_policy() override {
+        // SoftKeymaster does no enforcement; it's all done by Keystore.
+        return nullptr;
+    }
+
   private:
     keymaster_error_t ParseOldSoftkeymasterBlob(const KeymasterKeyBlob& blob,
                                                 KeymasterKeyBlob* key_material,
diff --git a/include/keymaster/soft_keymaster_device.h b/include/keymaster/soft_keymaster_device.h
index a4f85c4..75e0066 100644
--- a/include/keymaster/soft_keymaster_device.h
+++ b/include/keymaster/soft_keymaster_device.h
@@ -45,8 +45,20 @@
  */
 class SoftKeymasterDevice {
   public:
+    /**
+     * Create a SoftKeymasterDevice wrapping the specified HW keymaster0 device, which may be NULL.
+     *
+     * Uses SoftKeymaserContext.
+     */
     SoftKeymasterDevice(keymaster0_device_t* keymaster0_device = nullptr);
 
+    /**
+     * Create a SoftKeymasterDevice that uses the specified KeymasterContext.
+     *
+     * TODO(swillden): Refactor SoftKeymasterDevice construction to make all components injectable.
+     */
+    SoftKeymasterDevice(KeymasterContext* context);
+
     hw_device_t* hw_device();
     keymaster1_device_t* keymaster_device();
 
@@ -56,6 +68,8 @@
     }
 
   private:
+    void initialize(keymaster0_device_t* keymaster0_device);
+
     static void StoreDefaultNewKeyParams(AuthorizationSet* auth_set);
     static keymaster_error_t GetPkcs8KeyAlgorithm(const uint8_t* key, size_t key_length,
                                                   keymaster_algorithm_t* algorithm);
diff --git a/keymaster_enforcement.cpp b/keymaster_enforcement.cpp
index 8273cf5..2dc0d01 100644
--- a/keymaster_enforcement.cpp
+++ b/keymaster_enforcement.cpp
@@ -17,20 +17,64 @@
 #include <keymaster/keymaster_enforcement.h>
 
 #include <assert.h>
+#include <limits.h>
 #include <string.h>
 
-#include <limits>
-
 #include <openssl/evp.h>
 
 #include <hardware/hw_auth_token.h>
 #include <keymaster/android_keymaster_utils.h>
 #include <keymaster/logger.h>
 
+#include "List.h"
+
 using android::List;
 
 namespace keymaster {
 
+class AccessTimeMap {
+  public:
+    AccessTimeMap(uint32_t max_size) : max_size_(max_size) {}
+
+    /* If the key is found, returns true and fills \p last_access_time.  If not found returns
+     * false. */
+    bool LastKeyAccessTime(km_id_t keyid, uint32_t* last_access_time) const;
+
+    /* Updates the last key access time with the currentTime parameter.  Adds the key if
+     * needed, returning false if key cannot be added because list is full. */
+    bool UpdateKeyAccessTime(km_id_t keyid, uint32_t current_time, uint32_t timeout);
+
+  private:
+    struct AccessTime {
+        km_id_t keyid;
+        uint32_t access_time;
+        uint32_t timeout;
+    };
+    android::List<AccessTime> last_access_list_;
+    const uint32_t max_size_;
+};
+
+class AccessCountMap {
+  public:
+    AccessCountMap(uint32_t max_size) : max_size_(max_size) {}
+
+    /* If the key is found, returns true and fills \p count.  If not found returns
+     * false. */
+    bool KeyAccessCount(km_id_t keyid, uint32_t* count) const;
+
+    /* Increments key access count, adding an entry if the key has never been used.  Returns
+     * false if the list has reached maximum size. */
+    bool IncrementKeyAccessCount(km_id_t keyid);
+
+  private:
+    struct AccessCount {
+        km_id_t keyid;
+        uint64_t access_count;
+    };
+    android::List<AccessCount> access_count_list_;
+    const uint32_t max_size_;
+};
+
 bool is_public_key_algorithm(const AuthorizationSet& auth_set) {
     keymaster_algorithm_t algorithm;
     return auth_set.GetTagValue(TAG_ALGORITHM, &algorithm) &&
@@ -65,6 +109,16 @@
     return purpose == KM_PURPOSE_DECRYPT || purpose == KM_PURPOSE_VERIFY;
 }
 
+KeymasterEnforcement::KeymasterEnforcement(uint32_t max_access_time_map_size,
+                                           uint32_t max_access_count_map_size)
+    : access_time_map_(new (std::nothrow) AccessTimeMap(max_access_time_map_size)),
+      access_count_map_(new (std::nothrow) AccessCountMap(max_access_count_map_size)) {}
+
+KeymasterEnforcement::~KeymasterEnforcement() {
+    delete access_time_map_;
+    delete access_count_map_;
+}
+
 keymaster_error_t KeymasterEnforcement::AuthorizeOperation(const keymaster_purpose_t purpose,
                                                            const km_id_t keyid,
                                                            const AuthorizationSet& auth_set,
@@ -273,15 +327,28 @@
         operation_params.find(KM_TAG_NONCE) != -1)
         return KM_ERROR_CALLER_NONCE_PROHIBITED;
 
-    if (min_ops_timeout != UINT32_MAX &&
-        !access_time_map_.UpdateKeyAccessTime(keyid, get_current_time(), min_ops_timeout)) {
-        LOG_E("Rate-limited keys table full.  Entries will time out.", 0);
-        return KM_ERROR_TOO_MANY_OPERATIONS;
+    if (min_ops_timeout != UINT32_MAX) {
+        if (!access_time_map_) {
+            LOG_S("Rate-limited keys table not allocated.  Rate-limited keys disabled", 0);
+            return KM_ERROR_MEMORY_ALLOCATION_FAILED;
+        }
+
+        if (!access_time_map_->UpdateKeyAccessTime(keyid, get_current_time(), min_ops_timeout)) {
+            LOG_E("Rate-limited keys table full.  Entries will time out.", 0);
+            return KM_ERROR_TOO_MANY_OPERATIONS;
+        }
     }
 
-    if (update_access_count && !access_count_map_.IncrementKeyAccessCount(keyid)) {
-        LOG_E("Usage count-limited keys table full, until reboot.", 0);
-        return KM_ERROR_TOO_MANY_OPERATIONS;
+    if (update_access_count) {
+        if (!access_count_map_) {
+            LOG_S("Usage-count limited keys tabel not allocated.  Count-limited keys disabled", 0);
+            return KM_ERROR_MEMORY_ALLOCATION_FAILED;
+        }
+
+        if (!access_count_map_->IncrementKeyAccessCount(keyid)) {
+            LOG_E("Usage count-limited keys table full, until reboot.", 0);
+            return KM_ERROR_TOO_MANY_OPERATIONS;
+        }
     }
 
     return KM_ERROR_OK;
@@ -316,15 +383,21 @@
 }
 
 bool KeymasterEnforcement::MinTimeBetweenOpsPassed(uint32_t min_time_between, const km_id_t keyid) {
+    if (!access_time_map_)
+        return false;
+
     uint32_t last_access_time;
-    if (!access_time_map_.LastKeyAccessTime(keyid, &last_access_time))
+    if (!access_time_map_->LastKeyAccessTime(keyid, &last_access_time))
         return true;
     return min_time_between <= static_cast<int64_t>(get_current_time()) - last_access_time;
 }
 
 bool KeymasterEnforcement::MaxUsesPerBootNotExceeded(const km_id_t keyid, uint32_t max_uses) {
+    if (!access_count_map_)
+        return false;
+
     uint32_t key_access_count;
-    if (!access_count_map_.KeyAccessCount(keyid, &key_access_count))
+    if (!access_count_map_->KeyAccessCount(keyid, &key_access_count))
         return true;
     return key_access_count < max_uses;
 }
@@ -406,8 +479,7 @@
     return true;
 }
 
-bool KeymasterEnforcement::AccessTimeMap::LastKeyAccessTime(km_id_t keyid,
-                                                            uint32_t* last_access_time) const {
+bool AccessTimeMap::LastKeyAccessTime(km_id_t keyid, uint32_t* last_access_time) const {
     for (auto& entry : last_access_list_)
         if (entry.keyid == keyid) {
             *last_access_time = entry.access_time;
@@ -416,8 +488,7 @@
     return false;
 }
 
-bool KeymasterEnforcement::AccessTimeMap::UpdateKeyAccessTime(km_id_t keyid, uint32_t current_time,
-                                                              uint32_t timeout) {
+bool AccessTimeMap::UpdateKeyAccessTime(km_id_t keyid, uint32_t current_time, uint32_t timeout) {
     List<AccessTime>::iterator iter;
     for (iter = last_access_list_.begin(); iter != last_access_list_.end();) {
         if (iter->keyid == keyid) {
@@ -444,7 +515,7 @@
     return true;
 }
 
-bool KeymasterEnforcement::AccessCountMap::KeyAccessCount(km_id_t keyid, uint32_t* count) const {
+bool AccessCountMap::KeyAccessCount(km_id_t keyid, uint32_t* count) const {
     for (auto& entry : access_count_list_)
         if (entry.keyid == keyid) {
             *count = entry.access_count;
@@ -453,14 +524,15 @@
     return false;
 }
 
-template <typename T> T max_value(T) {
-    return std::numeric_limits<T>::max();
-}
-
-bool KeymasterEnforcement::AccessCountMap::IncrementKeyAccessCount(km_id_t keyid) {
+bool AccessCountMap::IncrementKeyAccessCount(km_id_t keyid) {
     for (auto& entry : access_count_list_)
         if (entry.keyid == keyid) {
-            if (entry.access_count < max_value(entry.access_count))
+            // Note that the 'if' below will always be true because KM_TAG_MAX_USES_PER_BOOT is a
+            // uint32_t, and as soon as entry.access_count reaches the specified maximum value
+            // operation requests will be rejected and access_count won't be incremented any more.
+            // And, besides, UINT64_MAX is huge.  But we ensure that it doesn't wrap anyway, out of
+            // an abundance of caution.
+            if (entry.access_count < UINT64_MAX)
                 ++entry.access_count;
             return true;
         }
diff --git a/operation.h b/operation.h
index a950b42..74948fa 100644
--- a/operation.h
+++ b/operation.h
@@ -21,8 +21,9 @@
 #include <stdint.h>
 #include <stdlib.h>
 
-#include <keymaster/android_keymaster_utils.h>
 #include <hardware/keymaster_defs.h>
+#include <keymaster/android_keymaster_utils.h>
+#include <keymaster/authorization_set.h>
 #include <keymaster/logger.h>
 
 namespace keymaster {
@@ -91,6 +92,14 @@
 
     keymaster_purpose_t purpose() const { return purpose_; }
 
+    void set_key_id(uint64_t key_id) { key_id_ = key_id; }
+    uint64_t key_id() const { return key_id_; }
+
+    void SetAuthorizations(const AuthorizationSet& auths) {
+        key_auths_.Reinitialize(auths.data(), auths.size());
+    }
+    const AuthorizationSet authorizations() { return key_auths_; }
+
     virtual keymaster_error_t Begin(const AuthorizationSet& input_params,
                                     AuthorizationSet* output_params) = 0;
     virtual keymaster_error_t Update(const AuthorizationSet& input_params, const Buffer& input,
@@ -102,6 +111,8 @@
 
   private:
     const keymaster_purpose_t purpose_;
+    AuthorizationSet key_auths_;
+    uint64_t key_id_;
 };
 
 }  // namespace keymaster
diff --git a/soft_keymaster_device.cpp b/soft_keymaster_device.cpp
index 852abb2..dba91ee 100644
--- a/soft_keymaster_device.cpp
+++ b/soft_keymaster_device.cpp
@@ -60,6 +60,15 @@
 
 SoftKeymasterDevice::SoftKeymasterDevice(keymaster0_device_t* keymaster0_device)
     : impl_(new AndroidKeymaster(new SoftKeymasterContext(keymaster0_device), 16)) {
+    initialize(keymaster0_device);
+}
+
+SoftKeymasterDevice::SoftKeymasterDevice(KeymasterContext* context)
+    : impl_(new AndroidKeymaster(context, 16)) {
+    initialize(nullptr);
+}
+
+void SoftKeymasterDevice::initialize(keymaster0_device_t* keymaster0_device) {
     static_assert(std::is_standard_layout<SoftKeymasterDevice>::value,
                   "SoftKeymasterDevice must be standard layout");
     static_assert(offsetof(SoftKeymasterDevice, device_) == 0,