mirror of
https://github.com/CCOSTAN/Home-AssistantConfig.git
synced 2025-02-22 09:01:22 +00:00
313 lines
9.0 KiB
Python
313 lines
9.0 KiB
Python
from __future__ import (
|
|
absolute_import, print_function, division, unicode_literals
|
|
)
|
|
|
|
import inspect
|
|
import json as json_module
|
|
import re
|
|
import six
|
|
|
|
from collections import namedtuple, Sequence, Sized
|
|
from functools import update_wrapper
|
|
from cookies import Cookies
|
|
from requests.utils import cookiejar_from_dict
|
|
from requests.exceptions import ConnectionError
|
|
from requests.sessions import REDIRECT_STATI
|
|
|
|
try:
|
|
from requests.packages.urllib3.response import HTTPResponse
|
|
except ImportError:
|
|
from urllib3.response import HTTPResponse
|
|
|
|
if six.PY2:
|
|
from urlparse import urlparse, parse_qsl
|
|
else:
|
|
from urllib.parse import urlparse, parse_qsl
|
|
|
|
if six.PY2:
|
|
try:
|
|
from six import cStringIO as BufferIO
|
|
except ImportError:
|
|
from six import StringIO as BufferIO
|
|
else:
|
|
from io import BytesIO as BufferIO
|
|
|
|
|
|
Call = namedtuple('Call', ['request', 'response'])
|
|
|
|
_wrapper_template = """\
|
|
def wrapper%(signature)s:
|
|
with responses:
|
|
return func%(funcargs)s
|
|
"""
|
|
|
|
|
|
def _is_string(s):
|
|
return isinstance(s, (six.string_types, six.text_type))
|
|
|
|
|
|
def _is_redirect(response):
|
|
try:
|
|
# 2.0.0 <= requests <= 2.2
|
|
return response.is_redirect
|
|
except AttributeError:
|
|
# requests > 2.2
|
|
return (
|
|
# use request.sessions conditional
|
|
response.status_code in REDIRECT_STATI and
|
|
'location' in response.headers
|
|
)
|
|
|
|
|
|
def get_wrapped(func, wrapper_template, evaldict):
|
|
# Preserve the argspec for the wrapped function so that testing
|
|
# tools such as pytest can continue to use their fixture injection.
|
|
args, a, kw, defaults = inspect.getargspec(func)
|
|
|
|
signature = inspect.formatargspec(args, a, kw, defaults)
|
|
is_bound_method = hasattr(func, '__self__')
|
|
if is_bound_method:
|
|
args = args[1:] # Omit 'self'
|
|
callargs = inspect.formatargspec(args, a, kw, None)
|
|
|
|
ctx = {'signature': signature, 'funcargs': callargs}
|
|
six.exec_(wrapper_template % ctx, evaldict)
|
|
|
|
wrapper = evaldict['wrapper']
|
|
|
|
update_wrapper(wrapper, func)
|
|
if is_bound_method:
|
|
wrapper = wrapper.__get__(func.__self__, type(func.__self__))
|
|
return wrapper
|
|
|
|
|
|
class CallList(Sequence, Sized):
|
|
def __init__(self):
|
|
self._calls = []
|
|
|
|
def __iter__(self):
|
|
return iter(self._calls)
|
|
|
|
def __len__(self):
|
|
return len(self._calls)
|
|
|
|
def __getitem__(self, idx):
|
|
return self._calls[idx]
|
|
|
|
def add(self, request, response):
|
|
self._calls.append(Call(request, response))
|
|
|
|
def reset(self):
|
|
self._calls = []
|
|
|
|
|
|
def _ensure_url_default_path(url, match_querystring):
|
|
if _is_string(url) and url.count('/') == 2:
|
|
if match_querystring:
|
|
return url.replace('?', '/?', 1)
|
|
else:
|
|
return url + '/'
|
|
return url
|
|
|
|
|
|
class RequestsMock(object):
|
|
DELETE = 'DELETE'
|
|
GET = 'GET'
|
|
HEAD = 'HEAD'
|
|
OPTIONS = 'OPTIONS'
|
|
PATCH = 'PATCH'
|
|
POST = 'POST'
|
|
PUT = 'PUT'
|
|
|
|
def __init__(self, assert_all_requests_are_fired=True):
|
|
self._calls = CallList()
|
|
self.reset()
|
|
self.assert_all_requests_are_fired = assert_all_requests_are_fired
|
|
|
|
def reset(self):
|
|
self._urls = []
|
|
self._calls.reset()
|
|
|
|
def add(self, method, url, body='', match_querystring=False,
|
|
status=200, adding_headers=None, stream=False,
|
|
content_type='text/plain', json=None):
|
|
|
|
# if we were passed a `json` argument,
|
|
# override the body and content_type
|
|
if json:
|
|
body = json_module.dumps(json)
|
|
content_type = 'application/json'
|
|
|
|
# ensure the url has a default path set if the url is a string
|
|
url = _ensure_url_default_path(url, match_querystring)
|
|
|
|
# body must be bytes
|
|
if isinstance(body, six.text_type):
|
|
body = body.encode('utf-8')
|
|
|
|
self._urls.append({
|
|
'url': url,
|
|
'method': method,
|
|
'body': body,
|
|
'content_type': content_type,
|
|
'match_querystring': match_querystring,
|
|
'status': status,
|
|
'adding_headers': adding_headers,
|
|
'stream': stream,
|
|
})
|
|
|
|
def add_callback(self, method, url, callback, match_querystring=False,
|
|
content_type='text/plain'):
|
|
# ensure the url has a default path set if the url is a string
|
|
# url = _ensure_url_default_path(url, match_querystring)
|
|
|
|
self._urls.append({
|
|
'url': url,
|
|
'method': method,
|
|
'callback': callback,
|
|
'content_type': content_type,
|
|
'match_querystring': match_querystring,
|
|
})
|
|
|
|
@property
|
|
def calls(self):
|
|
return self._calls
|
|
|
|
def __enter__(self):
|
|
self.start()
|
|
return self
|
|
|
|
def __exit__(self, *args):
|
|
self.stop()
|
|
self.reset()
|
|
|
|
def activate(self, func):
|
|
evaldict = {'responses': self, 'func': func}
|
|
return get_wrapped(func, _wrapper_template, evaldict)
|
|
|
|
def _find_match(self, request):
|
|
for match in self._urls:
|
|
if request.method != match['method']:
|
|
continue
|
|
|
|
if not self._has_url_match(match, request.url):
|
|
continue
|
|
|
|
break
|
|
else:
|
|
return None
|
|
if self.assert_all_requests_are_fired:
|
|
# for each found match remove the url from the stack
|
|
self._urls.remove(match)
|
|
return match
|
|
|
|
def _has_url_match(self, match, request_url):
|
|
url = match['url']
|
|
|
|
if _is_string(url):
|
|
if match['match_querystring']:
|
|
return self._has_strict_url_match(url, request_url)
|
|
else:
|
|
url_without_qs = request_url.split('?', 1)[0]
|
|
return url == url_without_qs
|
|
elif isinstance(url, re._pattern_type) and url.match(request_url):
|
|
return True
|
|
else:
|
|
return False
|
|
|
|
def _has_strict_url_match(self, url, other):
|
|
url_parsed = urlparse(url)
|
|
other_parsed = urlparse(other)
|
|
|
|
if url_parsed[:3] != other_parsed[:3]:
|
|
return False
|
|
|
|
url_qsl = sorted(parse_qsl(url_parsed.query))
|
|
other_qsl = sorted(parse_qsl(other_parsed.query))
|
|
return url_qsl == other_qsl
|
|
|
|
def _on_request(self, adapter, request, **kwargs):
|
|
match = self._find_match(request)
|
|
# TODO(dcramer): find the correct class for this
|
|
if match is None:
|
|
error_msg = 'Connection refused: {0} {1}'.format(request.method,
|
|
request.url)
|
|
response = ConnectionError(error_msg)
|
|
response.request = request
|
|
|
|
self._calls.add(request, response)
|
|
raise response
|
|
|
|
if 'body' in match and isinstance(match['body'], Exception):
|
|
self._calls.add(request, match['body'])
|
|
raise match['body']
|
|
|
|
headers = {
|
|
'Content-Type': match['content_type'],
|
|
}
|
|
|
|
if 'callback' in match: # use callback
|
|
status, r_headers, body = match['callback'](request)
|
|
if isinstance(body, six.text_type):
|
|
body = body.encode('utf-8')
|
|
body = BufferIO(body)
|
|
headers.update(r_headers)
|
|
|
|
elif 'body' in match:
|
|
if match['adding_headers']:
|
|
headers.update(match['adding_headers'])
|
|
status = match['status']
|
|
body = BufferIO(match['body'])
|
|
|
|
response = HTTPResponse(
|
|
status=status,
|
|
body=body,
|
|
headers=headers,
|
|
preload_content=False,
|
|
)
|
|
|
|
response = adapter.build_response(request, response)
|
|
if not match.get('stream'):
|
|
response.content # NOQA
|
|
|
|
try:
|
|
resp_cookies = Cookies.from_request(response.headers['set-cookie'])
|
|
response.cookies = cookiejar_from_dict(dict(
|
|
(v.name, v.value)
|
|
for _, v
|
|
in resp_cookies.items()
|
|
))
|
|
except (KeyError, TypeError):
|
|
pass
|
|
|
|
self._calls.add(request, response)
|
|
|
|
return response
|
|
|
|
def start(self):
|
|
try:
|
|
from unittest import mock
|
|
except ImportError:
|
|
import mock
|
|
|
|
def unbound_on_send(adapter, request, *a, **kwargs):
|
|
return self._on_request(adapter, request, *a, **kwargs)
|
|
self._patcher = mock.patch('requests.adapters.HTTPAdapter.send',
|
|
unbound_on_send)
|
|
self._patcher.start()
|
|
|
|
def stop(self):
|
|
self._patcher.stop()
|
|
if self.assert_all_requests_are_fired and self._urls:
|
|
raise AssertionError(
|
|
'Not all requests have been executed {0!r}'.format(
|
|
[(url['method'], url['url']) for url in self._urls]))
|
|
|
|
|
|
# expose default mock namespace
|
|
mock = _default_mock = RequestsMock(assert_all_requests_are_fired=False)
|
|
__all__ = []
|
|
for __attr in (a for a in dir(_default_mock) if not a.startswith('_')):
|
|
__all__.append(__attr)
|
|
globals()[__attr] = getattr(_default_mock, __attr)
|