Home-AssistantConfig/deps/responses.py
2016-10-11 16:42:06 +00:00

313 lines
9.0 KiB
Python

from __future__ import (
absolute_import, print_function, division, unicode_literals
)
import inspect
import json as json_module
import re
import six
from collections import namedtuple, Sequence, Sized
from functools import update_wrapper
from cookies import Cookies
from requests.utils import cookiejar_from_dict
from requests.exceptions import ConnectionError
from requests.sessions import REDIRECT_STATI
try:
from requests.packages.urllib3.response import HTTPResponse
except ImportError:
from urllib3.response import HTTPResponse
if six.PY2:
from urlparse import urlparse, parse_qsl
else:
from urllib.parse import urlparse, parse_qsl
if six.PY2:
try:
from six import cStringIO as BufferIO
except ImportError:
from six import StringIO as BufferIO
else:
from io import BytesIO as BufferIO
Call = namedtuple('Call', ['request', 'response'])
_wrapper_template = """\
def wrapper%(signature)s:
with responses:
return func%(funcargs)s
"""
def _is_string(s):
return isinstance(s, (six.string_types, six.text_type))
def _is_redirect(response):
try:
# 2.0.0 <= requests <= 2.2
return response.is_redirect
except AttributeError:
# requests > 2.2
return (
# use request.sessions conditional
response.status_code in REDIRECT_STATI and
'location' in response.headers
)
def get_wrapped(func, wrapper_template, evaldict):
# Preserve the argspec for the wrapped function so that testing
# tools such as pytest can continue to use their fixture injection.
args, a, kw, defaults = inspect.getargspec(func)
signature = inspect.formatargspec(args, a, kw, defaults)
is_bound_method = hasattr(func, '__self__')
if is_bound_method:
args = args[1:] # Omit 'self'
callargs = inspect.formatargspec(args, a, kw, None)
ctx = {'signature': signature, 'funcargs': callargs}
six.exec_(wrapper_template % ctx, evaldict)
wrapper = evaldict['wrapper']
update_wrapper(wrapper, func)
if is_bound_method:
wrapper = wrapper.__get__(func.__self__, type(func.__self__))
return wrapper
class CallList(Sequence, Sized):
def __init__(self):
self._calls = []
def __iter__(self):
return iter(self._calls)
def __len__(self):
return len(self._calls)
def __getitem__(self, idx):
return self._calls[idx]
def add(self, request, response):
self._calls.append(Call(request, response))
def reset(self):
self._calls = []
def _ensure_url_default_path(url, match_querystring):
if _is_string(url) and url.count('/') == 2:
if match_querystring:
return url.replace('?', '/?', 1)
else:
return url + '/'
return url
class RequestsMock(object):
DELETE = 'DELETE'
GET = 'GET'
HEAD = 'HEAD'
OPTIONS = 'OPTIONS'
PATCH = 'PATCH'
POST = 'POST'
PUT = 'PUT'
def __init__(self, assert_all_requests_are_fired=True):
self._calls = CallList()
self.reset()
self.assert_all_requests_are_fired = assert_all_requests_are_fired
def reset(self):
self._urls = []
self._calls.reset()
def add(self, method, url, body='', match_querystring=False,
status=200, adding_headers=None, stream=False,
content_type='text/plain', json=None):
# if we were passed a `json` argument,
# override the body and content_type
if json:
body = json_module.dumps(json)
content_type = 'application/json'
# ensure the url has a default path set if the url is a string
url = _ensure_url_default_path(url, match_querystring)
# body must be bytes
if isinstance(body, six.text_type):
body = body.encode('utf-8')
self._urls.append({
'url': url,
'method': method,
'body': body,
'content_type': content_type,
'match_querystring': match_querystring,
'status': status,
'adding_headers': adding_headers,
'stream': stream,
})
def add_callback(self, method, url, callback, match_querystring=False,
content_type='text/plain'):
# ensure the url has a default path set if the url is a string
# url = _ensure_url_default_path(url, match_querystring)
self._urls.append({
'url': url,
'method': method,
'callback': callback,
'content_type': content_type,
'match_querystring': match_querystring,
})
@property
def calls(self):
return self._calls
def __enter__(self):
self.start()
return self
def __exit__(self, *args):
self.stop()
self.reset()
def activate(self, func):
evaldict = {'responses': self, 'func': func}
return get_wrapped(func, _wrapper_template, evaldict)
def _find_match(self, request):
for match in self._urls:
if request.method != match['method']:
continue
if not self._has_url_match(match, request.url):
continue
break
else:
return None
if self.assert_all_requests_are_fired:
# for each found match remove the url from the stack
self._urls.remove(match)
return match
def _has_url_match(self, match, request_url):
url = match['url']
if _is_string(url):
if match['match_querystring']:
return self._has_strict_url_match(url, request_url)
else:
url_without_qs = request_url.split('?', 1)[0]
return url == url_without_qs
elif isinstance(url, re._pattern_type) and url.match(request_url):
return True
else:
return False
def _has_strict_url_match(self, url, other):
url_parsed = urlparse(url)
other_parsed = urlparse(other)
if url_parsed[:3] != other_parsed[:3]:
return False
url_qsl = sorted(parse_qsl(url_parsed.query))
other_qsl = sorted(parse_qsl(other_parsed.query))
return url_qsl == other_qsl
def _on_request(self, adapter, request, **kwargs):
match = self._find_match(request)
# TODO(dcramer): find the correct class for this
if match is None:
error_msg = 'Connection refused: {0} {1}'.format(request.method,
request.url)
response = ConnectionError(error_msg)
response.request = request
self._calls.add(request, response)
raise response
if 'body' in match and isinstance(match['body'], Exception):
self._calls.add(request, match['body'])
raise match['body']
headers = {
'Content-Type': match['content_type'],
}
if 'callback' in match: # use callback
status, r_headers, body = match['callback'](request)
if isinstance(body, six.text_type):
body = body.encode('utf-8')
body = BufferIO(body)
headers.update(r_headers)
elif 'body' in match:
if match['adding_headers']:
headers.update(match['adding_headers'])
status = match['status']
body = BufferIO(match['body'])
response = HTTPResponse(
status=status,
body=body,
headers=headers,
preload_content=False,
)
response = adapter.build_response(request, response)
if not match.get('stream'):
response.content # NOQA
try:
resp_cookies = Cookies.from_request(response.headers['set-cookie'])
response.cookies = cookiejar_from_dict(dict(
(v.name, v.value)
for _, v
in resp_cookies.items()
))
except (KeyError, TypeError):
pass
self._calls.add(request, response)
return response
def start(self):
try:
from unittest import mock
except ImportError:
import mock
def unbound_on_send(adapter, request, *a, **kwargs):
return self._on_request(adapter, request, *a, **kwargs)
self._patcher = mock.patch('requests.adapters.HTTPAdapter.send',
unbound_on_send)
self._patcher.start()
def stop(self):
self._patcher.stop()
if self.assert_all_requests_are_fired and self._urls:
raise AssertionError(
'Not all requests have been executed {0!r}'.format(
[(url['method'], url['url']) for url in self._urls]))
# expose default mock namespace
mock = _default_mock = RequestsMock(assert_all_requests_are_fired=False)
__all__ = []
for __attr in (a for a in dir(_default_mock) if not a.startswith('_')):
__all__.append(__attr)
globals()[__attr] = getattr(_default_mock, __attr)