Skip to content
This repository has been archived by the owner on Apr 14, 2022. It is now read-only.

Commit

Permalink
Add support for async files as request body
Browse files Browse the repository at this point in the history
  • Loading branch information
sethmlarson authored and pquentin committed May 11, 2020
1 parent d2268e5 commit 23ce1f6
Show file tree
Hide file tree
Showing 12 changed files with 354 additions and 173 deletions.
13 changes: 12 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,17 @@
setup(
version=version,
cmdclass={
"build_py": unasync.cmdclass_build_py(rules=[unasync.Rule("/ahip/", "/hip/")])
"build_py": unasync.cmdclass_build_py(
rules=[
unasync.Rule(
"/ahip/",
"/hip/",
additional_replacements={
"anext": "next",
"await_if_coro": "return_non_coro",
},
)
]
)
},
)
2 changes: 1 addition & 1 deletion src/ahip/_backends/anyio_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ async def start_tls(self, server_hostname, ssl_context: SSLContext):
def getpeercert(self, binary_form=False):
return self._stream.getpeercert(binary_form=binary_form)

async def receive_some(self):
async def receive_some(self, read_timeout):
return await self._stream.receive_some(BUFSIZE)

async def send_and_receive_for_a_while(
Expand Down
147 changes: 74 additions & 73 deletions src/ahip/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

import collections
import datetime
import itertools
import socket
import warnings

Expand All @@ -35,6 +34,7 @@
)
from .packages import six
from .util import ssl_ as ssl_util
from .util.unasync import await_if_coro, anext, ASYNC_MODE
from ._backends._common import LoopAbort
from ._backends._loader import load_backend, normalize_backend

Expand All @@ -44,24 +44,6 @@
ssl = None


def is_async_mode():
"""Tests if we're in the async part of the code or not"""

async def f():
"""Unasync transforms async functions in sync functions"""
return None

obj = f()
if obj is None:
return False
else:
obj.close() # prevent unawaited coroutine warning
return True


_ASYNC_MODE = is_async_mode()


# When it comes time to update this value as a part of regular maintenance
# (ie test_recent_date is failing) update it to ~6 months before the current date.
RECENT_DATE = datetime.date(2019, 1, 1)
Expand Down Expand Up @@ -106,17 +88,16 @@ def _stringify_headers(headers):
yield (name, value)


def _read_readable(readable):
async def _read_readable(readable):
# TODO: reconsider this block size
blocksize = 8192
while True:
datablock = readable.read(blocksize)
datablock = await await_if_coro(readable.read(blocksize))
if not datablock:
break
yield datablock


# XX this should return an async iterator
def _make_body_iterable(body):
"""
This function turns all possible body types that Hip supports into an
Expand All @@ -134,63 +115,83 @@ def _make_body_iterable(body):
is deliberate: users must make choices about the encoding of the data they
use.
"""
if body is None:
return []
elif isinstance(body, bytes):
return [body]
elif hasattr(body, "read"):
return _read_readable(body)
elif isinstance(body, collections.Iterable) and not isinstance(body, six.text_type):
return body
else:
raise InvalidBodyError("Unacceptable body type: %s" % type(body))

async def generator():
if body is None:
return
elif isinstance(body, bytes):
yield body
elif hasattr(body, "read"):
async for chunk in _read_readable(body):
yield chunk
elif isinstance(body, collections.Iterable) and not isinstance(
body, six.text_type
):
for chunk in body:
yield chunk
else:
raise InvalidBodyError("Unacceptable body type: %s" % type(body))

return generator().__aiter__()


# XX this should return an async iterator
def _request_bytes_iterable(request, state_machine):
"""
An iterable that serialises a set of bytes for the body.
"""

def all_pieces_iter():
h11_request = h11.Request(
method=request.method,
target=request.target,
headers=_stringify_headers(request.headers.items()),
async def generator():
h11_request = h11.Request(
method=request.method,
target=request.target,
headers=_stringify_headers(request.headers.items()),
)
yield state_machine.send(h11_request)

async for chunk in _make_body_iterable(request.body):
yield state_machine.send(h11.Data(data=chunk))

yield state_machine.send(h11.EndOfMessage())

return generator().__aiter__()

async def generator():

# Try to combine the header bytes + (first set of body bytes or end of
# message bytes) into one packet.
# As long as all_pieces_iter() yields at least two messages, this should
# never raise StopIteration.
remaining_pieces = all_pieces_iter()
first_packet_bytes = (await anext(remaining_pieces)) + (
await anext(remaining_pieces)
)
yield state_machine.send(h11_request)

for chunk in _make_body_iterable(request.body):
yield state_machine.send(h11.Data(data=chunk))

yield state_machine.send(h11.EndOfMessage())

# Try to combine the header bytes + (first set of body bytes or end of
# message bytes) into one packet.
# As long as all_pieces_iter() yields at least two messages, this should
# never raise StopIteration.
remaining_pieces = all_pieces_iter()
first_packet_bytes = next(remaining_pieces) + next(remaining_pieces)
all_pieces_combined_iter = itertools.chain([first_packet_bytes], remaining_pieces)

# We filter out any empty strings, because we don't want to call
# send(b""). You might think this is a no-op, so it shouldn't matter
# either way. But this isn't true. For example, if we're sending a request
# with Content-Length framing, we could have this sequence:
#
# - We send the last Data event.
# - The peer immediately sends its response and closes the socket.
# - We attempt to send the EndOfMessage event, which (b/c this request has
# Content-Length framing) is encoded as b"".
# - We call send(b"").
# - This triggers the kernel / SSL layer to discover that the socket is
# closed, so it raises an exception.
#
# It's easier to fix this once here instead of worrying about it in all
# the different backends.
for piece in all_pieces_combined_iter:
if piece:
yield piece

async def all_pieces_combined_iter():
yield first_packet_bytes
async for piece in remaining_pieces:
yield piece

# We filter out any empty strings, because we don't want to call
# send(b""). You might think this is a no-op, so it shouldn't matter
# either way. But this isn't true. For example, if we're sending a request
# with Content-Length framing, we could have this sequence:
#
# - We send the last Data event.
# - The peer immediately sends its response and closes the socket.
# - We attempt to send the EndOfMessage event, which (b/c this request has
# Content-Length framing) is encoded as b"".
# - We call send(b"").
# - This triggers the kernel / SSL layer to discover that the socket is
# closed, so it raises an exception.
#
# It's easier to fix this once here instead of worrying about it in all
# the different backends.
async for piece in all_pieces_combined_iter():
if piece:
yield piece

return generator().__aiter__()


def _response_from_h11(h11_response, body_object):
Expand Down Expand Up @@ -259,8 +260,8 @@ async def _start_http_request(request, state_machine, sock, read_timeout=None):

async def produce_bytes():
try:
return next(request_bytes_iterable)
except StopIteration:
return await anext(request_bytes_iterable)
except StopAsyncIteration:
# We successfully sent the whole body!
context["send_aborted"] = False
return None
Expand Down Expand Up @@ -346,7 +347,7 @@ def __init__(
):
self.is_verified = False
self.read_timeout = None
self._backend = load_backend(normalize_backend(backend, _ASYNC_MODE))
self._backend = load_backend(normalize_backend(backend, ASYNC_MODE))
self._host = host
self._port = port
self._socket_options = (
Expand Down
2 changes: 1 addition & 1 deletion src/ahip/connectionpool.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,7 +620,7 @@ async def urlopen(

# Rewind body position, if needed. Record current position
# for future rewinds in the event of a redirect/retry.
body_pos = set_file_position(body, body_pos)
body_pos = await set_file_position(body, body_pos)

if body is not None:
_add_transport_headers(headers)
Expand Down
2 changes: 1 addition & 1 deletion src/ahip/poolmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ async def urlopen(self, method, url, redirect=True, **kw):
# for future rewinds in the event of a redirect/retry.
body = kw.get("body")
body_pos = kw.get("body_pos")
kw["body_pos"] = set_file_position(body, body_pos)
kw["body_pos"] = await set_file_position(body, body_pos)

if "headers" not in kw:
kw["headers"] = self.headers.copy()
Expand Down
11 changes: 6 additions & 5 deletions src/ahip/util/request.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import absolute_import
from base64 import b64encode

from .unasync import await_if_coro
from ..packages.six import b, integer_types
from ..exceptions import UnrewindableBodyError

Expand Down Expand Up @@ -87,16 +88,16 @@ def make_headers(
return headers


def set_file_position(body, pos):
async def set_file_position(body, pos):
"""
If a position is provided, move file to that point.
Otherwise, we'll attempt to record a position for future use.
"""
if pos is not None:
rewind_body(body, pos)
await rewind_body(body, pos)
elif getattr(body, "tell", None) is not None:
try:
pos = body.tell()
pos = await await_if_coro(body.tell())
except (IOError, OSError):
# This differentiates from None, allowing us to catch
# a failed `tell()` later when trying to rewind the body.
Expand All @@ -105,7 +106,7 @@ def set_file_position(body, pos):
return pos


def rewind_body(body, body_pos):
async def rewind_body(body, body_pos):
"""
Attempt to rewind body to a certain position.
Primarily used for request redirects and retries.
Expand All @@ -119,7 +120,7 @@ def rewind_body(body, body_pos):
body_seek = getattr(body, "seek", None)
if body_seek is not None and isinstance(body_pos, integer_types):
try:
body_seek(body_pos)
await await_if_coro(body_seek(body_pos))
except (IOError, OSError):
raise UnrewindableBodyError(
"An error occurred when rewinding request body for redirect/retry."
Expand Down
40 changes: 40 additions & 0 deletions src/ahip/util/unasync.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""Set of utility functions for unasync that transform into sync counterparts cleanly"""

import inspect

_original_next = next


def is_async_mode():
"""Tests if we're in the async part of the code or not"""

async def f():
"""Unasync transforms async functions in sync functions"""
return None

obj = f()
if obj is None:
return False
else:
obj.close() # prevent unawaited coroutine warning
return True


ASYNC_MODE = is_async_mode()


async def anext(x):
return await x.__anext__()


async def await_if_coro(x):
if inspect.iscoroutine(x):
return await x
return x


next = _original_next


def return_non_coro(x):
return x
2 changes: 1 addition & 1 deletion test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

# We support Python 3.6+ for async code
if sys.version_info[:2] < (3, 6):
collect_ignore_glob = ["async/*.py", "with_dummyserver/async/*.py"]
collect_ignore_glob = ["async/*.py", "with_dummyserver/async*/*.py"]

# The Python 3.8+ default loop on Windows breaks Tornado
@pytest.fixture(scope="session", autouse=True)
Expand Down
Loading

0 comments on commit 23ce1f6

Please sign in to comment.