Issue #1470548: XMLGenerator now works with binary output streams.
diff --git a/Lib/test/test_sax.py b/Lib/test/test_sax.py
index a96f0a8..218d519 100644
--- a/Lib/test/test_sax.py
+++ b/Lib/test/test_sax.py
@@ -13,7 +13,7 @@
 from xml.sax.expatreader import create_parser
 from xml.sax.handler import feature_namespaces
 from xml.sax.xmlreader import InputSource, AttributesImpl, AttributesNSImpl
-from io import StringIO
+from io import BytesIO, StringIO
 import os.path
 import shutil
 from test import support
@@ -173,31 +173,29 @@
 
 # ===== XMLGenerator
 
-start = '<?xml version="1.0" encoding="iso-8859-1"?>\n'
-
-class XmlgenTest(unittest.TestCase):
+class XmlgenTest:
     def test_xmlgen_basic(self):
-        result = StringIO()
+        result = self.ioclass()
         gen = XMLGenerator(result)
         gen.startDocument()
         gen.startElement("doc", {})
         gen.endElement("doc")
         gen.endDocument()
 
-        self.assertEqual(result.getvalue(), start + "<doc></doc>")
+        self.assertEqual(result.getvalue(), self.xml("<doc></doc>"))
 
     def test_xmlgen_basic_empty(self):
-        result = StringIO()
+        result = self.ioclass()
         gen = XMLGenerator(result, short_empty_elements=True)
         gen.startDocument()
         gen.startElement("doc", {})
         gen.endElement("doc")
         gen.endDocument()
 
-        self.assertEqual(result.getvalue(), start + "<doc/>")
+        self.assertEqual(result.getvalue(), self.xml("<doc/>"))
 
     def test_xmlgen_content(self):
-        result = StringIO()
+        result = self.ioclass()
         gen = XMLGenerator(result)
 
         gen.startDocument()
@@ -206,10 +204,10 @@
         gen.endElement("doc")
         gen.endDocument()
 
-        self.assertEqual(result.getvalue(), start + "<doc>huhei</doc>")
+        self.assertEqual(result.getvalue(), self.xml("<doc>huhei</doc>"))
 
     def test_xmlgen_content_empty(self):
-        result = StringIO()
+        result = self.ioclass()
         gen = XMLGenerator(result, short_empty_elements=True)
 
         gen.startDocument()
@@ -218,10 +216,10 @@
         gen.endElement("doc")
         gen.endDocument()
 
-        self.assertEqual(result.getvalue(), start + "<doc>huhei</doc>")
+        self.assertEqual(result.getvalue(), self.xml("<doc>huhei</doc>"))
 
     def test_xmlgen_pi(self):
-        result = StringIO()
+        result = self.ioclass()
         gen = XMLGenerator(result)
 
         gen.startDocument()
@@ -230,10 +228,11 @@
         gen.endElement("doc")
         gen.endDocument()
 
-        self.assertEqual(result.getvalue(), start + "<?test data?><doc></doc>")
+        self.assertEqual(result.getvalue(),
+            self.xml("<?test data?><doc></doc>"))
 
     def test_xmlgen_content_escape(self):
-        result = StringIO()
+        result = self.ioclass()
         gen = XMLGenerator(result)
 
         gen.startDocument()
@@ -243,10 +242,10 @@
         gen.endDocument()
 
         self.assertEqual(result.getvalue(),
-            start + "<doc>&lt;huhei&amp;</doc>")
+            self.xml("<doc>&lt;huhei&amp;</doc>"))
 
     def test_xmlgen_attr_escape(self):
-        result = StringIO()
+        result = self.ioclass()
         gen = XMLGenerator(result)
 
         gen.startDocument()
@@ -260,13 +259,43 @@
         gen.endElement("doc")
         gen.endDocument()
 
-        self.assertEqual(result.getvalue(), start +
-            ("<doc a='\"'><e a=\"'\"></e>"
-             "<e a=\"'&quot;\"></e>"
-             "<e a=\"&#10;&#13;&#9;\"></e></doc>"))
+        self.assertEqual(result.getvalue(), self.xml(
+            "<doc a='\"'><e a=\"'\"></e>"
+            "<e a=\"'&quot;\"></e>"
+            "<e a=\"&#10;&#13;&#9;\"></e></doc>"))
+
+    def test_xmlgen_encoding(self):
+        encodings = ('iso-8859-15', 'utf-8', 'utf-8-sig',
+                     'utf-16', 'utf-16be', 'utf-16le',
+                     'utf-32', 'utf-32be', 'utf-32le')
+        for encoding in encodings:
+            result = self.ioclass()
+            gen = XMLGenerator(result, encoding=encoding)
+
+            gen.startDocument()
+            gen.startElement("doc", {"a": '\u20ac'})
+            gen.characters("\u20ac")
+            gen.endElement("doc")
+            gen.endDocument()
+
+            self.assertEqual(result.getvalue(),
+                self.xml('<doc a="\u20ac">\u20ac</doc>', encoding=encoding))
+
+    def test_xmlgen_unencodable(self):
+        result = self.ioclass()
+        gen = XMLGenerator(result, encoding='ascii')
+
+        gen.startDocument()
+        gen.startElement("doc", {"a": '\u20ac'})
+        gen.characters("\u20ac")
+        gen.endElement("doc")
+        gen.endDocument()
+
+        self.assertEqual(result.getvalue(),
+            self.xml('<doc a="&#8364;">&#8364;</doc>', encoding='ascii'))
 
     def test_xmlgen_ignorable(self):
-        result = StringIO()
+        result = self.ioclass()
         gen = XMLGenerator(result)
 
         gen.startDocument()
@@ -275,10 +304,10 @@
         gen.endElement("doc")
         gen.endDocument()
 
-        self.assertEqual(result.getvalue(), start + "<doc> </doc>")
+        self.assertEqual(result.getvalue(), self.xml("<doc> </doc>"))
 
     def test_xmlgen_ignorable_empty(self):
-        result = StringIO()
+        result = self.ioclass()
         gen = XMLGenerator(result, short_empty_elements=True)
 
         gen.startDocument()
@@ -287,10 +316,10 @@
         gen.endElement("doc")
         gen.endDocument()
 
-        self.assertEqual(result.getvalue(), start + "<doc> </doc>")
+        self.assertEqual(result.getvalue(), self.xml("<doc> </doc>"))
 
     def test_xmlgen_ns(self):
-        result = StringIO()
+        result = self.ioclass()
         gen = XMLGenerator(result)
 
         gen.startDocument()
@@ -303,12 +332,12 @@
         gen.endPrefixMapping("ns1")
         gen.endDocument()
 
-        self.assertEqual(result.getvalue(), start + \
-           ('<ns1:doc xmlns:ns1="%s"><udoc></udoc></ns1:doc>' %
+        self.assertEqual(result.getvalue(), self.xml(
+           '<ns1:doc xmlns:ns1="%s"><udoc></udoc></ns1:doc>' %
                                          ns_uri))
 
     def test_xmlgen_ns_empty(self):
-        result = StringIO()
+        result = self.ioclass()
         gen = XMLGenerator(result, short_empty_elements=True)
 
         gen.startDocument()
@@ -321,12 +350,12 @@
         gen.endPrefixMapping("ns1")
         gen.endDocument()
 
-        self.assertEqual(result.getvalue(), start + \
-           ('<ns1:doc xmlns:ns1="%s"><udoc/></ns1:doc>' %
+        self.assertEqual(result.getvalue(), self.xml(
+           '<ns1:doc xmlns:ns1="%s"><udoc/></ns1:doc>' %
                                          ns_uri))
 
     def test_1463026_1(self):
-        result = StringIO()
+        result = self.ioclass()
         gen = XMLGenerator(result)
 
         gen.startDocument()
@@ -334,10 +363,10 @@
         gen.endElementNS((None, 'a'), 'a')
         gen.endDocument()
 
-        self.assertEqual(result.getvalue(), start+'<a b="c"></a>')
+        self.assertEqual(result.getvalue(), self.xml('<a b="c"></a>'))
 
     def test_1463026_1_empty(self):
-        result = StringIO()
+        result = self.ioclass()
         gen = XMLGenerator(result, short_empty_elements=True)
 
         gen.startDocument()
@@ -345,10 +374,10 @@
         gen.endElementNS((None, 'a'), 'a')
         gen.endDocument()
 
-        self.assertEqual(result.getvalue(), start+'<a b="c"/>')
+        self.assertEqual(result.getvalue(), self.xml('<a b="c"/>'))
 
     def test_1463026_2(self):
-        result = StringIO()
+        result = self.ioclass()
         gen = XMLGenerator(result)
 
         gen.startDocument()
@@ -358,10 +387,10 @@
         gen.endPrefixMapping(None)
         gen.endDocument()
 
-        self.assertEqual(result.getvalue(), start+'<a xmlns="qux"></a>')
+        self.assertEqual(result.getvalue(), self.xml('<a xmlns="qux"></a>'))
 
     def test_1463026_2_empty(self):
-        result = StringIO()
+        result = self.ioclass()
         gen = XMLGenerator(result, short_empty_elements=True)
 
         gen.startDocument()
@@ -371,10 +400,10 @@
         gen.endPrefixMapping(None)
         gen.endDocument()
 
-        self.assertEqual(result.getvalue(), start+'<a xmlns="qux"/>')
+        self.assertEqual(result.getvalue(), self.xml('<a xmlns="qux"/>'))
 
     def test_1463026_3(self):
-        result = StringIO()
+        result = self.ioclass()
         gen = XMLGenerator(result)
 
         gen.startDocument()
@@ -385,10 +414,10 @@
         gen.endDocument()
 
         self.assertEqual(result.getvalue(),
-            start+'<my:a xmlns:my="qux" b="c"></my:a>')
+            self.xml('<my:a xmlns:my="qux" b="c"></my:a>'))
 
     def test_1463026_3_empty(self):
-        result = StringIO()
+        result = self.ioclass()
         gen = XMLGenerator(result, short_empty_elements=True)
 
         gen.startDocument()
@@ -399,7 +428,7 @@
         gen.endDocument()
 
         self.assertEqual(result.getvalue(),
-            start+'<my:a xmlns:my="qux" b="c"/>')
+            self.xml('<my:a xmlns:my="qux" b="c"/>'))
 
     def test_5027_1(self):
         # The xml prefix (as in xml:lang below) is reserved and bound by
@@ -416,13 +445,13 @@
 
         parser = make_parser()
         parser.setFeature(feature_namespaces, True)
-        result = StringIO()
+        result = self.ioclass()
         gen = XMLGenerator(result)
         parser.setContentHandler(gen)
         parser.parse(test_xml)
 
         self.assertEqual(result.getvalue(),
-                         start + (
+                         self.xml(
                          '<a:g1 xmlns:a="http://example.com/ns">'
                           '<a:g2 xml:lang="en">Hello</a:g2>'
                          '</a:g1>'))
@@ -435,7 +464,7 @@
         #
         # This test demonstrates the bug by direct manipulation of the
         # XMLGenerator.
-        result = StringIO()
+        result = self.ioclass()
         gen = XMLGenerator(result)
 
         gen.startDocument()
@@ -450,15 +479,57 @@
         gen.endDocument()
 
         self.assertEqual(result.getvalue(),
-                         start + (
+                         self.xml(
                          '<a:g1 xmlns:a="http://example.com/ns">'
                           '<a:g2 xml:lang="en">Hello</a:g2>'
                          '</a:g1>'))
 
+    def test_no_close_file(self):
+        result = self.ioclass()
+        def func(out):
+            gen = XMLGenerator(out)
+            gen.startDocument()
+            gen.startElement("doc", {})
+        func(result)
+        self.assertFalse(result.closed)
+
+class StringXmlgenTest(XmlgenTest, unittest.TestCase):
+    ioclass = StringIO
+
+    def xml(self, doc, encoding='iso-8859-1'):
+        return '<?xml version="1.0" encoding="%s"?>\n%s' % (encoding, doc)
+
+    test_xmlgen_unencodable = None
+
+class BytesXmlgenTest(XmlgenTest, unittest.TestCase):
+    ioclass = BytesIO
+
+    def xml(self, doc, encoding='iso-8859-1'):
+        return ('<?xml version="1.0" encoding="%s"?>\n%s' %
+                (encoding, doc)).encode(encoding, 'xmlcharrefreplace')
+
+class WriterXmlgenTest(BytesXmlgenTest):
+    class ioclass(list):
+        write = list.append
+        closed = False
+
+        def seekable(self):
+            return True
+
+        def tell(self):
+            # return 0 at start and not 0 after start
+            return len(self)
+
+        def getvalue(self):
+            return b''.join(self)
+
+
+start = b'<?xml version="1.0" encoding="iso-8859-1"?>\n'
+
 
 class XMLFilterBaseTest(unittest.TestCase):
     def test_filter_basic(self):
-        result = StringIO()
+        result = BytesIO()
         gen = XMLGenerator(result)
         filter = XMLFilterBase()
         filter.setContentHandler(gen)
@@ -470,7 +541,7 @@
         filter.endElement("doc")
         filter.endDocument()
 
-        self.assertEqual(result.getvalue(), start + "<doc>content </doc>")
+        self.assertEqual(result.getvalue(), start + b"<doc>content </doc>")
 
 # ===========================================================================
 #
@@ -478,7 +549,7 @@
 #
 # ===========================================================================
 
-with open(TEST_XMLFILE_OUT) as f:
+with open(TEST_XMLFILE_OUT, 'rb') as f:
     xml_test_out = f.read()
 
 class ExpatReaderTest(XmlTestBase):
@@ -487,11 +558,11 @@
 
     def test_expat_file(self):
         parser = create_parser()
-        result = StringIO()
+        result = BytesIO()
         xmlgen = XMLGenerator(result)
 
         parser.setContentHandler(xmlgen)
-        with open(TEST_XMLFILE) as f:
+        with open(TEST_XMLFILE, 'rb') as f:
             parser.parse(f)
 
         self.assertEqual(result.getvalue(), xml_test_out)
@@ -503,7 +574,7 @@
         self.addCleanup(support.unlink, fname)
 
         parser = create_parser()
-        result = StringIO()
+        result = BytesIO()
         xmlgen = XMLGenerator(result)
 
         parser.setContentHandler(xmlgen)
@@ -547,13 +618,13 @@
 
         def resolveEntity(self, publicId, systemId):
             inpsrc = InputSource()
-            inpsrc.setByteStream(StringIO("<entity/>"))
+            inpsrc.setByteStream(BytesIO(b"<entity/>"))
             return inpsrc
 
     def test_expat_entityresolver(self):
         parser = create_parser()
         parser.setEntityResolver(self.TestEntityResolver())
-        result = StringIO()
+        result = BytesIO()
         parser.setContentHandler(XMLGenerator(result))
 
         parser.feed('<!DOCTYPE doc [\n')
@@ -563,7 +634,7 @@
         parser.close()
 
         self.assertEqual(result.getvalue(), start +
-                         "<doc><entity></entity></doc>")
+                         b"<doc><entity></entity></doc>")
 
     # ===== Attributes support
 
@@ -632,7 +703,7 @@
 
     def test_expat_inpsource_filename(self):
         parser = create_parser()
-        result = StringIO()
+        result = BytesIO()
         xmlgen = XMLGenerator(result)
 
         parser.setContentHandler(xmlgen)
@@ -642,7 +713,7 @@
 
     def test_expat_inpsource_sysid(self):
         parser = create_parser()
-        result = StringIO()
+        result = BytesIO()
         xmlgen = XMLGenerator(result)
 
         parser.setContentHandler(xmlgen)
@@ -657,7 +728,7 @@
         self.addCleanup(support.unlink, fname)
 
         parser = create_parser()
-        result = StringIO()
+        result = BytesIO()
         xmlgen = XMLGenerator(result)
 
         parser.setContentHandler(xmlgen)
@@ -667,12 +738,12 @@
 
     def test_expat_inpsource_stream(self):
         parser = create_parser()
-        result = StringIO()
+        result = BytesIO()
         xmlgen = XMLGenerator(result)
 
         parser.setContentHandler(xmlgen)
         inpsrc = InputSource()
-        with open(TEST_XMLFILE) as f:
+        with open(TEST_XMLFILE, 'rb') as f:
             inpsrc.setByteStream(f)
             parser.parse(inpsrc)
 
@@ -681,7 +752,7 @@
     # ===== IncrementalParser support
 
     def test_expat_incremental(self):
-        result = StringIO()
+        result = BytesIO()
         xmlgen = XMLGenerator(result)
         parser = create_parser()
         parser.setContentHandler(xmlgen)
@@ -690,10 +761,10 @@
         parser.feed("</doc>")
         parser.close()
 
-        self.assertEqual(result.getvalue(), start + "<doc></doc>")
+        self.assertEqual(result.getvalue(), start + b"<doc></doc>")
 
     def test_expat_incremental_reset(self):
-        result = StringIO()
+        result = BytesIO()
         xmlgen = XMLGenerator(result)
         parser = create_parser()
         parser.setContentHandler(xmlgen)
@@ -701,7 +772,7 @@
         parser.feed("<doc>")
         parser.feed("text")
 
-        result = StringIO()
+        result = BytesIO()
         xmlgen = XMLGenerator(result)
         parser.setContentHandler(xmlgen)
         parser.reset()
@@ -711,12 +782,12 @@
         parser.feed("</doc>")
         parser.close()
 
-        self.assertEqual(result.getvalue(), start + "<doc>text</doc>")
+        self.assertEqual(result.getvalue(), start + b"<doc>text</doc>")
 
     # ===== Locator support
 
     def test_expat_locator_noinfo(self):
-        result = StringIO()
+        result = BytesIO()
         xmlgen = XMLGenerator(result)
         parser = create_parser()
         parser.setContentHandler(xmlgen)
@@ -730,7 +801,7 @@
         self.assertEqual(parser.getLineNumber(), 1)
 
     def test_expat_locator_withinfo(self):
-        result = StringIO()
+        result = BytesIO()
         xmlgen = XMLGenerator(result)
         parser = create_parser()
         parser.setContentHandler(xmlgen)
@@ -745,7 +816,7 @@
         shutil.copyfile(TEST_XMLFILE, fname)
         self.addCleanup(support.unlink, fname)
 
-        result = StringIO()
+        result = BytesIO()
         xmlgen = XMLGenerator(result)
         parser = create_parser()
         parser.setContentHandler(xmlgen)
@@ -766,7 +837,7 @@
         parser = create_parser()
         parser.setContentHandler(ContentHandler()) # do nothing
         source = InputSource()
-        source.setByteStream(StringIO("<foo bar foobar>"))   #ill-formed
+        source.setByteStream(BytesIO(b"<foo bar foobar>"))   #ill-formed
         name = "a file name"
         source.setSystemId(name)
         try:
@@ -857,7 +928,9 @@
 def test_main():
     run_unittest(MakeParserTest,
                  SaxutilsTest,
-                 XmlgenTest,
+                 StringXmlgenTest,
+                 BytesXmlgenTest,
+                 WriterXmlgenTest,
                  ExpatReaderTest,
                  ErrorReportingTest,
                  XmlReaderTest)