pw_tokenizer: Improve database parsing error messages

- Improve the error messages from failing to parse a token database.
- Add tests for the DatabaseFile class.

Change-Id: Ic39c435bebd906d56572d3501309cb32fa8bd128
Reviewed-on: https://pigweed-review.googlesource.com/c/pigweed/pigweed/+/23001
Reviewed-by: Anthony DiGirolamo <tonymd@google.com>
Commit-Queue: Wyatt Hepler <hepler@google.com>
diff --git a/pw_tokenizer/py/pw_tokenizer/database.py b/pw_tokenizer/py/pw_tokenizer/database.py
index 3bc3310..2c879ae 100755
--- a/pw_tokenizer/py/pw_tokenizer/database.py
+++ b/pw_tokenizer/py/pw_tokenizer/database.py
@@ -373,13 +373,18 @@
                     paths.update(expand_paths_or_globs(value))
 
             for path in paths:
-                try:
-                    databases.append(load_token_database(path))
-                except:
-                    _LOG.exception('Failed to load token database %s', path)
-                    raise
-        except (FileNotFoundError, ValueError) as err:
+                databases.append(load_token_database(path))
+        except tokens.DatabaseFormatError as err:
+            parser.error(
+                f'argument elf_or_token_database: {path} is not a supported '
+                'token database file. Only ELF files or token databases (CSV '
+                f'or binary format) are supported. {err}. ')
+        except FileNotFoundError as err:
             parser.error(f'argument elf_or_token_database: {err}')
+        except:  # pylint: disable=bare-except
+            _LOG.exception('Failed to load token database %s', path)
+            parser.error('argument elf_or_token_database: '
+                         f'Error occurred while loading token database {path}')
 
         setattr(namespace, self.dest, databases)
 
diff --git a/pw_tokenizer/py/pw_tokenizer/tokens.py b/pw_tokenizer/py/pw_tokenizer/tokens.py
index a0c9b66..570b021 100644
--- a/pw_tokenizer/py/pw_tokenizer/tokens.py
+++ b/pw_tokenizer/py/pw_tokenizer/tokens.py
@@ -345,6 +345,10 @@
 BINARY_FORMAT = _BinaryFileFormat()
 
 
+class DatabaseFormatError(Exception):
+    """Failed to parse a token database file."""
+
+
 def file_is_binary_database(fd: BinaryIO) -> bool:
     """True if the file starts with the binary token database magic string."""
     try:
@@ -356,15 +360,37 @@
         return False
 
 
+def _check_that_file_is_csv_database(path: Path) -> None:
+    """Raises an error unless the path appears to be a CSV token database."""
+    try:
+        with path.open('rb') as fd:
+            data = fd.read(8)  # Read 8 bytes, which should be the first token.
+
+        if not data:
+            return  # File is empty, which is valid CSV.
+
+        if len(data) != 8:
+            raise DatabaseFormatError(
+                f'Attempted to read {path} as a CSV token database, but the '
+                f'file is too short ({len(data)} B)')
+
+        # Make sure the first 8 chars are a valid hexadecimal number.
+        _ = int(data.decode(), 16)
+    except (IOError, UnicodeDecodeError, ValueError) as err:
+        raise DatabaseFormatError(
+            f'Encountered error while reading {path} as a CSV token database'
+        ) from err
+
+
 def parse_binary(fd: BinaryIO) -> Iterable[TokenizedStringEntry]:
     """Parses TokenizedStringEntries from a binary token database file."""
     magic, entry_count = BINARY_FORMAT.header.unpack(
         fd.read(BINARY_FORMAT.header.size))
 
     if magic != BINARY_FORMAT.magic:
-        raise ValueError(
-            'Magic number mismatch (found {!r}, expected {!r})'.format(
-                magic, BINARY_FORMAT.magic))
+        raise DatabaseFormatError(
+            f'Binary token database magic number mismatch (found {magic!r}, '
+            f'expected {BINARY_FORMAT.magic!r}) while reading from {fd}')
 
     entries = []
 
@@ -441,6 +467,7 @@
                 return
 
         # Read the path as a CSV file.
+        _check_that_file_is_csv_database(self.path)
         with self.path.open('r', newline='') as file:
             super().__init__(parse_csv(file))
             self._export = write_csv
diff --git a/pw_tokenizer/py/tokens_test.py b/pw_tokenizer/py/tokens_test.py
index 0154ca7..58014e2 100755
--- a/pw_tokenizer/py/tokens_test.py
+++ b/pw_tokenizer/py/tokens_test.py
@@ -17,6 +17,8 @@
 import datetime
 import io
 import logging
+from pathlib import Path
+import tempfile
 import unittest
 
 from pw_tokenizer import tokens
@@ -87,7 +89,7 @@
 """
 
 
-def read_db_from_csv(csv_str):
+def read_db_from_csv(csv_str: str) -> tokens.Database:
     with io.StringIO(csv_str) as csv_db:
         return tokens.Database(tokens.parse_csv(csv_db))
 
@@ -385,6 +387,47 @@
         self.assertEqual(str(db), CSV_DATABASE)
 
 
+class TestDatabaseFile(unittest.TestCase):
+    """Tests the DatabaseFile class."""
+    def setUp(self):
+        file = tempfile.NamedTemporaryFile(delete=False)
+        file.close()
+        self._path = Path(file.name)
+
+    def tearDown(self):
+        self._path.unlink()
+
+    def test_update_csv_file(self):
+        self._path.write_text(CSV_DATABASE)
+        db = tokens.DatabaseFile(self._path)
+        self.assertEqual(str(db), CSV_DATABASE)
+
+        db.add([tokens.TokenizedStringEntry(0xffffffff, 'New entry!')])
+
+        db.write_to_file()
+
+        self.assertEqual(self._path.read_text(),
+                         CSV_DATABASE + 'ffffffff,          ,"New entry!"\n')
+
+    def test_csv_file_too_short_raises_exception(self):
+        self._path.write_text('1234')
+
+        with self.assertRaises(tokens.DatabaseFormatError):
+            tokens.DatabaseFile(self._path)
+
+    def test_csv_invalid_format_raises_exception(self):
+        self._path.write_text('MK34567890')
+
+        with self.assertRaises(tokens.DatabaseFormatError):
+            tokens.DatabaseFile(self._path)
+
+    def test_csv_not_utf8(self):
+        self._path.write_bytes(b'\x80' * 20)
+
+        with self.assertRaises(tokens.DatabaseFormatError):
+            tokens.DatabaseFile(self._path)
+
+
 class TestFilter(unittest.TestCase):
     """Tests the filtering functionality."""
     def setUp(self):