Added mypy for static type checking
diff --git a/tests/test_cli.py b/tests/test_cli.py
index 7cf7ed4..1cd92c2 100644
--- a/tests/test_cli.py
+++ b/tests/test_cli.py
@@ -4,41 +4,37 @@
from __future__ import print_function
-import unittest
-import sys
import functools
-from contextlib import contextmanager
-
+import io
import os
-from io import StringIO, BytesIO
+import sys
+import typing
+import unittest
+from contextlib import contextmanager, redirect_stdout, redirect_stderr
import rsa
import rsa.cli
import rsa.util
-def make_buffer() -> StringIO:
- buf = StringIO()
- buf.buffer = BytesIO()
- return buf
-
-
-def get_bytes_out(out: StringIO) -> bytes:
- # Python 3.x writes 'bytes' to stdout.buffer
- return out.buffer.getvalue()
-
-
@contextmanager
-def captured_output():
+def captured_output() -> typing.Generator:
"""Captures output to stdout and stderr"""
- new_out, new_err = make_buffer(), make_buffer()
- old_out, old_err = sys.stdout, sys.stderr
- try:
- sys.stdout, sys.stderr = new_out, new_err
- yield new_out, new_err
- finally:
- sys.stdout, sys.stderr = old_out, old_err
+ # According to mypy, we're not supposed to change buf_out.buffer.
+ # However, this is just a test, and it works, hence the 'type: ignore'.
+ buf_out = io.StringIO()
+ buf_out.buffer = io.BytesIO() # type: ignore
+
+ buf_err = io.StringIO()
+ buf_err.buffer = io.BytesIO() # type: ignore
+
+ with redirect_stdout(buf_out), redirect_stderr(buf_err):
+ yield buf_out, buf_err
+
+
+def get_bytes_out(buf) -> bytes:
+ return buf.buffer.getvalue()
@contextmanager