blob: b00f31b400c3f3afc17658951168618dc15df8b1 [file] [log] [blame]
Antoine Pitrou2dbc6e62015-04-11 00:31:01 +02001"""Internal classes used by the gzip, lzma and bz2 modules"""
2
3import io
4
5
6BUFFER_SIZE = io.DEFAULT_BUFFER_SIZE # Compressed data read chunk size
7
8
9class BaseStream(io.BufferedIOBase):
10 """Mode-checking helper functions."""
11
12 def _check_not_closed(self):
13 if self.closed:
14 raise ValueError("I/O operation on closed file")
15
16 def _check_can_read(self):
17 if not self.readable():
18 raise io.UnsupportedOperation("File not open for reading")
19
20 def _check_can_write(self):
21 if not self.writable():
22 raise io.UnsupportedOperation("File not open for writing")
23
24 def _check_can_seek(self):
25 if not self.readable():
26 raise io.UnsupportedOperation("Seeking is only supported "
27 "on files open for reading")
28 if not self.seekable():
29 raise io.UnsupportedOperation("The underlying file object "
30 "does not support seeking")
31
32
33class DecompressReader(io.RawIOBase):
34 """Adapts the decompressor API to a RawIOBase reader API"""
35
36 def readable(self):
37 return True
38
39 def __init__(self, fp, decomp_factory, trailing_error=(), **decomp_args):
40 self._fp = fp
41 self._eof = False
42 self._pos = 0 # Current offset in decompressed stream
43
44 # Set to size of decompressed stream once it is known, for SEEK_END
45 self._size = -1
46
47 # Save the decompressor factory and arguments.
48 # If the file contains multiple compressed streams, each
49 # stream will need a separate decompressor object. A new decompressor
50 # object is also needed when implementing a backwards seek().
51 self._decomp_factory = decomp_factory
52 self._decomp_args = decomp_args
53 self._decompressor = self._decomp_factory(**self._decomp_args)
54
55 # Exception class to catch from decompressor signifying invalid
56 # trailing data to ignore
57 self._trailing_error = trailing_error
58
59 def close(self):
60 self._decompressor = None
61 return super().close()
62
63 def seekable(self):
64 return self._fp.seekable()
65
66 def readinto(self, b):
67 with memoryview(b) as view, view.cast("B") as byte_view:
68 data = self.read(len(byte_view))
69 byte_view[:len(data)] = data
70 return len(data)
71
72 def read(self, size=-1):
73 if size < 0:
74 return self.readall()
75
76 if not size or self._eof:
77 return b""
78 data = None # Default if EOF is encountered
79 # Depending on the input data, our call to the decompressor may not
80 # return any data. In this case, try again after reading another block.
81 while True:
82 if self._decompressor.eof:
83 rawblock = (self._decompressor.unused_data or
84 self._fp.read(BUFFER_SIZE))
85 if not rawblock:
86 break
87 # Continue to next stream.
88 self._decompressor = self._decomp_factory(
89 **self._decomp_args)
90 try:
91 data = self._decompressor.decompress(rawblock, size)
92 except self._trailing_error:
93 # Trailing data isn't a valid compressed stream; ignore it.
94 break
95 else:
96 if self._decompressor.needs_input:
97 rawblock = self._fp.read(BUFFER_SIZE)
98 if not rawblock:
99 raise EOFError("Compressed file ended before the "
100 "end-of-stream marker was reached")
101 else:
102 rawblock = b""
103 data = self._decompressor.decompress(rawblock, size)
104 if data:
105 break
106 if not data:
107 self._eof = True
108 self._size = self._pos
109 return b""
110 self._pos += len(data)
111 return data
112
113 # Rewind the file to the beginning of the data stream.
114 def _rewind(self):
115 self._fp.seek(0)
116 self._eof = False
117 self._pos = 0
118 self._decompressor = self._decomp_factory(**self._decomp_args)
119
120 def seek(self, offset, whence=io.SEEK_SET):
121 # Recalculate offset as an absolute file position.
122 if whence == io.SEEK_SET:
123 pass
124 elif whence == io.SEEK_CUR:
125 offset = self._pos + offset
126 elif whence == io.SEEK_END:
127 # Seeking relative to EOF - we need to know the file's size.
128 if self._size < 0:
129 while self.read(io.DEFAULT_BUFFER_SIZE):
130 pass
131 offset = self._size + offset
132 else:
133 raise ValueError("Invalid value for whence: {}".format(whence))
134
135 # Make it so that offset is the number of bytes to skip forward.
136 if offset < self._pos:
137 self._rewind()
138 else:
139 offset -= self._pos
140
141 # Read and discard data until we reach the desired position.
142 while offset > 0:
143 data = self.read(min(io.DEFAULT_BUFFER_SIZE, offset))
144 if not data:
145 break
146 offset -= len(data)
147
148 return self._pos
149
150 def tell(self):
151 """Return the current file position."""
152 return self._pos