]> jfr.im git - yt-dlp.git/blob - youtube_dl/utils.py
Refactor IDParser to search for elements by any attribute not just ID
[yt-dlp.git] / youtube_dl / utils.py
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3
4 import gzip
5 import io
6 import json
7 import locale
8 import os
9 import re
10 import sys
11 import zlib
12 import email.utils
13 import json
14
15 try:
16 import urllib.request as compat_urllib_request
17 except ImportError: # Python 2
18 import urllib2 as compat_urllib_request
19
20 try:
21 import urllib.error as compat_urllib_error
22 except ImportError: # Python 2
23 import urllib2 as compat_urllib_error
24
25 try:
26 import urllib.parse as compat_urllib_parse
27 except ImportError: # Python 2
28 import urllib as compat_urllib_parse
29
30 try:
31 from urllib.parse import urlparse as compat_urllib_parse_urlparse
32 except ImportError: # Python 2
33 from urlparse import urlparse as compat_urllib_parse_urlparse
34
35 try:
36 import http.cookiejar as compat_cookiejar
37 except ImportError: # Python 2
38 import cookielib as compat_cookiejar
39
40 try:
41 import html.entities as compat_html_entities
42 except ImportError: # Python 2
43 import htmlentitydefs as compat_html_entities
44
45 try:
46 import html.parser as compat_html_parser
47 except ImportError: # Python 2
48 import HTMLParser as compat_html_parser
49
50 try:
51 import http.client as compat_http_client
52 except ImportError: # Python 2
53 import httplib as compat_http_client
54
55 try:
56 from subprocess import DEVNULL
57 compat_subprocess_get_DEVNULL = lambda: DEVNULL
58 except ImportError:
59 compat_subprocess_get_DEVNULL = lambda: open(os.path.devnull, 'w')
60
61 try:
62 from urllib.parse import parse_qs as compat_parse_qs
63 except ImportError: # Python 2
64 # HACK: The following is the correct parse_qs implementation from cpython 3's stdlib.
65 # Python 2's version is apparently totally broken
66 def _unquote(string, encoding='utf-8', errors='replace'):
67 if string == '':
68 return string
69 res = string.split('%')
70 if len(res) == 1:
71 return string
72 if encoding is None:
73 encoding = 'utf-8'
74 if errors is None:
75 errors = 'replace'
76 # pct_sequence: contiguous sequence of percent-encoded bytes, decoded
77 pct_sequence = b''
78 string = res[0]
79 for item in res[1:]:
80 try:
81 if not item:
82 raise ValueError
83 pct_sequence += item[:2].decode('hex')
84 rest = item[2:]
85 if not rest:
86 # This segment was just a single percent-encoded character.
87 # May be part of a sequence of code units, so delay decoding.
88 # (Stored in pct_sequence).
89 continue
90 except ValueError:
91 rest = '%' + item
92 # Encountered non-percent-encoded characters. Flush the current
93 # pct_sequence.
94 string += pct_sequence.decode(encoding, errors) + rest
95 pct_sequence = b''
96 if pct_sequence:
97 # Flush the final pct_sequence
98 string += pct_sequence.decode(encoding, errors)
99 return string
100
101 def _parse_qsl(qs, keep_blank_values=False, strict_parsing=False,
102 encoding='utf-8', errors='replace'):
103 qs, _coerce_result = qs, unicode
104 pairs = [s2 for s1 in qs.split('&') for s2 in s1.split(';')]
105 r = []
106 for name_value in pairs:
107 if not name_value and not strict_parsing:
108 continue
109 nv = name_value.split('=', 1)
110 if len(nv) != 2:
111 if strict_parsing:
112 raise ValueError("bad query field: %r" % (name_value,))
113 # Handle case of a control-name with no equal sign
114 if keep_blank_values:
115 nv.append('')
116 else:
117 continue
118 if len(nv[1]) or keep_blank_values:
119 name = nv[0].replace('+', ' ')
120 name = _unquote(name, encoding=encoding, errors=errors)
121 name = _coerce_result(name)
122 value = nv[1].replace('+', ' ')
123 value = _unquote(value, encoding=encoding, errors=errors)
124 value = _coerce_result(value)
125 r.append((name, value))
126 return r
127
128 def compat_parse_qs(qs, keep_blank_values=False, strict_parsing=False,
129 encoding='utf-8', errors='replace'):
130 parsed_result = {}
131 pairs = _parse_qsl(qs, keep_blank_values, strict_parsing,
132 encoding=encoding, errors=errors)
133 for name, value in pairs:
134 if name in parsed_result:
135 parsed_result[name].append(value)
136 else:
137 parsed_result[name] = [value]
138 return parsed_result
139
140 try:
141 compat_str = unicode # Python 2
142 except NameError:
143 compat_str = str
144
145 try:
146 compat_chr = unichr # Python 2
147 except NameError:
148 compat_chr = chr
149
150 std_headers = {
151 'User-Agent': 'Mozilla/5.0 (X11; Linux x86_64; rv:10.0) Gecko/20100101 Firefox/10.0',
152 'Accept-Charset': 'ISO-8859-1,utf-8;q=0.7,*;q=0.7',
153 'Accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8',
154 'Accept-Encoding': 'gzip, deflate',
155 'Accept-Language': 'en-us,en;q=0.5',
156 }
157 def preferredencoding():
158 """Get preferred encoding.
159
160 Returns the best encoding scheme for the system, based on
161 locale.getpreferredencoding() and some further tweaks.
162 """
163 try:
164 pref = locale.getpreferredencoding()
165 u'TEST'.encode(pref)
166 except:
167 pref = 'UTF-8'
168
169 return pref
170
171 if sys.version_info < (3,0):
172 def compat_print(s):
173 print(s.encode(preferredencoding(), 'xmlcharrefreplace'))
174 else:
175 def compat_print(s):
176 assert type(s) == type(u'')
177 print(s)
178
179 # In Python 2.x, json.dump expects a bytestream.
180 # In Python 3.x, it writes to a character stream
181 if sys.version_info < (3,0):
182 def write_json_file(obj, fn):
183 with open(fn, 'wb') as f:
184 json.dump(obj, f)
185 else:
186 def write_json_file(obj, fn):
187 with open(fn, 'w', encoding='utf-8') as f:
188 json.dump(obj, f)
189
190
191 def htmlentity_transform(matchobj):
192 """Transforms an HTML entity to a character.
193
194 This function receives a match object and is intended to be used with
195 the re.sub() function.
196 """
197 entity = matchobj.group(1)
198
199 # Known non-numeric HTML entity
200 if entity in compat_html_entities.name2codepoint:
201 return compat_chr(compat_html_entities.name2codepoint[entity])
202
203 mobj = re.match(u'(?u)#(x?\\d+)', entity)
204 if mobj is not None:
205 numstr = mobj.group(1)
206 if numstr.startswith(u'x'):
207 base = 16
208 numstr = u'0%s' % numstr
209 else:
210 base = 10
211 return compat_chr(int(numstr, base))
212
213 # Unknown entity in name, return its literal representation
214 return (u'&%s;' % entity)
215
216 compat_html_parser.locatestarttagend = re.compile(r"""<[a-zA-Z][-.a-zA-Z0-9:_]*(?:\s+(?:(?<=['"\s])[^\s/>][^\s/=>]*(?:\s*=+\s*(?:'[^']*'|"[^"]*"|(?!['"])[^>\s]*))?\s*)*)?\s*""", re.VERBOSE) # backport bugfix
217 class AttrParser(compat_html_parser.HTMLParser):
218 """Modified HTMLParser that isolates a tag with the specified attribute"""
219 def __init__(self, attribute, value):
220 self.attribute = attribute
221 self.value = value
222 self.result = None
223 self.started = False
224 self.depth = {}
225 self.html = None
226 self.watch_startpos = False
227 self.error_count = 0
228 compat_html_parser.HTMLParser.__init__(self)
229
230 def error(self, message):
231 if self.error_count > 10 or self.started:
232 raise compat_html_parser.HTMLParseError(message, self.getpos())
233 self.rawdata = '\n'.join(self.html.split('\n')[self.getpos()[0]:]) # skip one line
234 self.error_count += 1
235 self.goahead(1)
236
237 def loads(self, html):
238 self.html = html
239 self.feed(html)
240 self.close()
241
242 def handle_starttag(self, tag, attrs):
243 attrs = dict(attrs)
244 if self.started:
245 self.find_startpos(None)
246 if self.attribute in attrs and attrs[self.attribute] == self.value:
247 self.result = [tag]
248 self.started = True
249 self.watch_startpos = True
250 if self.started:
251 if not tag in self.depth: self.depth[tag] = 0
252 self.depth[tag] += 1
253
254 def handle_endtag(self, tag):
255 if self.started:
256 if tag in self.depth: self.depth[tag] -= 1
257 if self.depth[self.result[0]] == 0:
258 self.started = False
259 self.result.append(self.getpos())
260
261 def find_startpos(self, x):
262 """Needed to put the start position of the result (self.result[1])
263 after the opening tag with the requested id"""
264 if self.watch_startpos:
265 self.watch_startpos = False
266 self.result.append(self.getpos())
267 handle_entityref = handle_charref = handle_data = handle_comment = \
268 handle_decl = handle_pi = unknown_decl = find_startpos
269
270 def get_result(self):
271 if self.result is None:
272 return None
273 if len(self.result) != 3:
274 return None
275 lines = self.html.split('\n')
276 lines = lines[self.result[1][0]-1:self.result[2][0]]
277 lines[0] = lines[0][self.result[1][1]:]
278 if len(lines) == 1:
279 lines[-1] = lines[-1][:self.result[2][1]-self.result[1][1]]
280 lines[-1] = lines[-1][:self.result[2][1]]
281 return '\n'.join(lines).strip()
282
283 def get_element_by_id(id, html):
284 """Return the content of the tag with the specified ID in the passed HTML document"""
285 return get_element_by_attribute("id", id, html)
286
287 def get_element_by_attribute(attribute, value, html):
288 """Return the content of the tag with the specified attribute in the passed HTML document"""
289 parser = AttrParser(attribute, value)
290 try:
291 parser.loads(html)
292 except compat_html_parser.HTMLParseError:
293 pass
294 return parser.get_result()
295
296
297 def clean_html(html):
298 """Clean an HTML snippet into a readable string"""
299 # Newline vs <br />
300 html = html.replace('\n', ' ')
301 html = re.sub('\s*<\s*br\s*/?\s*>\s*', '\n', html)
302 # Strip html tags
303 html = re.sub('<.*?>', '', html)
304 # Replace html entities
305 html = unescapeHTML(html)
306 return html
307
308
309 def sanitize_open(filename, open_mode):
310 """Try to open the given filename, and slightly tweak it if this fails.
311
312 Attempts to open the given filename. If this fails, it tries to change
313 the filename slightly, step by step, until it's either able to open it
314 or it fails and raises a final exception, like the standard open()
315 function.
316
317 It returns the tuple (stream, definitive_file_name).
318 """
319 try:
320 if filename == u'-':
321 if sys.platform == 'win32':
322 import msvcrt
323 msvcrt.setmode(sys.stdout.fileno(), os.O_BINARY)
324 return (sys.stdout, filename)
325 stream = open(encodeFilename(filename), open_mode)
326 return (stream, filename)
327 except (IOError, OSError) as err:
328 # In case of error, try to remove win32 forbidden chars
329 filename = re.sub(u'[/<>:"\\|\\\\?\\*]', u'#', filename)
330
331 # An exception here should be caught in the caller
332 stream = open(encodeFilename(filename), open_mode)
333 return (stream, filename)
334
335
336 def timeconvert(timestr):
337 """Convert RFC 2822 defined time string into system timestamp"""
338 timestamp = None
339 timetuple = email.utils.parsedate_tz(timestr)
340 if timetuple is not None:
341 timestamp = email.utils.mktime_tz(timetuple)
342 return timestamp
343
344 def sanitize_filename(s, restricted=False, is_id=False):
345 """Sanitizes a string so it could be used as part of a filename.
346 If restricted is set, use a stricter subset of allowed characters.
347 Set is_id if this is not an arbitrary string, but an ID that should be kept if possible
348 """
349 def replace_insane(char):
350 if char == '?' or ord(char) < 32 or ord(char) == 127:
351 return ''
352 elif char == '"':
353 return '' if restricted else '\''
354 elif char == ':':
355 return '_-' if restricted else ' -'
356 elif char in '\\/|*<>':
357 return '_'
358 if restricted and (char in '!&\'()[]{}$;`^,#' or char.isspace()):
359 return '_'
360 if restricted and ord(char) > 127:
361 return '_'
362 return char
363
364 result = u''.join(map(replace_insane, s))
365 if not is_id:
366 while '__' in result:
367 result = result.replace('__', '_')
368 result = result.strip('_')
369 # Common case of "Foreign band name - English song title"
370 if restricted and result.startswith('-_'):
371 result = result[2:]
372 if not result:
373 result = '_'
374 return result
375
376 def orderedSet(iterable):
377 """ Remove all duplicates from the input iterable """
378 res = []
379 for el in iterable:
380 if el not in res:
381 res.append(el)
382 return res
383
384 def unescapeHTML(s):
385 """
386 @param s a string
387 """
388 assert type(s) == type(u'')
389
390 result = re.sub(u'(?u)&(.+?);', htmlentity_transform, s)
391 return result
392
393 def encodeFilename(s):
394 """
395 @param s The name of the file
396 """
397
398 assert type(s) == type(u'')
399
400 # Python 3 has a Unicode API
401 if sys.version_info >= (3, 0):
402 return s
403
404 if sys.platform == 'win32' and sys.getwindowsversion()[0] >= 5:
405 # Pass u'' directly to use Unicode APIs on Windows 2000 and up
406 # (Detecting Windows NT 4 is tricky because 'major >= 4' would
407 # match Windows 9x series as well. Besides, NT 4 is obsolete.)
408 return s
409 else:
410 return s.encode(sys.getfilesystemencoding(), 'ignore')
411
412 class DownloadError(Exception):
413 """Download Error exception.
414
415 This exception may be thrown by FileDownloader objects if they are not
416 configured to continue on errors. They will contain the appropriate
417 error message.
418 """
419 pass
420
421
422 class SameFileError(Exception):
423 """Same File exception.
424
425 This exception will be thrown by FileDownloader objects if they detect
426 multiple files would have to be downloaded to the same file on disk.
427 """
428 pass
429
430
431 class PostProcessingError(Exception):
432 """Post Processing exception.
433
434 This exception may be raised by PostProcessor's .run() method to
435 indicate an error in the postprocessing task.
436 """
437 pass
438
439 class MaxDownloadsReached(Exception):
440 """ --max-downloads limit has been reached. """
441 pass
442
443
444 class UnavailableVideoError(Exception):
445 """Unavailable Format exception.
446
447 This exception will be thrown when a video is requested
448 in a format that is not available for that video.
449 """
450 pass
451
452
453 class ContentTooShortError(Exception):
454 """Content Too Short exception.
455
456 This exception may be raised by FileDownloader objects when a file they
457 download is too small for what the server announced first, indicating
458 the connection was probably interrupted.
459 """
460 # Both in bytes
461 downloaded = None
462 expected = None
463
464 def __init__(self, downloaded, expected):
465 self.downloaded = downloaded
466 self.expected = expected
467
468
469 class Trouble(Exception):
470 """Trouble helper exception
471
472 This is an exception to be handled with
473 FileDownloader.trouble
474 """
475
476 class YoutubeDLHandler(compat_urllib_request.HTTPHandler):
477 """Handler for HTTP requests and responses.
478
479 This class, when installed with an OpenerDirector, automatically adds
480 the standard headers to every HTTP request and handles gzipped and
481 deflated responses from web servers. If compression is to be avoided in
482 a particular request, the original request in the program code only has
483 to include the HTTP header "Youtubedl-No-Compression", which will be
484 removed before making the real request.
485
486 Part of this code was copied from:
487
488 http://techknack.net/python-urllib2-handlers/
489
490 Andrew Rowls, the author of that code, agreed to release it to the
491 public domain.
492 """
493
494 @staticmethod
495 def deflate(data):
496 try:
497 return zlib.decompress(data, -zlib.MAX_WBITS)
498 except zlib.error:
499 return zlib.decompress(data)
500
501 @staticmethod
502 def addinfourl_wrapper(stream, headers, url, code):
503 if hasattr(compat_urllib_request.addinfourl, 'getcode'):
504 return compat_urllib_request.addinfourl(stream, headers, url, code)
505 ret = compat_urllib_request.addinfourl(stream, headers, url)
506 ret.code = code
507 return ret
508
509 def http_request(self, req):
510 for h in std_headers:
511 if h in req.headers:
512 del req.headers[h]
513 req.add_header(h, std_headers[h])
514 if 'Youtubedl-no-compression' in req.headers:
515 if 'Accept-encoding' in req.headers:
516 del req.headers['Accept-encoding']
517 del req.headers['Youtubedl-no-compression']
518 return req
519
520 def http_response(self, req, resp):
521 old_resp = resp
522 # gzip
523 if resp.headers.get('Content-encoding', '') == 'gzip':
524 gz = gzip.GzipFile(fileobj=io.BytesIO(resp.read()), mode='r')
525 resp = self.addinfourl_wrapper(gz, old_resp.headers, old_resp.url, old_resp.code)
526 resp.msg = old_resp.msg
527 # deflate
528 if resp.headers.get('Content-encoding', '') == 'deflate':
529 gz = io.BytesIO(self.deflate(resp.read()))
530 resp = self.addinfourl_wrapper(gz, old_resp.headers, old_resp.url, old_resp.code)
531 resp.msg = old_resp.msg
532 return resp
533
534 https_request = http_request
535 https_response = http_response