blob: e8b70aa0a3e6806c0f2b60ffaf9944291abcf4c4 [file] [log] [blame]
Antoine Pitrou2dbc6e62015-04-11 00:31:01 +02001"""Internal classes used by the gzip, lzma and bz2 modules"""
2
3import io
Ma Linf9bedb62021-04-28 14:58:54 +08004import sys
Antoine Pitrou2dbc6e62015-04-11 00:31:01 +02005
6BUFFER_SIZE = io.DEFAULT_BUFFER_SIZE # Compressed data read chunk size
7
8
9class BaseStream(io.BufferedIOBase):
10 """Mode-checking helper functions."""
11
12 def _check_not_closed(self):
13 if self.closed:
14 raise ValueError("I/O operation on closed file")
15
16 def _check_can_read(self):
17 if not self.readable():
18 raise io.UnsupportedOperation("File not open for reading")
19
20 def _check_can_write(self):
21 if not self.writable():
22 raise io.UnsupportedOperation("File not open for writing")
23
24 def _check_can_seek(self):
25 if not self.readable():
26 raise io.UnsupportedOperation("Seeking is only supported "
27 "on files open for reading")
28 if not self.seekable():
29 raise io.UnsupportedOperation("The underlying file object "
30 "does not support seeking")
31
32
33class DecompressReader(io.RawIOBase):
34 """Adapts the decompressor API to a RawIOBase reader API"""
35
36 def readable(self):
37 return True
38
39 def __init__(self, fp, decomp_factory, trailing_error=(), **decomp_args):
40 self._fp = fp
41 self._eof = False
42 self._pos = 0 # Current offset in decompressed stream
43
44 # Set to size of decompressed stream once it is known, for SEEK_END
45 self._size = -1
46
47 # Save the decompressor factory and arguments.
48 # If the file contains multiple compressed streams, each
49 # stream will need a separate decompressor object. A new decompressor
50 # object is also needed when implementing a backwards seek().
51 self._decomp_factory = decomp_factory
52 self._decomp_args = decomp_args
53 self._decompressor = self._decomp_factory(**self._decomp_args)
54
55 # Exception class to catch from decompressor signifying invalid
56 # trailing data to ignore
57 self._trailing_error = trailing_error
58
59 def close(self):
60 self._decompressor = None
61 return super().close()
62
63 def seekable(self):
64 return self._fp.seekable()
65
66 def readinto(self, b):
67 with memoryview(b) as view, view.cast("B") as byte_view:
68 data = self.read(len(byte_view))
69 byte_view[:len(data)] = data
70 return len(data)
71
72 def read(self, size=-1):
73 if size < 0:
74 return self.readall()
75
76 if not size or self._eof:
77 return b""
78 data = None # Default if EOF is encountered
79 # Depending on the input data, our call to the decompressor may not
80 # return any data. In this case, try again after reading another block.
81 while True:
82 if self._decompressor.eof:
83 rawblock = (self._decompressor.unused_data or
84 self._fp.read(BUFFER_SIZE))
85 if not rawblock:
86 break
87 # Continue to next stream.
88 self._decompressor = self._decomp_factory(
89 **self._decomp_args)
90 try:
91 data = self._decompressor.decompress(rawblock, size)
92 except self._trailing_error:
93 # Trailing data isn't a valid compressed stream; ignore it.
94 break
95 else:
96 if self._decompressor.needs_input:
97 rawblock = self._fp.read(BUFFER_SIZE)
98 if not rawblock:
99 raise EOFError("Compressed file ended before the "
100 "end-of-stream marker was reached")
101 else:
102 rawblock = b""
103 data = self._decompressor.decompress(rawblock, size)
104 if data:
105 break
106 if not data:
107 self._eof = True
108 self._size = self._pos
109 return b""
110 self._pos += len(data)
111 return data
112
Ma Linf9bedb62021-04-28 14:58:54 +0800113 def readall(self):
114 chunks = []
115 # sys.maxsize means the max length of output buffer is unlimited,
116 # so that the whole input buffer can be decompressed within one
117 # .decompress() call.
118 while data := self.read(sys.maxsize):
119 chunks.append(data)
120
121 return b"".join(chunks)
122
Antoine Pitrou2dbc6e62015-04-11 00:31:01 +0200123 # Rewind the file to the beginning of the data stream.
124 def _rewind(self):
125 self._fp.seek(0)
126 self._eof = False
127 self._pos = 0
128 self._decompressor = self._decomp_factory(**self._decomp_args)
129
130 def seek(self, offset, whence=io.SEEK_SET):
131 # Recalculate offset as an absolute file position.
132 if whence == io.SEEK_SET:
133 pass
134 elif whence == io.SEEK_CUR:
135 offset = self._pos + offset
136 elif whence == io.SEEK_END:
137 # Seeking relative to EOF - we need to know the file's size.
138 if self._size < 0:
139 while self.read(io.DEFAULT_BUFFER_SIZE):
140 pass
141 offset = self._size + offset
142 else:
143 raise ValueError("Invalid value for whence: {}".format(whence))
144
145 # Make it so that offset is the number of bytes to skip forward.
146 if offset < self._pos:
147 self._rewind()
148 else:
149 offset -= self._pos
150
151 # Read and discard data until we reach the desired position.
152 while offset > 0:
153 data = self.read(min(io.DEFAULT_BUFFER_SIZE, offset))
154 if not data:
155 break
156 offset -= len(data)
157
158 return self._pos
159
160 def tell(self):
161 """Return the current file position."""
162 return self._pos