Issue #1470548: XMLGenerator now works with UTF-16 and UTF-32 encodings.
diff --git a/Lib/test/test_sax.py b/Lib/test/test_sax.py
index a16c821..825c16a 100644
--- a/Lib/test/test_sax.py
+++ b/Lib/test/test_sax.py
@@ -14,6 +14,7 @@
from xml.sax.handler import feature_namespaces
from xml.sax.xmlreader import InputSource, AttributesImpl, AttributesNSImpl
from cStringIO import StringIO
+import io
import os.path
import shutil
import test.test_support as support
@@ -170,9 +171,9 @@
start = '<?xml version="1.0" encoding="iso-8859-1"?>\n'
-class XmlgenTest(unittest.TestCase):
+class XmlgenTest:
def test_xmlgen_basic(self):
- result = StringIO()
+ result = self.ioclass()
gen = XMLGenerator(result)
gen.startDocument()
gen.startElement("doc", {})
@@ -182,7 +183,7 @@
self.assertEqual(result.getvalue(), start + "<doc></doc>")
def test_xmlgen_content(self):
- result = StringIO()
+ result = self.ioclass()
gen = XMLGenerator(result)
gen.startDocument()
@@ -194,7 +195,7 @@
self.assertEqual(result.getvalue(), start + "<doc>huhei</doc>")
def test_xmlgen_pi(self):
- result = StringIO()
+ result = self.ioclass()
gen = XMLGenerator(result)
gen.startDocument()
@@ -206,7 +207,7 @@
self.assertEqual(result.getvalue(), start + "<?test data?><doc></doc>")
def test_xmlgen_content_escape(self):
- result = StringIO()
+ result = self.ioclass()
gen = XMLGenerator(result)
gen.startDocument()
@@ -219,7 +220,7 @@
start + "<doc><huhei&</doc>")
def test_xmlgen_attr_escape(self):
- result = StringIO()
+ result = self.ioclass()
gen = XMLGenerator(result)
gen.startDocument()
@@ -238,8 +239,41 @@
"<e a=\"'"\"></e>"
"<e a=\" 	\"></e></doc>"))
+ def test_xmlgen_encoding(self):
+ encodings = ('iso-8859-15', 'utf-8',
+ 'utf-16be', 'utf-16le',
+ 'utf-32be', 'utf-32le')
+ for encoding in encodings:
+ result = self.ioclass()
+ gen = XMLGenerator(result, encoding=encoding)
+
+ gen.startDocument()
+ gen.startElement("doc", {"a": u'\u20ac'})
+ gen.characters(u"\u20ac")
+ gen.endElement("doc")
+ gen.endDocument()
+
+ self.assertEqual(result.getvalue(), (
+ u'<?xml version="1.0" encoding="%s"?>\n'
+ u'<doc a="\u20ac">\u20ac</doc>' % encoding
+ ).encode(encoding, 'xmlcharrefreplace'))
+
+ def test_xmlgen_unencodable(self):
+ result = self.ioclass()
+ gen = XMLGenerator(result, encoding='ascii')
+
+ gen.startDocument()
+ gen.startElement("doc", {"a": u'\u20ac'})
+ gen.characters(u"\u20ac")
+ gen.endElement("doc")
+ gen.endDocument()
+
+ self.assertEqual(result.getvalue(),
+ '<?xml version="1.0" encoding="ascii"?>\n'
+ '<doc a="€">€</doc>')
+
def test_xmlgen_ignorable(self):
- result = StringIO()
+ result = self.ioclass()
gen = XMLGenerator(result)
gen.startDocument()
@@ -251,7 +285,7 @@
self.assertEqual(result.getvalue(), start + "<doc> </doc>")
def test_xmlgen_ns(self):
- result = StringIO()
+ result = self.ioclass()
gen = XMLGenerator(result)
gen.startDocument()
@@ -269,7 +303,7 @@
ns_uri))
def test_1463026_1(self):
- result = StringIO()
+ result = self.ioclass()
gen = XMLGenerator(result)
gen.startDocument()
@@ -280,7 +314,7 @@
self.assertEqual(result.getvalue(), start+'<a b="c"></a>')
def test_1463026_2(self):
- result = StringIO()
+ result = self.ioclass()
gen = XMLGenerator(result)
gen.startDocument()
@@ -293,7 +327,7 @@
self.assertEqual(result.getvalue(), start+'<a xmlns="qux"></a>')
def test_1463026_3(self):
- result = StringIO()
+ result = self.ioclass()
gen = XMLGenerator(result)
gen.startDocument()
@@ -321,7 +355,7 @@
parser = make_parser()
parser.setFeature(feature_namespaces, True)
- result = StringIO()
+ result = self.ioclass()
gen = XMLGenerator(result)
parser.setContentHandler(gen)
parser.parse(test_xml)
@@ -340,7 +374,7 @@
#
# This test demonstrates the bug by direct manipulation of the
# XMLGenerator.
- result = StringIO()
+ result = self.ioclass()
gen = XMLGenerator(result)
gen.startDocument()
@@ -360,6 +394,29 @@
'<a:g2 xml:lang="en">Hello</a:g2>'
'</a:g1>'))
+ def test_no_close_file(self):
+ result = self.ioclass()
+ def func(out):
+ gen = XMLGenerator(out)
+ gen.startDocument()
+ gen.startElement("doc", {})
+ func(result)
+ self.assertFalse(result.closed)
+
+class StringXmlgenTest(XmlgenTest, unittest.TestCase):
+ ioclass = StringIO
+
+class BytesIOXmlgenTest(XmlgenTest, unittest.TestCase):
+ ioclass = io.BytesIO
+
+class WriterXmlgenTest(XmlgenTest, unittest.TestCase):
+ class ioclass(list):
+ write = list.append
+ closed = False
+
+ def getvalue(self):
+ return b''.join(self)
+
class XMLFilterBaseTest(unittest.TestCase):
def test_filter_basic(self):
@@ -804,7 +861,9 @@
def test_main():
run_unittest(MakeParserTest,
SaxutilsTest,
- XmlgenTest,
+ StringXmlgenTest,
+ BytesIOXmlgenTest,
+ WriterXmlgenTest,
ExpatReaderTest,
ErrorReportingTest,
XmlReaderTest)
diff --git a/Lib/xml/sax/saxutils.py b/Lib/xml/sax/saxutils.py
index 7989713..dad74f5 100644
--- a/Lib/xml/sax/saxutils.py
+++ b/Lib/xml/sax/saxutils.py
@@ -4,6 +4,7 @@
"""
import os, urlparse, urllib, types
+import io
import sys
import handler
import xmlreader
@@ -13,15 +14,6 @@
except AttributeError:
_StringTypes = [types.StringType]
-# See whether the xmlcharrefreplace error handler is
-# supported
-try:
- from codecs import xmlcharrefreplace_errors
- _error_handling = "xmlcharrefreplace"
- del xmlcharrefreplace_errors
-except ImportError:
- _error_handling = "strict"
-
def __dict_replace(s, d):
"""Replace substrings of a string using a dictionary."""
for key, value in d.items():
@@ -82,25 +74,46 @@
return data
+def _gettextwriter(out, encoding):
+ if out is None:
+ import sys
+ out = sys.stdout
+
+ if isinstance(out, io.RawIOBase):
+ buffer = io.BufferedIOBase(out)
+ # Keep the original file open when the TextIOWrapper is
+ # destroyed
+ buffer.close = lambda: None
+ else:
+ # This is to handle passed objects that aren't in the
+ # IOBase hierarchy, but just have a write method
+ buffer = io.BufferedIOBase()
+ buffer.writable = lambda: True
+ buffer.write = out.write
+ try:
+ # TextIOWrapper uses this methods to determine
+ # if BOM (for UTF-16, etc) should be added
+ buffer.seekable = out.seekable
+ buffer.tell = out.tell
+ except AttributeError:
+ pass
+ # wrap a binary writer with TextIOWrapper
+ return io.TextIOWrapper(buffer, encoding=encoding,
+ errors='xmlcharrefreplace',
+ newline='\n')
+
class XMLGenerator(handler.ContentHandler):
def __init__(self, out=None, encoding="iso-8859-1"):
- if out is None:
- import sys
- out = sys.stdout
handler.ContentHandler.__init__(self)
- self._out = out
+ out = _gettextwriter(out, encoding)
+ self._write = out.write
+ self._flush = out.flush
self._ns_contexts = [{}] # contains uri -> prefix dicts
self._current_context = self._ns_contexts[-1]
self._undeclared_ns_maps = []
self._encoding = encoding
- def _write(self, text):
- if isinstance(text, str):
- self._out.write(text)
- else:
- self._out.write(text.encode(self._encoding, _error_handling))
-
def _qname(self, name):
"""Builds a qualified name from a (ns_url, localname) pair"""
if name[0]:
@@ -121,9 +134,12 @@
# ContentHandler methods
def startDocument(self):
- self._write('<?xml version="1.0" encoding="%s"?>\n' %
+ self._write(u'<?xml version="1.0" encoding="%s"?>\n' %
self._encoding)
+ def endDocument(self):
+ self._flush()
+
def startPrefixMapping(self, prefix, uri):
self._ns_contexts.append(self._current_context.copy())
self._current_context[uri] = prefix
@@ -134,39 +150,39 @@
del self._ns_contexts[-1]
def startElement(self, name, attrs):
- self._write('<' + name)
+ self._write(u'<' + name)
for (name, value) in attrs.items():
- self._write(' %s=%s' % (name, quoteattr(value)))
- self._write('>')
+ self._write(u' %s=%s' % (name, quoteattr(value)))
+ self._write(u'>')
def endElement(self, name):
- self._write('</%s>' % name)
+ self._write(u'</%s>' % name)
def startElementNS(self, name, qname, attrs):
- self._write('<' + self._qname(name))
+ self._write(u'<' + self._qname(name))
for prefix, uri in self._undeclared_ns_maps:
if prefix:
- self._out.write(' xmlns:%s="%s"' % (prefix, uri))
+ self._write(u' xmlns:%s="%s"' % (prefix, uri))
else:
- self._out.write(' xmlns="%s"' % uri)
+ self._write(u' xmlns="%s"' % uri)
self._undeclared_ns_maps = []
for (name, value) in attrs.items():
- self._write(' %s=%s' % (self._qname(name), quoteattr(value)))
- self._write('>')
+ self._write(u' %s=%s' % (self._qname(name), quoteattr(value)))
+ self._write(u'>')
def endElementNS(self, name, qname):
- self._write('</%s>' % self._qname(name))
+ self._write(u'</%s>' % self._qname(name))
def characters(self, content):
- self._write(escape(content))
+ self._write(escape(unicode(content)))
def ignorableWhitespace(self, content):
- self._write(content)
+ self._write(unicode(content))
def processingInstruction(self, target, data):
- self._write('<?%s %s?>' % (target, data))
+ self._write(u'<?%s %s?>' % (target, data))
class XMLFilterBase(xmlreader.XMLReader):