Merge pull request #6029 from yang-g/hpack_table

Use base64 encoded length for decoder table size calculation of binary headers
diff --git a/src/core/ext/transport/chttp2/transport/hpack_encoder.c b/src/core/ext/transport/chttp2/transport/hpack_encoder.c
index 807cb5c..555027c 100644
--- a/src/core/ext/transport/chttp2/transport/hpack_encoder.c
+++ b/src/core/ext/transport/chttp2/transport/hpack_encoder.c
@@ -49,6 +49,7 @@
 #include "src/core/ext/transport/chttp2/transport/hpack_table.h"
 #include "src/core/ext/transport/chttp2/transport/timeout_encoding.h"
 #include "src/core/ext/transport/chttp2/transport/varint.h"
+#include "src/core/lib/transport/metadata.h"
 #include "src/core/lib/transport/static_metadata.h"
 
 #define HASH_FRAGMENT_1(x) ((x)&255)
@@ -182,8 +183,7 @@
   uint32_t key_hash = elem->key->hash;
   uint32_t elem_hash = GRPC_MDSTR_KV_HASH(key_hash, elem->value->hash);
   uint32_t new_index = c->tail_remote_index + c->table_elems + 1;
-  size_t elem_size = 32 + GPR_SLICE_LENGTH(elem->key->slice) +
-                     GPR_SLICE_LENGTH(elem->value->slice);
+  size_t elem_size = grpc_mdelem_get_size_in_hpack_table(elem);
 
   GPR_ASSERT(elem_size < 65536);
 
@@ -399,8 +399,7 @@
   }
 
   /* should this elem be in the table? */
-  decoder_space_usage = 32 + GPR_SLICE_LENGTH(elem->key->slice) +
-                        GPR_SLICE_LENGTH(elem->value->slice);
+  decoder_space_usage = grpc_mdelem_get_size_in_hpack_table(elem);
   should_add_elem = decoder_space_usage < MAX_DECODER_SPACE_USAGE &&
                     c->filter_elems[HASH_FRAGMENT_1(elem_hash)] >=
                         c->filter_elems_sum / ONE_ON_ADD_PROBABILITY;
diff --git a/src/core/lib/transport/metadata.c b/src/core/lib/transport/metadata.c
index d037c61..5d9e7e7 100644
--- a/src/core/lib/transport/metadata.c
+++ b/src/core/lib/transport/metadata.c
@@ -38,6 +38,7 @@
 #include <string.h>
 
 #include <grpc/compression.h>
+#include <grpc/grpc.h>
 #include <grpc/support/alloc.h>
 #include <grpc/support/atm.h>
 #include <grpc/support/log.h>
@@ -79,6 +80,7 @@
 
 typedef void (*destroy_user_data_func)(void *user_data);
 
+#define SIZE_IN_DECODER_TABLE_NOT_SET -1
 /* Shadow structure for grpc_mdstr for non-static values */
 typedef struct internal_string {
   /* must be byte compatible with grpc_mdstr */
@@ -93,6 +95,8 @@
 
   gpr_slice base64_and_huffman;
 
+  gpr_atm size_in_decoder_table;
+
   struct internal_string *bucket_next;
 } internal_string;
 
@@ -413,6 +417,7 @@
   }
   s->has_base64_and_huffman_encoded = 0;
   s->hash = hash;
+  s->size_in_decoder_table = SIZE_IN_DECODER_TABLE_NOT_SET;
   s->bucket_next = shard->strs[idx];
   shard->strs[idx] = s;
 
@@ -582,6 +587,39 @@
       grpc_mdstr_from_string(key), grpc_mdstr_from_buffer(value, value_length));
 }
 
+static size_t get_base64_encoded_size(size_t raw_length) {
+  static const uint8_t tail_xtra[3] = {0, 2, 3};
+  return raw_length / 3 * 4 + tail_xtra[raw_length % 3];
+}
+
+size_t grpc_mdelem_get_size_in_hpack_table(grpc_mdelem *elem) {
+  size_t overhead_and_key = 32 + GPR_SLICE_LENGTH(elem->key->slice);
+  size_t value_len = GPR_SLICE_LENGTH(elem->value->slice);
+  if (is_mdstr_static(elem->value)) {
+    if (grpc_is_binary_header(
+            (const char *)GPR_SLICE_START_PTR(elem->key->slice),
+            GPR_SLICE_LENGTH(elem->key->slice))) {
+      return overhead_and_key + get_base64_encoded_size(value_len);
+    } else {
+      return overhead_and_key + value_len;
+    }
+  } else {
+    internal_string *is = (internal_string *)elem->value;
+    gpr_atm current_size = gpr_atm_acq_load(&is->size_in_decoder_table);
+    if (current_size == SIZE_IN_DECODER_TABLE_NOT_SET) {
+      if (grpc_is_binary_header(
+              (const char *)GPR_SLICE_START_PTR(elem->key->slice),
+              GPR_SLICE_LENGTH(elem->key->slice))) {
+        current_size = (gpr_atm)get_base64_encoded_size(value_len);
+      } else {
+        current_size = (gpr_atm)value_len;
+      }
+      gpr_atm_rel_store(&is->size_in_decoder_table, current_size);
+    }
+    return overhead_and_key + (size_t)current_size;
+  }
+}
+
 grpc_mdelem *grpc_mdelem_ref(grpc_mdelem *gmd DEBUG_ARGS) {
   internal_metadata *md = (internal_metadata *)gmd;
   if (is_mdelem_static(gmd)) return gmd;
diff --git a/src/core/lib/transport/metadata.h b/src/core/lib/transport/metadata.h
index 6a02437..8630b0e 100644
--- a/src/core/lib/transport/metadata.h
+++ b/src/core/lib/transport/metadata.h
@@ -110,6 +110,8 @@
                                                 const uint8_t *value,
                                                 size_t value_length);
 
+size_t grpc_mdelem_get_size_in_hpack_table(grpc_mdelem *elem);
+
 /* Mutator and accessor for grpc_mdelem user data. The destructor function
    is used as a type tag and is checked during user_data fetch. */
 void *grpc_mdelem_get_user_data(grpc_mdelem *md,
diff --git a/test/core/transport/chttp2/hpack_encoder_test.c b/test/core/transport/chttp2/hpack_encoder_test.c
index 5ff592c..186bb64 100644
--- a/test/core/transport/chttp2/hpack_encoder_test.c
+++ b/test/core/transport/chttp2/hpack_encoder_test.c
@@ -184,6 +184,37 @@
   verify(0, 0, 0, "000007 0104 deadbeef 40 026161 026261", 1, "aa", "ba");
 }
 
+static void verify_table_size_change_match_elem_size(const char *key,
+                                                     const char *value) {
+  gpr_slice_buffer output;
+  grpc_mdelem *elem = grpc_mdelem_from_strings(key, value);
+  size_t elem_size = grpc_mdelem_get_size_in_hpack_table(elem);
+  size_t initial_table_size = g_compressor.table_size;
+  grpc_linked_mdelem *e = gpr_malloc(sizeof(*e));
+  grpc_metadata_batch b;
+  grpc_metadata_batch_init(&b);
+  e[0].md = elem;
+  e[0].prev = NULL;
+  e[0].next = NULL;
+  b.list.head = &e[0];
+  b.list.tail = &e[0];
+  gpr_slice_buffer_init(&output);
+
+  grpc_transport_one_way_stats stats;
+  memset(&stats, 0, sizeof(stats));
+  grpc_chttp2_encode_header(&g_compressor, 0xdeadbeef, &b, 0, &stats, &output);
+  gpr_slice_buffer_destroy(&output);
+  grpc_metadata_batch_destroy(&b);
+
+  GPR_ASSERT(g_compressor.table_size == elem_size + initial_table_size);
+  gpr_free(e);
+}
+
+static void test_encode_header_size(void) {
+  verify_table_size_change_match_elem_size("hello", "world");
+  verify_table_size_change_match_elem_size("hello-bin", "world");
+}
+
 static void run_test(void (*test)(), const char *name) {
   gpr_log(GPR_INFO, "RUN TEST: %s", name);
   grpc_chttp2_hpack_compressor_init(&g_compressor);
@@ -198,6 +229,7 @@
   grpc_init();
   TEST(test_basic_headers);
   TEST(test_decode_table_overflow);
+  TEST(test_encode_header_size);
   grpc_shutdown();
   for (i = 0; i < num_to_delete; i++) {
     gpr_free(to_delete[i]);
diff --git a/test/core/transport/metadata_test.c b/test/core/transport/metadata_test.c
index 809fa87..a4e9694 100644
--- a/test/core/transport/metadata_test.c
+++ b/test/core/transport/metadata_test.c
@@ -34,6 +34,7 @@
 #include "src/core/lib/transport/metadata.h"
 
 #include <stdio.h>
+#include <string.h>
 
 #include <grpc/grpc.h>
 #include <grpc/support/alloc.h>
@@ -42,6 +43,7 @@
 
 #include "src/core/ext/transport/chttp2/transport/bin_encoder.h"
 #include "src/core/lib/support/string.h"
+#include "src/core/lib/transport/static_metadata.h"
 #include "test/core/util/test_config.h"
 
 #define LOG_TEST(x) gpr_log(GPR_INFO, "%s", x)
@@ -260,6 +262,54 @@
   grpc_shutdown();
 }
 
+static void verify_ascii_header_size(const char *key, const char *value) {
+  grpc_mdelem *elem = grpc_mdelem_from_strings(key, value);
+  size_t elem_size = grpc_mdelem_get_size_in_hpack_table(elem);
+  size_t expected_size = 32 + strlen(key) + strlen(value);
+  GPR_ASSERT(expected_size == elem_size);
+  GRPC_MDELEM_UNREF(elem);
+}
+
+static void verify_binary_header_size(const char *key, const uint8_t *value,
+                                      size_t value_len) {
+  grpc_mdelem *elem = grpc_mdelem_from_string_and_buffer(key, value, value_len);
+  GPR_ASSERT(grpc_is_binary_header(key, strlen(key)));
+  size_t elem_size = grpc_mdelem_get_size_in_hpack_table(elem);
+  gpr_slice value_slice =
+      gpr_slice_from_copied_buffer((const char *)value, value_len);
+  gpr_slice base64_encoded = grpc_chttp2_base64_encode(value_slice);
+  size_t expected_size = 32 + strlen(key) + GPR_SLICE_LENGTH(base64_encoded);
+  GPR_ASSERT(expected_size == elem_size);
+  gpr_slice_unref(value_slice);
+  gpr_slice_unref(base64_encoded);
+  GRPC_MDELEM_UNREF(elem);
+}
+
+#define BUFFER_SIZE 64
+static void test_mdelem_sizes_in_hpack(void) {
+  LOG_TEST("test_mdelem_size");
+  grpc_init();
+
+  uint8_t binary_value[BUFFER_SIZE] = {0};
+  for (uint8_t i = 0; i < BUFFER_SIZE; i++) {
+    binary_value[i] = i;
+  }
+
+  verify_ascii_header_size("hello", "world");
+  verify_ascii_header_size("hello", "worldxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx");
+  verify_ascii_header_size(":scheme", "http");
+
+  for (uint8_t i = 0; i < BUFFER_SIZE; i++) {
+    verify_binary_header_size("hello-bin", binary_value, i);
+  }
+
+  const char *static_metadata = grpc_static_metadata_strings[0];
+  memcpy(binary_value, static_metadata, strlen(static_metadata));
+  verify_binary_header_size("hello-bin", binary_value, strlen(static_metadata));
+
+  grpc_shutdown();
+}
+
 int main(int argc, char **argv) {
   grpc_test_init(argc, argv);
   test_no_op();
@@ -272,5 +322,6 @@
   test_slices_work();
   test_base64_and_huffman_works();
   test_user_data_works();
+  test_mdelem_sizes_in_hpack();
   return 0;
 }