tcti: Get rid of unnecessary 4k buffer in TCTI header.

The device TCTI was being lazy and wasn't implementing some features
that the receive function should have. All of this resulted in a
completely unnecessary 4k buffer being allocated as part of every TCTI
context.

In order to support a feature that allows a caller to discover the size
of the response before allocating the response buffer we had to fix
this. The receive algorithm in the device TCTI now reads the response
header first to get the size of the response. If the caller hasn't
supplied a resposne buffer we just return them the size in the supplied
parameter. If they have supplied a response buffer and it's large enough
to hold the response we grab the rest of the response and put it
directly into the caller supplied buffer.

As part of getting this 4k back we had to do some refactoring as well to
get the header parsing function into a module not specific to the socket
TCTI code. This commit also adds a unit test for the new feature of the
receive function.

Signed-off-by: Philip Tricca <philip.b.tricca@intel.com>
diff --git a/Makefile-test.am b/Makefile-test.am
index dbe5c74..b4cebef 100644
--- a/Makefile-test.am
+++ b/Makefile-test.am
@@ -114,7 +114,7 @@
 
 if UNIT
 test_unit_util_CFLAGS = $(CMOCKA_CFLAGS) $(AM_CFLAGS)
-test_unit_util_LDADD = $(CMOCKA_LIBS) $(libutil)
+test_unit_util_LDADD = $(CMOCKA_LIBS) $(libutil) $(libmarshal)
 test_unit_util_LDFLAGS = -Wl,--wrap=write
 test_unit_util_SOURCES = test/unit/util.c
 
@@ -130,7 +130,7 @@
     tcti/sockets.c tcti/sockets.h test/unit/tcti-socket.c
 
 test_unit_socket_CFLAGS  = $(CMOCKA_CFLAGS) $(AM_CFLAGS)
-test_unit_socket_LDADD   = $(CMOCKA_LIBS) $(libutil)
+test_unit_socket_LDADD   = $(CMOCKA_LIBS) $(libutil) $(libmarshal)
 test_unit_socket_LDFLAGS = -Wl,--wrap=connect,--wrap=socket
 test_unit_socket_SOURCES = test/unit/socket.c tcti/sockets.c tcti/sockets.h
 
diff --git a/tcti/tcti.c b/tcti/tcti.c
index e90a7cb..9dd409b 100644
--- a/tcti/tcti.c
+++ b/tcti/tcti.c
@@ -91,7 +91,7 @@
     if (tcti_intel->state != TCTI_STATE_RECEIVE) {
         return TSS2_TCTI_RC_BAD_SEQUENCE;
     }
-    if (response_buffer == NULL || response_size == NULL) {
+    if (response_buffer == NULL && response_size == NULL) {
         return TSS2_TCTI_RC_BAD_REFERENCE;
     }
 
@@ -135,3 +135,38 @@
 {
     return TSS2_TCTI_RC_NOT_IMPLEMENTED;
 }
+
+TSS2_RC
+parse_header (
+    const uint8_t *buf,
+    tpm_header_t *header)
+{
+    TSS2_RC rc;
+    size_t offset = 0;
+
+    LOG_TRACE ("Parsing header from buffer: 0x%" PRIxPTR, (uintptr_t)buf);
+    rc = Tss2_MU_TPM2_ST_Unmarshal (buf,
+                                    TPM_HEADER_SIZE,
+                                    &offset,
+                                    &header->tag);
+    if (rc != TSS2_RC_SUCCESS) {
+        LOG_ERROR ("Failed to unmarshal tag.");
+        return rc;
+    }
+    rc = Tss2_MU_UINT32_Unmarshal (buf,
+                                   TPM_HEADER_SIZE,
+                                   &offset,
+                                   &header->size);
+    if (rc != TSS2_RC_SUCCESS) {
+        LOG_ERROR ("Failed to unmarshal command size.");
+        return rc;
+    }
+    rc = Tss2_MU_UINT32_Unmarshal (buf,
+                                   TPM_HEADER_SIZE,
+                                   &offset,
+                                   &header->code);
+    if (rc != TSS2_RC_SUCCESS) {
+        LOG_ERROR ("Failed to unmarshal command code.");
+    }
+    return rc;
+}
diff --git a/tcti/tcti.h b/tcti/tcti.h
index 041525b..bc05246 100644
--- a/tcti/tcti.h
+++ b/tcti/tcti.h
@@ -42,6 +42,7 @@
 
 #include <errno.h>
 #include <sapi/tpm20.h>
+#include <stdbool.h>
 
 #if defined(__linux__) || defined(__unix__) || defined(__APPLE__)
 #include <sys/socket.h>
@@ -105,6 +106,7 @@
 typedef struct {
     TSS2_TCTI_CONTEXT_COMMON_V2 v2;
     tcti_state_t state;
+    tpm_header_t header;
 
     struct {
         UINT32 reserved: 1; /* Used to be debugMsgEnabled which is deprecated */
@@ -125,7 +127,6 @@
 
     /* File descriptor for device file if real TPM is being used. */
     int devFile;
-    unsigned char responseBuffer[4096];
 } TSS2_TCTI_CONTEXT_INTEL;
 
 /*
@@ -182,5 +183,14 @@
     int fd,
     const uint8_t *buf,
     size_t size);
+/*
+ * Utility to function to parse the first 10 bytes of a buffer and populate
+ * the 'header' structure with the results. The provided buffer is assumed to
+ * be at least 10 bytes long.
+ */
+TSS2_RC
+parse_header (
+    const uint8_t *buf,
+    tpm_header_t *header);
 
 #endif
diff --git a/tcti/tcti_device.c b/tcti/tcti_device.c
index 1c27055..22d6611 100644
--- a/tcti/tcti_device.c
+++ b/tcti/tcti_device.c
@@ -90,7 +90,6 @@
     TSS2_TCTI_CONTEXT_INTEL *tcti_intel = tcti_context_intel_cast (tctiContext);
     TSS2_RC rc = TSS2_RC_SUCCESS;
     ssize_t  size;
-    unsigned int i;
 
     rc = tcti_receive_checks (tctiContext, response_size, response_buffer);
     if (rc != TSS2_RC_SUCCESS) {
@@ -103,45 +102,57 @@
         return TSS2_TCTI_RC_BAD_VALUE;
     }
 
-    if (tcti_intel->status.tagReceived == 0) {
+    /* Read header first to get size of response. */
+    if (tcti_intel->header.size == 0) {
+        uint8_t header_buf [TPM_HEADER_SIZE];
+        LOG_INFO ("Header not yet received, reading %zd byte header from fd %d",
+                  sizeof (header_buf), tcti_intel->devFile);
         size = TEMP_RETRY (read (tcti_intel->devFile,
-                                 tcti_intel->responseBuffer,
-                                 4096));
+                                 header_buf,
+                                 sizeof (header_buf)));
         if (size < 0) {
-            LOG_ERROR("send failed with error: %d", errno);
+            LOG_WARNING ("Failed to read response header. %d: %s",
+                         errno, strerror (errno));
             rc = TSS2_TCTI_RC_IO_ERROR;
             goto retLocalTpmReceive;
-        } else {
-            tcti_intel->status.tagReceived = 1;
-            tcti_intel->responseSize = size;
         }
-
-        tcti_intel->responseSize = size;
+        LOGBLOB_DEBUG (header_buf, TPM_HEADER_SIZE, "Response header received");
+        rc = parse_header (header_buf, &tcti_intel->header);
+        if (rc != TSS2_RC_SUCCESS) {
+            return rc;
+        }
+        LOG_INFO ("Received response header with size: %" PRIu32,
+                  tcti_intel->header.size);
     }
 
+    *response_size = tcti_intel->header.size;
     if (response_buffer == NULL) {
-        *response_size = tcti_intel->responseSize;
+        LOG_DEBUG ("response_buffer is null, returning size: %zd", *response_size);
         goto retLocalTpmReceive;
     }
-
-    if (*response_size < tcti_intel->responseSize) {
+    if (*response_size < tcti_intel->header.size) {
+        LOG_WARNING ("Size of user supplied response buffer %zd is less than "
+                     "the size of the response buffer: %" PRIu32,
+                     *response_size, tcti_intel->header.size);
         rc = TSS2_TCTI_RC_INSUFFICIENT_BUFFER;
-        *response_size = tcti_intel->responseSize;
         goto retLocalTpmReceive;
     }
-
-    *response_size = tcti_intel->responseSize;
-
-    for (i = 0; i < *response_size; i++) {
-        response_buffer[i] = tcti_intel->responseBuffer[i];
+    /* Read the rest of the response, minus the header that we already jave. */
+    size = TEMP_RETRY (read (tcti_intel->devFile,
+                             response_buffer,
+                             tcti_intel->header.size - TPM_HEADER_SIZE));
+    if (size < 0) {
+        LOG_WARNING ("Failed to read response body. %d: %s",
+                     errno, strerror (errno));
+        rc = TSS2_TCTI_RC_IO_ERROR;
+        goto retLocalTpmReceive;
     }
 
     LOGBLOB_DEBUG(response_buffer, tcti_intel->responseSize, "Response Received");
 
-    tcti_intel->status.commandSent = 0;
-
 retLocalTpmReceive:
     if (rc == TSS2_RC_SUCCESS && response_buffer != NULL ) {
+        tcti_intel->header.size = 0;
         tcti_intel->state = TCTI_STATE_TRANSMIT;
     }
 
@@ -220,6 +231,7 @@
     TSS2_TCTI_SET_LOCALITY (tctiContext) = tcti_device_set_locality;
     TSS2_TCTI_MAKE_STICKY (tctiContext) = tcti_make_sticky_not_implemented;
     tcti_intel->state = TCTI_STATE_TRANSMIT;
+    memset (&tcti_intel->header, 0, sizeof (tcti_intel->header));
 
     tcti_intel->status.locality = 3;
     tcti_intel->status.commandSent = 0;
diff --git a/tcti/tcti_socket.c b/tcti/tcti_socket.c
index 86b0338..b594cbf 100644
--- a/tcti/tcti_socket.c
+++ b/tcti/tcti_socket.c
@@ -59,46 +59,6 @@
 }
 
 /*
- * Utility to function to parse the first 10 bytes of a buffer and populate
- * the 'header' structure with the results. The provided buffer is assumed to
- * be at least 10 bytes long.
- */
-TSS2_RC
-parse_header (
-    const uint8_t *buf,
-    tpm_header_t *header)
-{
-    TSS2_RC rc;
-    size_t offset = 0;
-
-    LOG_TRACE ("Parsing header from buffer: 0x%" PRIxPTR, (uintptr_t)buf);
-    rc = Tss2_MU_TPM2_ST_Unmarshal (buf,
-                                    TPM_HEADER_SIZE,
-                                    &offset,
-                                    &header->tag);
-    if (rc != TSS2_RC_SUCCESS) {
-        LOG_ERROR ("Failed to unmarshal tag.");
-        return rc;
-    }
-    rc = Tss2_MU_UINT32_Unmarshal (buf,
-                                   TPM_HEADER_SIZE,
-                                   &offset,
-                                   &header->size);
-    if (rc != TSS2_RC_SUCCESS) {
-        LOG_ERROR ("Failed to unmarshal command size.");
-        return rc;
-    }
-    rc = Tss2_MU_UINT32_Unmarshal (buf,
-                                   TPM_HEADER_SIZE,
-                                   &offset,
-                                   &header->code);
-    if (rc != TSS2_RC_SUCCESS) {
-        LOG_ERROR ("Failed to unmarshal command code.");
-    }
-    return rc;
-}
-
-/*
  * This fucntion is used to send the simulator a sort of command message
  * that tells it we're about to send it a TPM command. This requires that
  * we first send it a 4 byte code that's defined by the simulator. Then
diff --git a/test/unit/tcti-device.c b/test/unit/tcti-device.c
index 4eb5907..b7f4890 100644
--- a/test/unit/tcti-device.c
+++ b/test/unit/tcti-device.c
@@ -36,9 +36,13 @@
 }
 /* wrap functions for read & write required to test receive / transmit */
 ssize_t
-__wrap_read (int fd, void *buffer, size_t count)
+__wrap_read (int fd, void *buf, size_t count)
 {
-    return mock_type (ssize_t);
+    ssize_t ret = mock_type (ssize_t);
+    uint8_t *buf_in = mock_type (uint8_t*);
+
+    memcpy (buf, buf_in, ret);
+    return ret;
 }
 ssize_t
 __wrap_write (int fd, const void *buffer, size_t buffer_size)
@@ -111,7 +115,7 @@
  * data received.
  */
 static void
-tcti_device_receive_success (void **state)
+tcti_device_receive_one_call_success (void **state)
 {
     data_t *data = *state;
     TSS2_RC rc;
@@ -119,7 +123,10 @@
 
     /* Keep state machine check in `receive` from returning error. */
     tcti_intel->state = TCTI_STATE_RECEIVE;
-    will_return (__wrap_read, data->data_size);
+    will_return (__wrap_read, TPM_HEADER_SIZE);
+    will_return (__wrap_read, data->buffer);
+    will_return (__wrap_read, data->data_size - TPM_HEADER_SIZE);
+    will_return (__wrap_read, &data->buffer [TPM_HEADER_SIZE]);
     rc = Tss2_Tcti_Receive (data->ctx,
                             &data->buffer_size,
                             data->buffer,
@@ -127,6 +134,33 @@
     assert_true (rc == TSS2_RC_SUCCESS);
     assert_int_equal (data->data_size, data->buffer_size);
 }
+static void
+tcti_device_receive_two_call_success (void **state)
+{
+    data_t *data = *state;
+    TSS2_RC rc;
+    TSS2_TCTI_CONTEXT_INTEL *tcti_intel = tcti_context_intel_cast (data->ctx);
+    size_t size = 0;
+
+    /* Keep state machine check in `receive` from returning error. */
+    tcti_intel->state = TCTI_STATE_RECEIVE;
+    will_return (__wrap_read, TPM_HEADER_SIZE);
+    will_return (__wrap_read, data->buffer);
+    will_return (__wrap_read, data->data_size - TPM_HEADER_SIZE);
+    will_return (__wrap_read, &data->buffer [TPM_HEADER_SIZE]);
+    rc = Tss2_Tcti_Receive (data->ctx,
+                            &size,
+                            NULL,
+                            TSS2_TCTI_TIMEOUT_BLOCK);
+    printf ("got size: %zd", size);
+    assert_int_equal (size, data->data_size);
+    assert_int_equal (rc, TSS2_RC_SUCCESS);
+    rc = Tss2_Tcti_Receive (data->ctx,
+                            &data->buffer_size,
+                            data->buffer,
+                            TSS2_TCTI_TIMEOUT_BLOCK);
+    assert_true (rc == TSS2_RC_SUCCESS);
+}
 /*
  * A test case for a successful call to the transmit function. This requires
  * that the context and the cmmand buffer be valid. The only indication of
@@ -151,9 +185,12 @@
     const struct CMUnitTest tests[] = {
         cmocka_unit_test (tcti_device_init_all_null_test),
         cmocka_unit_test(tcti_device_init_size_test),
-        cmocka_unit_test_setup_teardown (tcti_device_receive_success,
+        cmocka_unit_test_setup_teardown (tcti_device_receive_one_call_success,
                                   tcti_device_setup_with_command,
                                   tcti_device_teardown),
+        cmocka_unit_test_setup_teardown (tcti_device_receive_two_call_success,
+                                         tcti_device_setup_with_command,
+                                         tcti_device_teardown),
         cmocka_unit_test_setup_teardown (tcti_device_transmit_success,
                                   tcti_device_setup_with_command,
                                   tcti_device_teardown),