pw_tokenizer: Move Base64 functions to class

- Move the Base64 tokenization functions to the tokenizer class, which
  is much cleaner to work with.
- Have AutoUpdatingDetokenizer derive from Detokenizer. This makes
  working with detokenizers simpler.

Change-Id: Ic6bd9354c34f21a9931c83200e7c98e05911b6a2
Reviewed-on: https://pigweed-review.googlesource.com/c/pigweed/pigweed/+/47761
Reviewed-by: Keir Mierle <keir@google.com>
Commit-Queue: Keir Mierle <keir@google.com>
Commit-Queue: Wyatt Hepler <hepler@google.com>
diff --git a/pw_tokenizer/docs.rst b/pw_tokenizer/docs.rst
index e2d714a..8967f12 100644
--- a/pw_tokenizer/docs.rst
+++ b/pw_tokenizer/docs.rst
@@ -876,8 +876,8 @@
 
 Decoding
 --------
-Base64 decoding and detokenizing is supported in the Python detokenizer through
-the ``detokenize_base64`` and related functions.
+The Python ``Detokenizer`` class supprts decoding and detokenizing prefixed
+Base64 messages with ``detokenize_base64`` and related methods.
 
 .. tip::
   The Python detokenization tools support recursive detokenization for prefixed
@@ -1019,7 +1019,7 @@
   * Provide simple wrapper shell scripts that fill in arguments for the
     project. For example, point ``detokenize.py`` to the project's token
     databases.
-  * Use ``pw_tokenizer.AutoReloadingDetokenizer`` to decode in
+  * Use ``pw_tokenizer.AutoUpdatingDetokenizer`` to decode in
     continuously-running tools, so that users don't have to restart the tool
     when the token database updates.
   * Integrate detokenization everywhere it is needed. Integrating the tools
diff --git a/pw_tokenizer/py/BUILD.gn b/pw_tokenizer/py/BUILD.gn
index e9e4468..00128d1 100644
--- a/pw_tokenizer/py/BUILD.gn
+++ b/pw_tokenizer/py/BUILD.gn
@@ -31,8 +31,6 @@
     "pw_tokenizer/encode.py",
     "pw_tokenizer/serial_detokenizer.py",
     "pw_tokenizer/tokens.py",
-    "tokenized_string_decoding_test_data.py",
-    "varint_test_data.py",
   ]
   tests = [
     "database_test.py",
@@ -40,9 +38,12 @@
     "detokenize_test.py",
     "elf_reader_test.py",
     "encode_test.py",
+    "tokenized_string_decoding_test_data.py",
     "tokens_test.py",
+    "varint_test_data.py",
   ]
   inputs = [
+    "elf_reader_test_binary.elf",
     "example_binary_with_tokenized_strings.elf",
     "example_legacy_binary_with_tokenized_strings.elf",
   ]
diff --git a/pw_tokenizer/py/detokenize_test.py b/pw_tokenizer/py/detokenize_test.py
index 1de6160..23f0f64 100755
--- a/pw_tokenizer/py/detokenize_test.py
+++ b/pw_tokenizer/py/detokenize_test.py
@@ -526,22 +526,21 @@
     def test_detokenize_base64_live(self):
         for data, expected in self.TEST_CASES:
             output = io.BytesIO()
-            detokenize.detokenize_base64_live(self.detok, io.BytesIO(data),
-                                              output, '$')
+            self.detok.detokenize_base64_live(io.BytesIO(data), output, '$')
 
             self.assertEqual(expected, output.getvalue())
 
     def test_detokenize_base64_to_file(self):
         for data, expected in self.TEST_CASES:
             output = io.BytesIO()
-            detokenize.detokenize_base64_to_file(self.detok, data, output, '$')
+            self.detok.detokenize_base64_to_file(data, output, '$')
 
             self.assertEqual(expected, output.getvalue())
 
     def test_detokenize_base64(self):
         for data, expected in self.TEST_CASES:
-            self.assertEqual(
-                expected, detokenize.detokenize_base64(self.detok, data, b'$'))
+            self.assertEqual(expected,
+                             self.detok.detokenize_base64(data, b'$'))
 
 
 class DetokenizeBase64InfiniteRecursion(unittest.TestCase):
@@ -559,28 +558,24 @@
     def test_detokenize_self_recursion(self):
         for depth in range(5):
             self.assertEqual(
-                detokenize.detokenize_base64(self.detok,
-                                             b'This one is deep: $AAAAAA==',
+                self.detok.detokenize_base64(b'This one is deep: $AAAAAA==',
                                              recursion=depth),
                 b'This one is deep: $AAAAAA==')
 
     def test_detokenize_self_recursion_default(self):
         self.assertEqual(
-            detokenize.detokenize_base64(self.detok,
-                                         b'This one is deep: $AAAAAA=='),
+            self.detok.detokenize_base64(b'This one is deep: $AAAAAA=='),
             b'This one is deep: $AAAAAA==')
 
     def test_detokenize_cyclic_recursion_even(self):
         self.assertEqual(
-            detokenize.detokenize_base64(self.detok,
-                                         b'I said "$AQAAAA=="',
-                                         recursion=2), b'I said "$AgAAAA=="')
+            self.detok.detokenize_base64(b'I said "$AQAAAA=="', recursion=2),
+            b'I said "$AgAAAA=="')
 
     def test_detokenize_cyclic_recursion_odd(self):
         self.assertEqual(
-            detokenize.detokenize_base64(self.detok,
-                                         b'I said "$AQAAAA=="',
-                                         recursion=3), b'I said "$AwAAAA=="')
+            self.detok.detokenize_base64(b'I said "$AQAAAA=="', recursion=3),
+            b'I said "$AwAAAA=="')
 
 
 if __name__ == '__main__':
diff --git a/pw_tokenizer/py/pw_tokenizer/detokenize.py b/pw_tokenizer/py/pw_tokenizer/detokenize.py
index 770a4f0..f9b459f 100755
--- a/pw_tokenizer/py/pw_tokenizer/detokenize.py
+++ b/pw_tokenizer/py/pw_tokenizer/detokenize.py
@@ -56,9 +56,12 @@
         os.path.abspath(__file__))))
     from pw_tokenizer import database, decode, encode, tokens
 
-ENCODED_TOKEN = struct.Struct('<I')
 _LOG = logging.getLogger('pw_tokenizer')
 
+ENCODED_TOKEN = struct.Struct('<I')
+BASE64_PREFIX = encode.BASE64_PREFIX.encode()
+DEFAULT_RECURSION = 9
+
 
 class DetokenizedString:
     """A detokenized string, with all results if there are collisions."""
@@ -179,12 +182,17 @@
           show_errors: if True, an error message is used in place of the %
               conversion specifier when an argument fails to decode
         """
-        self.database = database.load_token_database(*token_database_or_elf)
         self.show_errors = show_errors
 
         # Cache FormatStrings for faster lookup & formatting.
         self._cache: Dict[int, List[_TokenizedFormatString]] = {}
 
+        self._initialize_database(token_database_or_elf)
+
+    def _initialize_database(self, token_sources: Iterable) -> None:
+        self.database = database.load_token_database(*token_sources)
+        self._cache.clear()
+
     def lookup(self, token: int) -> List[_TokenizedFormatString]:
         """Returns (TokenizedStringEntry, FormatString) list for matches."""
         try:
@@ -207,8 +215,88 @@
         return DetokenizedString(token, self.lookup(token), encoded_message,
                                  self.show_errors)
 
+    def detokenize_base64(self,
+                          data: bytes,
+                          prefix: Union[str, bytes] = BASE64_PREFIX,
+                          recursion: int = DEFAULT_RECURSION) -> bytes:
+        """Decodes and replaces prefixed Base64 messages in the provided data.
 
-class AutoUpdatingDetokenizer:
+        Args:
+          data: the binary data to decode
+          prefix: one-character byte string that signals the start of a message
+          recursion: how many levels to recursively decode
+
+        Returns:
+          copy of the data with all recognized tokens decoded
+        """
+        output = io.BytesIO()
+        self.detokenize_base64_to_file(data, output, prefix, recursion)
+        return output.getvalue()
+
+    def detokenize_base64_to_file(self,
+                                  data: bytes,
+                                  output: BinaryIO,
+                                  prefix: Union[str, bytes] = BASE64_PREFIX,
+                                  recursion: int = DEFAULT_RECURSION) -> None:
+        """Decodes prefixed Base64 messages in data; decodes to output file."""
+        prefix = prefix.encode() if isinstance(prefix, str) else prefix
+        output.write(
+            _base64_message_regex(prefix).sub(
+                self._detokenize_prefixed_base64(prefix, recursion), data))
+
+    def detokenize_base64_live(self,
+                               input_file: BinaryIO,
+                               output: BinaryIO,
+                               prefix: Union[str, bytes] = BASE64_PREFIX,
+                               recursion: int = DEFAULT_RECURSION) -> None:
+        """Reads chars one-at-a-time, decoding messages; SLOW for big files."""
+        prefix_bytes = prefix.encode() if isinstance(prefix, str) else prefix
+
+        base64_message = _base64_message_regex(prefix_bytes)
+
+        def transform(data: bytes) -> bytes:
+            return base64_message.sub(
+                self._detokenize_prefixed_base64(prefix_bytes, recursion),
+                data)
+
+        for message in PrefixedMessageDecoder(
+                prefix,
+                string.ascii_letters + string.digits + '+/-_=').transform(
+                    input_file, transform):
+            output.write(message)
+
+            # Flush each line to prevent delays when piping between processes.
+            if b'\n' in message:
+                output.flush()
+
+    def _detokenize_prefixed_base64(
+            self, prefix: bytes,
+            recursion: int) -> Callable[[Match[bytes]], bytes]:
+        """Returns a function that decodes prefixed Base64."""
+        def decode_and_detokenize(match: Match[bytes]) -> bytes:
+            """Decodes prefixed base64 with this detokenizer."""
+            original = match.group(0)
+
+            try:
+                detokenized_string = self.detokenize(
+                    base64.b64decode(original[1:], validate=True))
+                if detokenized_string.matches():
+                    result = str(detokenized_string).encode()
+
+                    if recursion > 0 and original != result:
+                        result = self.detokenize_base64(
+                            result, prefix, recursion - 1)
+
+                    return result
+            except binascii.Error:
+                pass
+
+            return original
+
+        return decode_and_detokenize
+
+
+class AutoUpdatingDetokenizer(Detokenizer):
     """Loads and updates a detokenizer from database paths."""
     class _DatabasePath:
         """Tracks the modified time of a path or file object."""
@@ -243,22 +331,19 @@
         self.paths = tuple(self._DatabasePath(path) for path in paths_or_files)
         self.min_poll_period_s = min_poll_period_s
         self._last_checked_time: float = time.time()
-        self._detokenizer = Detokenizer(*(path.load() for path in self.paths))
+        super().__init__(*(path.load() for path in self.paths))
 
-    def detokenize(self, data: bytes) -> DetokenizedString:
-        """Updates the token database if it has changed, then detokenizes."""
+    def _reload_if_changed(self) -> None:
         if time.time() - self._last_checked_time >= self.min_poll_period_s:
             self._last_checked_time = time.time()
 
             if any(path.updated() for path in self.paths):
                 _LOG.info('Changes detected; reloading token database')
-                self._detokenizer = Detokenizer(*(path.load()
-                                                  for path in self.paths))
+                self._initialize_database(path.load() for path in self.paths)
 
-        return self._detokenizer.detokenize(data)
-
-
-_Detokenizer = Union[Detokenizer, AutoUpdatingDetokenizer]
+    def lookup(self, token: int) -> List[_TokenizedFormatString]:
+        self._reload_if_changed()
+        return super().lookup(token)
 
 
 class PrefixedMessageDecoder:
@@ -328,37 +413,6 @@
             yield transform(chunk) if is_message else chunk
 
 
-def _detokenize_prefixed_base64(
-        detokenizer: _Detokenizer, prefix: bytes,
-        recursion: int) -> Callable[[Match[bytes]], bytes]:
-    """Returns a function that decodes prefixed Base64 with the detokenizer."""
-    def decode_and_detokenize(match: Match[bytes]) -> bytes:
-        """Decodes prefixed base64 with the provided detokenizer."""
-        original = match.group(0)
-
-        try:
-            detokenized_string = detokenizer.detokenize(
-                base64.b64decode(original[1:], validate=True))
-            if detokenized_string.matches():
-                result = str(detokenized_string).encode()
-
-                if recursion > 0 and original != result:
-                    result = detokenize_base64(detokenizer, result, prefix,
-                                               recursion - 1)
-
-                return result
-        except binascii.Error:
-            pass
-
-        return original
-
-    return decode_and_detokenize
-
-
-BASE64_PREFIX = encode.BASE64_PREFIX.encode()
-DEFAULT_RECURSION = 9
-
-
 def _base64_message_regex(prefix: bytes) -> Pattern[bytes]:
     """Returns a regular expression for prefixed base64 tokenized strings."""
     return re.compile(
@@ -370,64 +424,14 @@
             br'(?:[A-Za-z0-9+/\-_]{3}=|[A-Za-z0-9+/\-_]{2}==)?'))
 
 
-def detokenize_base64_live(detokenizer: _Detokenizer,
-                           input_file: BinaryIO,
-                           output: BinaryIO,
-                           prefix: Union[str, bytes] = BASE64_PREFIX,
-                           recursion: int = DEFAULT_RECURSION) -> None:
-    """Reads chars one-at-a-time and decodes messages; SLOW for big files."""
-    prefix_bytes = prefix.encode() if isinstance(prefix, str) else prefix
-
-    base64_message = _base64_message_regex(prefix_bytes)
-
-    def transform(data: bytes) -> bytes:
-        return base64_message.sub(
-            _detokenize_prefixed_base64(detokenizer, prefix_bytes, recursion),
-            data)
-
-    for message in PrefixedMessageDecoder(
-            prefix, string.ascii_letters + string.digits + '+/-_=').transform(
-                input_file, transform):
-        output.write(message)
-
-        # Flush each line to prevent delays when piping between processes.
-        if b'\n' in message:
-            output.flush()
-
-
-def detokenize_base64_to_file(detokenizer: _Detokenizer,
-                              data: bytes,
-                              output: BinaryIO,
-                              prefix: Union[str, bytes] = BASE64_PREFIX,
-                              recursion: int = DEFAULT_RECURSION) -> None:
-    """Decodes prefixed Base64 messages in data; decodes to an output file."""
-    prefix = prefix.encode() if isinstance(prefix, str) else prefix
-    output.write(
-        _base64_message_regex(prefix).sub(
-            _detokenize_prefixed_base64(detokenizer, prefix, recursion), data))
-
-
-def detokenize_base64(detokenizer: _Detokenizer,
+def detokenize_base64(detokenizer: Detokenizer,
                       data: bytes,
                       prefix: Union[str, bytes] = BASE64_PREFIX,
                       recursion: int = DEFAULT_RECURSION) -> bytes:
-    """Decodes and replaces prefixed Base64 messages in the provided data.
-
-    Args:
-      detokenizer: the detokenizer with which to decode messages
-      data: the binary data to decode
-      prefix: one-character byte string that signals the start of a message
-      recursion: how many levels to recursively decode
-
-    Returns:
-      copy of the data with all recognized tokens decoded
-    """
-    output = io.BytesIO()
-    detokenize_base64_to_file(detokenizer, data, output, prefix, recursion)
-    return output.getvalue()
+    return detokenizer.detokenize_base64(data, prefix, recursion)
 
 
-def _follow_and_detokenize_file(detokenizer: _Detokenizer,
+def _follow_and_detokenize_file(detokenizer: Detokenizer,
                                 file: BinaryIO,
                                 output: BinaryIO,
                                 prefix: Union[str, bytes],
@@ -438,7 +442,7 @@
         while True:
             data = file.read()
             if data:
-                detokenize_base64_to_file(detokenizer, data, output, prefix)
+                detokenizer.detokenize_base64_to_file(data, output, prefix)
                 output.flush()
             else:
                 time.sleep(poll_period_s)
@@ -463,11 +467,11 @@
         _follow_and_detokenize_file(detokenizer, input_file, output, prefix)
     elif input_file.seekable():
         # Process seekable files all at once, which is MUCH faster.
-        detokenize_base64_to_file(detokenizer, input_file.read(), output,
-                                  prefix)
+        detokenizer.detokenize_base64_to_file(input_file.read(), output,
+                                              prefix)
     else:
         # For non-seekable inputs (e.g. pipes), read one character at a time.
-        detokenize_base64_live(detokenizer, input_file, output, prefix)
+        detokenizer.detokenize_base64_live(input_file, output, prefix)
 
 
 def _parse_args() -> argparse.Namespace:
diff --git a/pw_tokenizer/py/pw_tokenizer/serial_detokenizer.py b/pw_tokenizer/py/pw_tokenizer/serial_detokenizer.py
index e010225..234ca64 100644
--- a/pw_tokenizer/py/pw_tokenizer/serial_detokenizer.py
+++ b/pw_tokenizer/py/pw_tokenizer/serial_detokenizer.py
@@ -74,8 +74,7 @@
     serial_device = serial.Serial(port=device, baudrate=baudrate)
 
     try:
-        detokenize.detokenize_base64_live(detokenizer, serial_device, output,
-                                          prefix)
+        detokenizer.detokenize_base64_live(serial_device, output, prefix)
     except KeyboardInterrupt:
         output.flush()