diff --git a/asks/__init__.py b/asks/__init__.py index a51c36c..f5d485c 100644 --- a/asks/__init__.py +++ b/asks/__init__.py @@ -1,14 +1,15 @@ # pylint: disable=wildcard-import # pylint: disable=wrong-import-position -from .auth import * -from .base_funcs import * -from .sessions import * - +from typing import Any from warnings import warn +from .auth import * # noqa +from .base_funcs import * # noqa +from .sessions import * # noqa + -def init(library): +def init(library: Any) -> None: """ Unused. asks+anyio auto-detects your library. """ diff --git a/asks/auth.py b/asks/auth.py index f14139c..3a207b3 100644 --- a/asks/auth.py +++ b/asks/auth.py @@ -1,13 +1,11 @@ # pylint: disable=abstract-method -from abc import abstractmethod, ABCMeta - -import re - import base64 +import re +from abc import ABCMeta, abstractmethod from hashlib import md5 from random import choice from string import ascii_lowercase, digits - +from typing import Any __all__ = ["AuthBase", "PreResponseAuth", "PostResponseAuth", "BasicAuth", "DigestAuth"] @@ -20,7 +18,11 @@ class AuthBase(metaclass=ABCMeta): """ @abstractmethod - def __call__(self): + async def __call__( + self, + response_obj: "response_objects.Response", + req_obj: "request_object.RequestProcessor", + ) -> dict[str, str]: """Not Implemented""" @@ -37,7 +39,9 @@ class PostResponseAuth(AuthBase): Auth class for response dependant auth. """ - def __init__(self): + auth_attempted: bool + + def __init__(self) -> None: self.auth_attempted = False @@ -46,11 +50,15 @@ class BasicAuth(PreResponseAuth): Ye Olde Basic HTTP Auth. """ - def __init__(self, auth_info, encoding="utf-8"): + def __init__(self, auth_info: Any, encoding: str = "utf-8"): self.auth_info = auth_info self.encoding = encoding - async def __call__(self, _): + async def __call__( + self, + response_obj: "response_objects.Response", + req_obj: "request_object.RequestProcessor", + ) -> dict[str, str]: usrname, psword = [bytes(x, self.encoding) for x in self.auth_info] encoded_auth = str(base64.b64encode(usrname + b":" + psword), self.encoding) return {"Authorization": "Basic {}".format(encoded_auth)} @@ -69,7 +77,9 @@ class DigestAuth(PostResponseAuth): _HDR_VAL_PARSE = re.compile(r'\b(\w+)=(?:"([^\\"]+)"|(\S+))') - def __init__(self, auth_info, encoding="utf-8"): + domain_space: list[Any] + + def __init__(self, auth_info: Any, encoding: str = "utf-8") -> None: super().__init__() self.auth_info = auth_info self.encoding = encoding @@ -77,7 +87,11 @@ def __init__(self, auth_info, encoding="utf-8"): self.nonce = None self.nonce_count = 1 - async def __call__(self, response_obj, req_obj): + async def __call__( + self, + response_obj: "response_objects.Response", + req_obj: "request_object.RequestProcessor", + ) -> dict[str, str]: usrname, psword = [bytes(x, self.encoding) for x in self.auth_info] try: @@ -128,10 +142,14 @@ async def __call__(self, response_obj, req_obj): self.encoding, ) - bytes_path = bytes(req_obj.path, self.encoding) + bytes_path = bytes(req_obj.path or "", self.encoding) bytes_method = bytes(req_obj.method, self.encoding) try: if b"auth-int" in auth_dict["qop"].lower(): + + if isinstance(response_obj.raw, str): + raise ValueError("response_obj.raw is str when it shouldn't") + hashed_body = bytes( md5(response_obj.raw or b"").hexdigest(), self.encoding ) @@ -160,7 +178,7 @@ async def __call__(self, response_obj, req_obj): auth_dict["nonce"], bytes_nc, cnonce, - bytes(qop, self.encoding), + bytes(qop or "", self.encoding), ha2, ) ) @@ -190,3 +208,7 @@ async def __call__(self, response_obj, req_obj): ] ) return {"Authorization": "Digest {}".format(", ".join(response_items))} + + +from . import request_object # noqa +from . import response_objects # noqa diff --git a/asks/base_funcs.py b/asks/base_funcs.py index 3b599f6..df02dad 100644 --- a/asks/base_funcs.py +++ b/asks/base_funcs.py @@ -4,14 +4,15 @@ to the caller. """ from functools import partial +from typing import Any, Union +from .response_objects import Response, StreamResponse from .sessions import Session - __all__ = ["get", "head", "post", "put", "delete", "options", "patch", "request"] -async def request(method, uri, **kwargs): +async def request(method: str, uri: str, **kwargs: Any) -> Union[Response, StreamResponse]: """Base function for one time http requests. Args: diff --git a/asks/cookie_utils.py b/asks/cookie_utils.py index cfb2e64..aa28fac 100644 --- a/asks/cookie_utils.py +++ b/asks/cookie_utils.py @@ -1,7 +1,11 @@ __all__ = ["CookieTracker", "parse_cookies"] -from .response_objects import Cookie +from typing import Any, Optional, Union + +from .response_objects import BaseResponse, Cookie + +_CookiesToSend = dict[Optional[str], Optional[str]] class CookieTracker: @@ -10,21 +14,21 @@ class CookieTracker: the otherwise stateless general http method functions. """ - def __init__(self): - self.domain_dict = {} + def __init__(self) -> None: + self.domain_dict: dict[str, list[Cookie]] = {} - def get_additional_cookies(self, netloc, path): + def get_additional_cookies(self, netloc: str, path: str) -> _CookiesToSend: netloc = netloc.replace("://www.", "://", 1) return self._check_cookies(netloc + path) - def _store_cookies(self, response_obj): + def _store_cookies(self, response_obj: BaseResponse[Any]) -> None: for cookie in response_obj.cookies: try: self.domain_dict[cookie.host.lstrip()].append(cookie) except KeyError: self.domain_dict[cookie.host.lstrip()] = [cookie] - def _check_cookies(self, endpoint): + def _check_cookies(self, endpoint: str) -> _CookiesToSend: relevant_domains = [] domains = self.domain_dict.keys() @@ -37,22 +41,22 @@ def _check_cookies(self, endpoint): relevant_domains.append(check_domain) return self._get_cookies_to_send(relevant_domains) - def _get_cookies_to_send(self, domain_list): - cookies_to_go = {} + def _get_cookies_to_send(self, domain_list: list[str]) -> _CookiesToSend: + cookies_to_go: dict[Optional[str], Optional[str]] = {} for domain in domain_list: for cookie_obj in self.domain_dict[domain]: cookies_to_go[cookie_obj.name] = cookie_obj.value return cookies_to_go -def parse_cookies(response, host): +def parse_cookies(response: BaseResponse[Any], host: str) -> None: """ Sticks cookies to a response. """ cookie_pie = [] try: for cookie in response.headers["set-cookie"]: - cookie_jar = {} + cookie_jar: dict[str, Union[str, bool]] = {} name_val, *rest = cookie.split(";") name, value = name_val.split("=", 1) cookie_jar["name"] = name.strip() diff --git a/asks/errors.py b/asks/errors.py index c1439a2..d5a1bb9 100644 --- a/asks/errors.py +++ b/asks/errors.py @@ -2,6 +2,8 @@ Simple exceptions to be raised in case of errors. """ +from typing import Any + class AsksException(Exception): """ @@ -28,7 +30,12 @@ class BadHttpResponse(AsksException): class BadStatus(AsksException): - def __init__(self, err, response, status_code=500): + def __init__( + self, + err: Any, + response: "response_objects.BaseResponse[Any]", + status_code: int = 500, + ) -> None: super().__init__(err) self.response = response self.status_code = status_code @@ -42,3 +49,6 @@ class RequestTimeout(ConnectivityError): class ServerClosedConnectionError(ConnectivityError): pass + + +from . import response_objects # noqa diff --git a/asks/http_utils.py b/asks/http_utils.py index 9d8d56b..f060999 100644 --- a/asks/http_utils.py +++ b/asks/http_utils.py @@ -7,19 +7,23 @@ import codecs -from zlib import decompressobj, MAX_WBITS +from typing import Iterator, Optional +from zlib import MAX_WBITS, decompressobj from .utils import processor -def parse_content_encoding(content_encoding: str) -> [str]: +def parse_content_encoding(content_encoding: str) -> list[str]: compressions = [x.strip() for x in content_encoding.split(",")] return compressions @processor -def decompress(compressions, encoding=None): - data = b"" +def decompress( + compressions: list[str], encoding: Optional[str] = None +) -> Iterator[bytes]: + encoded: Optional[bytes] = b"" + decoded: bytes = b"" # https://tools.ietf.org/html/rfc7231 # "If one or more encodings have been applied to a representation, the # sender that applied the encodings MUST generate a Content-Encoding @@ -32,9 +36,11 @@ def decompress(compressions, encoding=None): if encoding: decompressors.append(make_decoder_shim(encoding)) while True: - data = yield data - for decompressor in decompressors: - data = decompressor.send(data) + encoded = yield decoded + if encoded is not None: + for decompressor in decompressors: + encoded = decompressor.send(encoded) + decoded = encoded # https://tools.ietf.org/html/rfc7230#section-4.2.1 - #section-4.2.3 @@ -47,20 +53,24 @@ def decompress(compressions, encoding=None): @processor -def decompress_one(compression): - data = b"" +def decompress_one(compression: str) -> Iterator[bytes]: + encoded: Optional[bytes] = b"" + decoded: bytes = b"" decompressor = decompressobj(wbits=DECOMPRESS_WBITS[compression]) while True: - data = yield data - data = decompressor.decompress(data) + encoded = yield decoded + if encoded is not None: + decoded = decompressor.decompress(encoded) yield decompressor.flush() @processor -def make_decoder_shim(encoding): - data = b"" +def make_decoder_shim(encoding: str) -> Iterator[str]: + encoded: Optional[bytes] = b"" + decoded: str = "" decoder = codecs.getincrementaldecoder(encoding)(errors="replace") while True: - data = yield data - data = decoder.decode(data) + encoded = yield decoded + if encoded is not None: + decoded = decoder.decode(encoded) yield decoder.decode(b"", final=True) diff --git a/asks/multipart.py b/asks/multipart.py index bfe5392..2dc4a1d 100644 --- a/asks/multipart.py +++ b/asks/multipart.py @@ -1,10 +1,8 @@ import mimetypes - -from typing import BinaryIO, NamedTuple, Union, Optional from pathlib import Path +from typing import Any, BinaryIO, NamedTuple, Optional, Union -from anyio import open_file, AsyncFile - +from anyio import AsyncFile, open_file _RAW_BYTES_MIMETYPE = "application/octet-stream" @@ -17,11 +15,11 @@ class MultipartData(NamedTuple): a field to be sent, and/or to send raw bytes as files. """ - binary_source: Union[Path, bytes, BinaryIO, AsyncFile] + binary_source: Union[Path, bytes, BinaryIO, AsyncFile[Any]] mime_type: Optional[str] = _RAW_BYTES_MIMETYPE basename: Optional[str] = None - async def to_bytes(self): + async def to_bytes(self) -> bytes: binary_source = self.binary_source if isinstance(binary_source, Path): @@ -39,10 +37,10 @@ async def to_bytes(self): return result # We must then assume it is a coroutine. - return await result + return bytes(await result) -def _to_multipart_file(value): +def _to_multipart_file(value: Any) -> MultipartData: """ Ensure a file-like supported type is encapsulated in a MultipartData object. @@ -61,10 +59,14 @@ def _to_multipart_file(value): else _RAW_BYTES_MIMETYPE ) - return MultipartData(binary_source=value, mime_type=mime_type, basename=basename,) + return MultipartData( + binary_source=value, + mime_type=mime_type, + basename=basename, + ) -def _to_multipart_form_data(value, encoding): +def _to_multipart_form_data(value: Any, encoding: str) -> MultipartData: """ Transform a form-data entry into a MultipartData object. @@ -86,11 +88,13 @@ def _to_multipart_form_data(value, encoding): # It's not a supported file type, so we do our best to transform it into form data. return MultipartData( - binary_source=str(value).encode(encoding), mime_type=None, basename=None, + binary_source=str(value).encode(encoding), + mime_type=None, + basename=None, ) -async def build_multipart_body(values, encoding, boundary_data): +async def build_multipart_body(values: Any, encoding: Any, boundary_data: Any) -> bytes: """ Forms a multipart request body from a dict of form fields to values. diff --git a/asks/py.typed b/asks/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/asks/req_structs.py b/asks/req_structs.py index b28834c..e155200 100644 --- a/asks/req_structs.py +++ b/asks/req_structs.py @@ -3,32 +3,48 @@ Some structures used throughout asks. """ from collections import OrderedDict, deque -from collections.abc import MutableMapping, Mapping +from collections.abc import Mapping, MutableMapping +from typing import Any, Generator, Iterable, Iterator, Optional, Protocol, cast -class SocketQ(deque): +class SocketLike(Protocol): + host: str + port: str + _active: bool + + async def receive(self, max_bytes: int = 65536) -> bytes: + ... + + async def aclose(self) -> None: + ... + + async def send(self, item: Optional[bytes]) -> None: + ... + + +class SocketQ(deque[SocketLike]): """ A funky little subclass of deque built for the session classes. Allows for connection pooling of sockets to remote hosts. """ - def index(self, host_loc): + def index(self, host_loc: object, _a: int = 0, _b: int = 0) -> int: try: return next(index for index, i in enumerate(self) if i.host == host_loc) except StopIteration: - raise ValueError("{} not in SocketQ".format(host_loc)) from None + raise ValueError("{!r} not in SocketQ".format(host_loc)) from None - def pull(self, index): + def pull(self, index: Any) -> SocketLike: x = self[index] del self[index] return x - async def free_pool(self): + async def free_pool(self) -> None: while self: sock = self.pop() await sock.aclose() - def __contains__(self, host_loc): + def __contains__(self, host_loc: object) -> bool: for i in self: if i.host == host_loc: return True @@ -47,7 +63,7 @@ def __contains__(self, host_loc): """ -class CaseInsensitiveDict(MutableMapping): +class CaseInsensitiveDict(MutableMapping[str, str]): """A case-insensitive ``dict``-like object. Implements all methods and operations of ``collections.MutableMapping`` as well as dict's ``copy``. Also @@ -69,34 +85,41 @@ class CaseInsensitiveDict(MutableMapping): behavior is undefined. """ - def __init__(self, data=None, **kwargs): - self._store = OrderedDict() + def __init__( + self, data: Optional[Iterable[tuple[str, str]]] = None, **kwargs: Any + ) -> None: + self._store: OrderedDict[str, tuple[str, str]] = OrderedDict() if data is None: data = {} self.update(data, **kwargs) - def __setitem__(self, key, value): + def __setitem__(self, key: str, value: str) -> None: # Use the lowercased key for lookups, but store the actual # key alongside the value. self._store[key.lower()] = (key, value) - def __getitem__(self, key): + def __getitem__(self, key: str) -> Any: return self._store[key.lower()][1] - def __delitem__(self, key): + def __delitem__(self, key: str) -> None: del self._store[key.lower()] - def __iter__(self): - return (casedkey for casedkey, mappedvalue in self._store.values()) + def __iter__(self) -> Iterator[str]: + return ( + casedkey + for casedkey, mappedvalue in cast( + Iterable[tuple[str, tuple[str, str]]], self._store.values() + ) + ) - def __len__(self): + def __len__(self) -> int: return len(self._store) - def lower_items(self): + def lower_items(self) -> Generator[tuple[str, str], None, None]: """Like items(), but with all lowercase keys.""" return ((lowerkey, keyval[1]) for (lowerkey, keyval) in self._store.items()) - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if isinstance(other, Mapping): other = CaseInsensitiveDict(other) else: @@ -105,8 +128,10 @@ def __eq__(self, other): return dict(self.lower_items()) == dict(other.lower_items()) # Copy is required - def copy(self): - return CaseInsensitiveDict(self._store.values()) + def copy(self) -> "CaseInsensitiveDict": + return CaseInsensitiveDict( + ((key, val[1]) for (key, val) in self._store.items()) + ) - def __repr__(self): + def __repr__(self) -> str: return str(dict(self.items())) diff --git a/asks/request_object.py b/asks/request_object.py index 79bcaa6..0a7a79b 100644 --- a/asks/request_object.py +++ b/asks/request_object.py @@ -8,37 +8,47 @@ __all__ = ["RequestProcessor"] -from numbers import Number -from os.path import basename -from urllib.parse import urljoin, urlparse, urlunparse, quote_plus import json as _json -from random import randint import mimetypes import re +from collections.abc import Awaitable, Callable +from numbers import Number +from os.path import basename +from random import randint +from typing import Any, Iterable, Optional, Protocol, Union, cast +from urllib.parse import quote_plus, urljoin, urlparse, urlunparse -from anyio import open_file, EndOfStream import h11 +from anyio import EndOfStream, open_file -from .utils import requote_uri +from .auth import PostResponseAuth, PreResponseAuth from .cookie_utils import parse_cookies -from .auth import PreResponseAuth, PostResponseAuth -from .req_structs import CaseInsensitiveDict as c_i_dict -from .response_objects import Response, StreamResponse, StreamBody from .errors import TooManyRedirects from .multipart import build_multipart_body +from .req_structs import CaseInsensitiveDict as c_i_dict +from .req_structs import SocketLike +from .response_objects import (Response, StreamBody, + StreamResponse) +from .utils import requote_uri - -_BOUNDARY = "8banana133744910kmmr13a56!102!" + str(randint(10 ** 3, 10 ** 9)) +_BOUNDARY = "8banana133744910kmmr13a56!102!" + str(randint(10**3, 10**9)) _WWX_MATCH = re.compile(r"\Aww.\.") +class ResponseLike(Protocol): + status_code: int + reason: bytes + http_version: bytes + headers: list[tuple[bytes, bytes]] + + class RequestProcessor: """ Handles the building, formatting and i/o of requests once the calling session passes the required info and calls `make_request`. Args: - session (child of BaseSession): A reference to the calling session. + session (child of "session.BaseSession"): A reference to the calling session. method (str): The HTTP method to be used in the request. @@ -86,29 +96,36 @@ class RequestProcessor: socket object may be updated on `connection: close` headers. """ - def __init__(self, session, method, uri, port, **kwargs): + def __init__( + self, + session: "Optional[sessions.BaseSession]", + method: str, + uri: str, + port: Optional[str], + **kwargs: Any + ) -> None: # These are kwargsable attribs. self.session = session self.method = method.upper() self.uri = uri - self.port = port + self.port = port or "443" self.auth = None self.auth_off_domain = None - self.body = None + self.body: Optional[str] = None self.data = None self.params = None self.headers = None - self.encoding = None + self.encoding: str = "utf-8" self.json = None self.files = None self.multipart = None - self.cookies = {} - self.callback = None + self.cookies: dict[str, str] = {} + self.callback: Optional[Callable[[bytes], Awaitable[None]]] = None self.stream = None self.timeout = None self.max_redirects = 20 self.follow_redirects = True - self.sock = None + self.sock: Optional[SocketLike] = None self.persist_cookies = None self.mimetype = None @@ -117,21 +134,23 @@ def __init__(self, session, method, uri, port, **kwargs): self.__dict__.update(kwargs) # These are unkwargsable, and set by the code. - self.history_objects = [] - self.scheme = None - self.host = None - self.path = None - self.query = None - self.uri_parameters = None + self.history_objects: list[Union[Response, StreamResponse]] = [] + self.scheme: Optional[str] = None + self.host: Optional[str] = None + self.path: Optional[str] = None + self.query: Optional[str] = None + self.uri_parameters: Optional[str] = None self.target_netloc = None - self.req_url = None + self.req_url: Optional[str] = None - self.initial_scheme = None - self.initial_netloc = None + self.initial_scheme: Optional[str] = None + self.initial_netloc: Optional[str] = None self.streaming = False - async def make_request(self, redirect=False): + async def make_request( + self, redirect: bool = False + ) -> tuple[Optional[SocketLike], Union[Response, StreamResponse]]: """ Acts as the central hub for preparing requests to be sent, and returning them upon completion. Generally just pokes through @@ -183,7 +202,8 @@ async def make_request(self, redirect=False): # What the fuck is this shit. if self.persist_cookies is not None: self.cookies.update( - self.persist_cookies.get_additional_cookies(self.host, self.path) + self.persist_cookies.get_additional_cookies( + self.host, self.path) ) # formulate path / query and intended extra querys for use in uri @@ -191,10 +211,14 @@ async def make_request(self, redirect=False): # handle building the request body, if any body = "" - if any((self.data, self.files, self.json is not None, self.multipart is not None)): + if any( + (self.data, self.files, self.json is not None, self.multipart is not None) + ): content_type, content_len, body = await self._formulate_body() - asks_headers["Content-Type"] = content_type - asks_headers["Content-Length"] = content_len + if content_type: + asks_headers["Content-Type"] = content_type + if content_len: + asks_headers["Content-Length"] = content_len self.body = body # add custom headers, if any @@ -216,16 +240,19 @@ async def make_request(self, redirect=False): # Construct h11 body object, if any body. if body: - if not isinstance(body, bytes): - body = bytes(body, self.encoding) - asks_headers["Content-Length"] = str(len(body)) - req_body = h11.Data(data=body) + if isinstance(body, bytes): + body_bytes = body + else: + body_bytes = bytes(body, self.encoding) + asks_headers["Content-Length"] = str(len(body_bytes)) + req_body = h11.Data(data=body_bytes) else: req_body = None # Construct h11 request object. req = h11.Request( - method=self.method, target=self.path, headers=asks_headers.items() + method=self.method, target=self.path, headers=list( + asks_headers.items()) ) # call i/o handling func @@ -235,7 +262,7 @@ async def make_request(self, redirect=False): # to the calling session's connection pool. # We don't want to return sockets that are of a difference schema or # different top level domain, as they are less likely to be useful. - if redirect: + if redirect and self.sock: if not ( self.scheme == self.initial_scheme and self.host == self.initial_netloc ): @@ -245,13 +272,22 @@ async def make_request(self, redirect=False): if self.streaming: return None, response_obj - if asks_headers.get('connection', '') == 'close' and self.sock._active: + if ( + self.sock + and asks_headers.get("connection", "") == "close" + and self.sock._active + ): await self.sock.aclose() return None, response_obj return self.sock, response_obj - async def _request_io(self, h11_request, h11_body, h11_connection): + async def _request_io( + self, + h11_request: h11.Request, + h11_body: Optional[h11.Data], + h11_connection: h11.Connection, + ) -> Union[Response, StreamResponse]: """ Takes care of the i/o side of the request once it's been built, and calls a couple of cleanup functions to check for redirects / store @@ -273,7 +309,8 @@ async def _request_io(self, h11_request, h11_body, h11_connection): """ await self._send(h11_request, h11_body, h11_connection) response_obj = await self._catch_response(h11_connection) - parse_cookies(response_obj, self.host) + if self.host: + parse_cookies(response_obj, self.host) # If there's a cookie tracker object, store any cookies we # might've picked up along our travels. @@ -297,7 +334,7 @@ async def _request_io(self, h11_request, h11_body, h11_connection): return response_obj - def _build_path(self): + def _build_path(self) -> None: """ Constructs the actual request URL with accompanying query if any. @@ -332,7 +369,7 @@ def _build_path(self): (self.scheme, self.host, (self.path or ""), "", "", "") ) - async def _redirect(self, response_obj): + async def _redirect(self, response_obj: Union[Response, StreamResponse]) -> Union[Response, StreamResponse]: """ Calls the _check_redirect method of the supplied response object in order to determine if the http status code indicates a redirect. @@ -360,7 +397,7 @@ async def _redirect(self, response_obj): force_get = True location = response_obj.headers["Location"] - if redirect: + if redirect and location: allow_redirect = True location = urljoin(self.uri, location.strip()) if self.auth is not None: @@ -388,17 +425,19 @@ async def _redirect(self, response_obj): _, response_obj = await self.make_request() return response_obj - async def _get_new_sock(self): + async def _get_new_sock(self) -> None: """ On 'Connection: close' headers we've to create a new connection. This reaches in to the parent session and pulls a switcheroo, dunking the current connection and requesting a new one. """ - self.sock._active = False + if not self.session: + raise ValueError("session is none") self.sock = await self.session._grab_connection(self.uri) + self.sock._active = False self.port = self.sock.port - async def _formulate_body(self): + async def _formulate_body(self) -> tuple[Optional[str], str, str]: """ Takes user supplied data / files and forms it / them appropriately, returning the contents type, len, @@ -455,7 +494,9 @@ async def _formulate_body(self): return c_type, str(len(body)), body @staticmethod - def _dict_to_query(data, params=True, base_query=False): + def _dict_to_query( + data: dict[str, object], params: bool = True, base_query: bool = False + ) -> str: """ Turns python dicts in to valid body-queries or queries for use directly in the request url. Unlike the stdlib quote() and it's variations, @@ -478,6 +519,7 @@ def _dict_to_query(data, params=True, base_query=False): for key in v: query.append("=".join(quote_plus(x) for x in (k, key))) elif hasattr(v, "__iter__"): + v = cast(Iterable[object], v) for elm in v: query.append( "=".join( @@ -494,7 +536,7 @@ def _dict_to_query(data, params=True, base_query=False): return requote_uri("&".join(query)) - async def _multipart(self, files_dict): + async def _multipart(self, files_dict: dict[str, str]) -> bytes: """ Forms multipart requests from a dict with name, path k/vs. Name does not have to be the actual file name. @@ -504,10 +546,11 @@ async def _multipart(self, files_dict): as multipart files. Returns: - multip_pkg (str): The strings representation of the content body, + multip_pkg (bytes): The bytes representation of the content body, multipart formatted. """ boundary = bytes(_BOUNDARY, self.encoding) + boundary = bytes(_BOUNDARY, "utf-8") hder_format = 'Content-Disposition: form-data; name="{}"' hder_format_io = '; filename="{}"' @@ -525,17 +568,20 @@ async def _multipart(self, files_dict): hder_format.format(k) + hder_format_io.format(basename(v)), self.encoding, ) - mime_type = mimetypes.guess_type(basename(v)) - if not mime_type[1]: + mime_type_tuple = mimetypes.guess_type(basename(v)) + if not mime_type_tuple[1]: mime_type = "application/octet-stream" else: - mime_type = "/".join(mime_type) - multip_pkg += bytes("\r\nContent-Type: " + mime_type, self.encoding) + mime_type = "{}/{}".format( + mime_type_tuple[0], mime_type_tuple[1]) + multip_pkg += bytes("\r\nContent-Type: " + + mime_type, self.encoding) multip_pkg += b"\r\n" * 2 + pkg_body except (TypeError, FileNotFoundError): pkg_body = bytes(v, self.encoding) + b"\r\n" - multip_pkg += bytes(hder_format.format(k) + "\r\n" * 2, self.encoding) + multip_pkg += bytes(hder_format.format(k) + + "\r\n" * 2, self.encoding) multip_pkg += pkg_body if index == num_of_parts: @@ -543,11 +589,11 @@ async def _multipart(self, files_dict): return multip_pkg - async def _file_manager(self, path): + async def _file_manager(self, path: str) -> bytes: async with await open_file(path, "rb") as f: return b"".join(await f.readlines()) + b"\r\n" - async def _catch_response(self, h11_connection): + async def _catch_response(self, h11_connection: h11.Connection) -> Union[Response, StreamResponse]: """ Instantiates the parser which manages incoming data, first getting the headers, storing cookies, and then parsing the response's body, @@ -568,15 +614,15 @@ async def _catch_response(self, h11_connection): response = await self._recv_event(h11_connection) - resp_data = { + resp_data: dict[str, Any] = { "encoding": self.encoding, "method": self.method, "status_code": response.status_code, - "reason_phrase": str(response.reason, "utf-8"), - "http_version": str(response.http_version, "utf-8"), + "reason_phrase": response.reason.decode("utf-8"), + "http_version": response.http_version.decode("utf-8"), "headers": c_i_dict( [ - (str(name, "utf-8"), str(value, "utf-8")) + (name.decode("utf-8"), value.decode("utf-8")) for name, value in response.headers ] ), @@ -587,9 +633,11 @@ async def _catch_response(self, h11_connection): for header in response.headers: if header[0].lower() == b"set-cookie": try: - resp_data["headers"]["set-cookie"].append(str(header[1], "utf-8")) + resp_data["headers"]["set-cookie"].append( + str(header[1], "utf-8")) except (KeyError, AttributeError): - resp_data["headers"]["set-cookie"] = [str(header[1], "utf-8")] + resp_data["headers"]["set-cookie"] = [ + str(header[1], "utf-8")] # check whether we should receive body according to RFC 7230 # https://tools.ietf.org/html/rfc7230#section-3.3.3 @@ -602,7 +650,10 @@ async def _catch_response(self, h11_connection): if "chunked" in resp_data["headers"]["transfer-encoding"].lower(): get_body = True except KeyError: - connection_close = resp_data["headers"].get("connection", "").lower() == "close" + connection_close = ( + resp_data["headers"].get( + "connection", "").lower() == "close" + ) http_1 = response.http_version == b"1.0" if connection_close or http_1: get_body = True @@ -649,19 +700,25 @@ async def _catch_response(self, h11_connection): return Response(**resp_data) - async def _recv_event(self, h11_connection): + async def _recv_event(self, h11_connection: h11.Connection) -> ResponseLike: while True: event = h11_connection.next_event() if event is h11.NEED_DATA: try: - data = await self.sock.receive() + if self.sock: + data = await self.sock.receive() except EndOfStream: data = b"" h11_connection.receive_data(data) continue - return event - - async def _send(self, request_bytes, body_bytes, h11_connection): + return cast(ResponseLike, event) + + async def _send( + self, + request: h11.Request, + body: Optional[h11.Data], + h11_connection: h11.Connection, + ) -> None: """ Takes a package and body, combines then, then shoots 'em off in to the ether. @@ -670,15 +727,17 @@ async def _send(self, request_bytes, body_bytes, h11_connection): package (list of str): The header package. body (str): The str representation of the body. """ - await self.sock.send(h11_connection.send(request_bytes)) - if body_bytes is not None: - await self.sock.send(h11_connection.send(body_bytes)) + if not self.sock: + return + await self.sock.send(h11_connection.send(request)) + if body is not None: + await self.sock.send(h11_connection.send(body)) data = h11_connection.send(h11.EndOfMessage()) if data: await self.sock.send(data) - async def _auth_handler_pre(self): + async def _auth_handler_pre(self) -> dict[str, str]: """ If the user supplied auth does not rely on any response (is a PreResponseAuth object) then we call the auth's __call__ @@ -689,7 +748,7 @@ async def _auth_handler_pre(self): return await self.auth(self) return {} - async def _auth_handler_post_get_auth(self): + async def _auth_handler_post_get_auth(self) -> dict[str, str]: """ If the user supplied auth does rely on a response (is a PostResponseAuth object) then we call the auth's __call__ @@ -704,10 +763,12 @@ async def _auth_handler_post_get_auth(self): if authable_resp.status_code == 401: if not self.auth.auth_attempted: self.auth.auth_attempted = True - return await self.auth(authable_resp, self) + return await self.auth(cast(Response, authable_resp), self) return {} - async def _auth_handler_post_check_retry(self, response_obj): + async def _auth_handler_post_check_retry( + self, response_obj: Union[Response, StreamResponse] + ) -> Union[Response, StreamResponse]: """ The other half of _auth_handler_post_check_retry (what a mouthful). If auth has not yet been attempted and the most recent response @@ -728,7 +789,7 @@ async def _auth_handler_post_check_retry(self, response_obj): return response_obj return response_obj - async def _location_auth_protect(self, location): + async def _location_auth_protect(self, location: str) -> bool: """ Checks to see if the new location is 1. The same top level domain @@ -739,28 +800,38 @@ async def _location_auth_protect(self, location): and the connection type is equally or more secure. False otherwise. """ + if not self.host: + return False netloc_sans_port = self.host.split(":")[0] - netloc_sans_port = netloc_sans_port.replace( - (re.match(_WWX_MATCH, netloc_sans_port)[0]), "" - ) + match = re.match(_WWX_MATCH, netloc_sans_port) + + if not match: + return False + netloc_sans_port = netloc_sans_port.replace(match.groups()[0], "") base_domain = ".".join(netloc_sans_port.split(".")[-2:]) l_scheme, l_netloc, _, _, _, _ = urlparse(location) location_sans_port = l_netloc.split(":")[0] - location_sans_port = location_sans_port.replace( - (re.match(_WWX_MATCH, location_sans_port)[0]), "" - ) + match = re.match(_WWX_MATCH, location_sans_port) + + if not match: + return False + + location_sans_port = location_sans_port.replace(match.groups()[0], "") location_domain = ".".join(location_sans_port.split(".")[-2:]) if base_domain == location_domain: + if self.scheme is None: + return True if l_scheme < self.scheme: return False else: return True + return False - async def _body_callback(self, h11_connection): + async def _body_callback(self, h11_connection: h11.Connection) -> ResponseLike: """ A callback func to be supplied if the user wants to do something directly with the response body's stream. @@ -769,6 +840,10 @@ async def _body_callback(self, h11_connection): while True: next_event = await self._recv_event(h11_connection) if isinstance(next_event, h11.Data): - await self.callback(next_event.data) + if self.callback is not None: + await self.callback(next_event.data) else: return next_event + + +from . import sessions # noqa diff --git a/asks/response_objects.py b/asks/response_objects.py index 105c589..a285537 100644 --- a/asks/response_objects.py +++ b/asks/response_objects.py @@ -1,16 +1,21 @@ import codecs -from types import SimpleNamespace import json as _json +from types import SimpleNamespace +from typing import Any, Generic, Iterator, Optional, TypeVar, Union, cast -from async_generator import async_generator, yield_ import h11 +from async_generator import async_generator, yield_ +from .errors import BadStatus from .http_utils import decompress, parse_content_encoding +from .req_structs import SocketLike from .utils import timeout_manager -from .errors import BadStatus -class BaseResponse: +BodyType = TypeVar('BodyType') + + +class BaseResponse(Generic[BodyType]): """ A response object supporting a range of methods and attribs for accessing the status line, header, cookies, history and @@ -19,15 +24,15 @@ class BaseResponse: def __init__( self, - encoding, - http_version, - status_code, - reason_phrase, - headers, - body, - method, - url, - ): + encoding: Optional[str], + http_version: str, + status_code: int, + reason_phrase: str, + headers: dict[str, str], + body: BodyType, + method: str, + url: str, + ) -> None: self.encoding = encoding self.http_version = http_version self.status_code = status_code @@ -36,10 +41,10 @@ def __init__( self.body = body self.method = method self.url = url - self.history = [] - self.cookies = [] + self.history: list["Union[Response, StreamResponse]"] = [] + self.cookies: list["Cookie"] = [] - def raise_for_status(self): + def raise_for_status(self) -> None: """ Raise BadStatus if one occurred. """ @@ -60,12 +65,12 @@ def raise_for_status(self): self.status_code, ) - def __repr__(self): + def __repr__(self) -> str: return "<{} {} {}>".format( self.__class__.__name__, self.status_code, self.reason_phrase ) - def _guess_encoding(self): + def _guess_encoding(self) -> None: try: guess = self.headers["content-type"].split("=")[1] codecs.lookup(guess) @@ -73,29 +78,33 @@ def _guess_encoding(self): except LookupError: # IndexError/KeyError are LookupError subclasses pass - def _decompress(self, encoding=None): + def _decompress(self, encoding: Optional[str] = None) -> str: content_encoding = self.headers.get("Content-Encoding", None) if content_encoding is not None: decompressor = decompress( parse_content_encoding(content_encoding), encoding ) r = decompressor.send(self.body) - return r + return cast(str, r) else: if encoding is not None: + if not isinstance(self.body, bytes) and not isinstance(self.body, bytearray): + raise TypeError("body is not bytes when it should be") return self.body.decode(encoding, errors="replace") else: + if not isinstance(self.body, str): + raise TypeError("body is not str when it should be") return self.body - async def __aenter__(self): + async def __aenter__(self) -> "BaseResponse[BodyType]": return self - async def __aexit__(self, *exc_info): + async def __aexit__(self, *exc_info: Any) -> Any: ... -class Response(BaseResponse): - def json(self, **kwargs): +class Response(BaseResponse[Union[str, bytes, bytearray]]): + def json(self, **kwargs: Any) -> Any: """ If the response's body is valid json, we load it as a python dict and return it. @@ -104,57 +113,61 @@ def json(self, **kwargs): return _json.loads(body, **kwargs) @property - def text(self): + def text(self) -> Any: """ Returns the (maybe decompressed) decoded version of the body. """ return self._decompress(self.encoding) @property - def content(self): + def content(self) -> Any: """ Returns the content as-is after decompression, if any. """ return self._decompress() @property - def raw(self): + def raw(self) -> Union[str, bytes, bytearray]: """ Returns the response body as received. """ return self.body -class StreamResponse(BaseResponse): - ... - - class StreamBody: - def __init__(self, h11_connection, sock, content_encoding=None, encoding=None): + def __init__( + self, + h11_connection: h11.Connection, + sock: SocketLike, + content_encoding: Optional[str] = None, + encoding: Optional[str] = None, + ): self.h11_connection = h11_connection self.sock = sock self.content_encoding = content_encoding self.encoding = encoding # TODO: add decompress data to __call__ args self.decompress_data = True - self.timeout = None + self.timeout: Optional[float] = None self.read_size = 10000 @async_generator - async def __aiter__(self): + async def __aiter__(self) -> None: if self.content_encoding is not None: - decompressor = decompress(parse_content_encoding(self.content_encoding)) + decompressor = decompress( + parse_content_encoding(self.content_encoding)) while True: event = await self._recv_event() if isinstance(event, h11.Data): + data = event.data if self.content_encoding is not None: if self.decompress_data: - event.data = decompressor.send(event.data) - await yield_(event.data) + data = decompressor.send(event.data) + await yield_(data) elif isinstance(event, h11.EndOfMessage): break - async def _recv_event(self): + async def _recv_event(self) -> Any: while True: event = self.h11_connection.next_event() @@ -167,46 +180,49 @@ async def _recv_event(self): return event - def __call__(self, timeout=None): + def __call__(self, timeout: Optional[float] = None) -> "StreamBody": self.timeout = timeout return self - async def __aenter__(self): + async def __aenter__(self) -> "StreamBody": return self - async def close(self): + async def close(self) -> None: await self.sock.aclose() - async def __aexit__(self, *exc_info): + async def __aexit__(self, *exc_info: Any) -> None: await self.close() +class StreamResponse(BaseResponse[StreamBody]): + ... + + class Cookie(SimpleNamespace): """ A simple cookie object, for storing cookie stuff :) Needs to be made compatible with the API's cookies kw arg. """ - def __init__(self, host, data): - self.name = None - self.value = None - self.domain = None - self.path = None - self.secure = False - self.expires = None - self.comment = None + def __init__(self, host: str, data: dict[str, Any]) -> None: + self.name: Optional[str] = None + self.value: Optional[str] = None + self.domain: Optional[str] = None + self.path: Optional[str] = None + self.secure: bool = False + self.comment: Optional[str] = None self.__dict__.update(data) super().__init__(**self.__dict__) self.host = host - def __repr__(self): + def __repr__(self) -> str: if self.name is not None: return "".format(self.name, self.host) else: return "".format(self.value, self.host) - def __iter__(self): + def __iter__(self) -> Iterator[tuple[str, str]]: for k, v in self.__dict__.items(): yield k, v diff --git a/asks/sessions.py b/asks/sessions.py index 33765b4..b789c59 100644 --- a/asks/sessions.py +++ b/asks/sessions.py @@ -5,18 +5,20 @@ from abc import ABCMeta, abstractmethod from copy import copy from functools import partialmethod -from urllib.parse import urlparse, urlunparse, urljoin +from ssl import SSLContext +from typing import Any, Optional, Union, cast +from urllib.parse import urlparse, urlunparse +from anyio import Semaphore, connect_tcp from h11 import RemoteProtocolError -from anyio import connect_tcp, Semaphore from .cookie_utils import CookieTracker from .errors import BadHttpResponse -from .req_structs import SocketQ +from .req_structs import SocketLike, SocketQ from .request_object import RequestProcessor +from .response_objects import Response, StreamResponse from .utils import get_netloc_port, timeout_manager - __all__ = ["Session"] @@ -27,7 +29,11 @@ class BaseSession(metaclass=ABCMeta): socket to create, and all of the HTTP methods ('GET', 'POST', etc.) """ - def __init__(self, headers=None, ssl_context=None): + def __init__( + self, + headers: Optional[dict[str, str]] = None, + ssl_context: Optional[SSLContext] = None, + ): """ Args: headers (dict): Headers to be applied to all requests. @@ -41,56 +47,60 @@ def __init__(self, headers=None, ssl_context=None): self.headers = {} self.ssl_context = ssl_context - self.encoding = None - self.source_address = None - self._cookie_tracker = None + self.encoding: str = "utf-8" + self.source_address: Optional[str] = None + self._cookie_tracker: Optional[CookieTracker] = None @property @abstractmethod - def sema(self): + def sema(self) -> Semaphore: """ A semaphore-like context manager. """ ... - async def _open_connection_http(self, location): + async def _open_connection_http(self, location: tuple[str, int]) -> SocketLike: """ Creates a normal async socket, returns it. Args: location (tuple(str, int)): A tuple of net location (eg '127.0.0.1' or 'example.org') and port (eg 80 or 25000). """ - sock = await connect_tcp( - location[0], location[1], local_host=self.source_address + sock = cast( + SocketLike, + await connect_tcp(location[0], location[1], local_host=self.source_address), ) sock._active = True return sock - async def _open_connection_https(self, location): + async def _open_connection_https(self, location: tuple[str, int]) -> SocketLike: """ Creates an async SSL socket, returns it. Args: location (tuple(str, int)): A tuple of net location (eg '127.0.0.1' or 'example.org') and port (eg 80 or 25000). """ - sock = await connect_tcp( - location[0], - location[1], - ssl_context=self.ssl_context, - local_host=self.source_address, - tls=True, - tls_standard_compatible=False, + sock = cast( + SocketLike, + await connect_tcp( + location[0], + location[1], + ssl_context=self.ssl_context, + local_host=self.source_address, + tls=True, + tls_standard_compatible=False, + ), ) sock._active = True return sock - async def _connect(self, host_loc): + async def _connect(self, host_loc: str) -> tuple[SocketLike, str]: """ Simple enough stuff to figure out where we should connect, and creates the appropriate connection. """ parsed_hostloc = urlparse(host_loc) - scheme, host, path, parameters, query, fragment = parsed_hostloc + scheme, _, path, parameters, query, fragment = parsed_hostloc if parameters or query or fragment: raise TypeError( "Supplied info beyond scheme, host." @@ -98,15 +108,27 @@ async def _connect(self, host_loc): path, ) - host, port = get_netloc_port(parsed_hostloc) + host_opt, port = get_netloc_port(parsed_hostloc) + if host_opt is not None: + host = host_opt + else: + host = "host" + if scheme == "http": return await self._open_connection_http((host, int(port))), port else: return await self._open_connection_https((host, int(port))), port async def request( - self, method, url=None, *, path="", retries=1, connection_timeout=60, **kwargs - ): + self, + method: str, + url: Optional[str] = None, + *, + path: str = "", + retries: int = 1, + connection_timeout: int = 60, + **kwargs: Any + ) -> Union[Response, StreamResponse]: """ This is the template for all of the `http method` methods for the Session. @@ -267,7 +289,7 @@ async def request( method, url, path=path, retries=retries, headers=headers, **kwargs ) - return r + return cast(Response, r) # These be the actual http methods! # They are partial methods of `request`. See the `request` docstring @@ -280,7 +302,7 @@ async def request( options = partialmethod(request, "OPTIONS") patch = partialmethod(request, "PATCH") - async def _handle_exception(self, e, sock): + async def _handle_exception(self, e: Exception, sock: SocketLike) -> None: """ Given an exception, we want to handle it appropriately. Some exceptions we prefer to shadow with an asks exception, and some we want to raise directly. @@ -295,21 +317,21 @@ async def _handle_exception(self, e, sock): raise e @abstractmethod - def _make_url(self): + def _make_url(self, path: str) -> str: """ A method who's result is concated with a uri path. """ ... @abstractmethod - async def _grab_connection(self, url): + async def _grab_connection(self, url: str) -> SocketLike: """ A method that will return a socket-like object. """ ... @abstractmethod - async def return_to_pool(self, sock): + async def return_to_pool(self, sock: SocketLike) -> None: """ A method that will accept a socket-like object. """ @@ -326,14 +348,14 @@ class Session(BaseSession): def __init__( self, - base_location="", - endpoint="", - headers=None, - encoding="utf-8", - persist_cookies=None, - ssl_context=None, - connections=1, - ): + base_location: str = "", + endpoint: str = "", + headers: Optional[dict[str, str]] = None, + encoding: str = "utf-8", + persist_cookies: Optional[bool] = None, + ssl_context: Optional[SSLContext] = None, + connections: int = 1, + ) -> None: """ Args: encoding (str): The encoding asks'll try to use on response bodies. @@ -350,32 +372,33 @@ def __init__( self.endpoint = endpoint if persist_cookies is True: - self._cookie_tracker = CookieTracker() + cookie_tracker = CookieTracker() else: - self._cookie_tracker = persist_cookies + cookie_tracker = None + self._cookie_tracker: Optional[CookieTracker] = cookie_tracker self._conn_pool = SocketQ() - self._sema = None + self._sema: Optional[Semaphore] = None self._connections = connections @property - def base_location(self): + def base_location(self) -> str: return self._base_location @base_location.setter - def base_location(self, value): + def base_location(self, value: str) -> None: if not value: self._base_location = value else: self._base_location = self._normalise_last_slashes(value) @property - def endpoint(self): + def endpoint(self) -> str: return self._endpoint @endpoint.setter - def endpoint(self, value): + def endpoint(self, value: str) -> None: if not value: self._endpoint = value else: @@ -383,12 +406,12 @@ def endpoint(self, value): self._endpoint = self._normalise_last_slashes(value) @property - def sema(self): + def sema(self) -> Semaphore: if self._sema is None: self._sema = Semaphore(self._connections) return self._sema - def _checkout_connection(self, host_loc): + def _checkout_connection(self, host_loc: str) -> Optional[SocketLike]: try: index = self._conn_pool.index(host_loc) except ValueError: @@ -397,17 +420,17 @@ def _checkout_connection(self, host_loc): sock = self._conn_pool.pull(index) return sock - async def return_to_pool(self, sock): + async def return_to_pool(self, sock: SocketLike) -> None: if sock._active: self._conn_pool.appendleft(sock) - async def _make_connection(self, host_loc): + async def _make_connection(self, host_loc: str) -> SocketLike: sock, port = await self._connect(host_loc) sock.host, sock.port = host_loc, port return sock - async def _grab_connection(self, url): + async def _grab_connection(self, url: str) -> SocketLike: """ The connection pool handler. Returns a connection to the caller. If there are no connections ready, and @@ -423,7 +446,7 @@ async def _grab_connection(self, url): lying around. """ scheme, host, _, _, _, _ = urlparse(url) - host_loc = urlunparse((scheme, host, "", "", "", "")) + host_loc: str = urlunparse((scheme, host, "", "", "", "")) sock = self._checkout_connection(host_loc) @@ -432,7 +455,7 @@ async def _grab_connection(self, url): return sock - def _make_url(self, path): + def _make_url(self, path: str) -> str: """ Puts together the hostloc and current endpoint for use in request uri. """ @@ -446,24 +469,24 @@ def _make_url(self, path): return "".join((self.base_location, self.endpoint, path)) @staticmethod - def _normalise_last_slashes(url_segment): + def _normalise_last_slashes(url_segment: str) -> str: """ Drop any last /'s """ return url_segment if not url_segment.endswith("/") else url_segment[:-1] @staticmethod - def _normalise_head_slashes(url_segment): + def _normalise_head_slashes(url_segment: str) -> str: """ Add any missing head /'s """ return url_segment if url_segment.startswith("/") else "/" + url_segment - async def __aenter__(self): + async def __aenter__(self) -> "Session": return self - async def __aexit__(self, exc_type, exc_value, traceback): + async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: await self.close() - async def close(self): + async def close(self) -> None: await self._conn_pool.free_pool() diff --git a/asks/utils.py b/asks/utils.py index 7e233c8..9d42d5d 100644 --- a/asks/utils.py +++ b/asks/utils.py @@ -1,15 +1,18 @@ __all__ = ["get_netloc_port", "requote_uri", "timeout_manager"] -from urllib.parse import quote from functools import wraps +from typing import Any, Callable, Optional +from urllib.parse import ParseResult, quote from anyio import fail_after from .errors import RequestTimeout -async def timeout_manager(timeout, coro, *args): +async def timeout_manager( + timeout: Optional[float], coro: Callable[..., Any], *args: Any +) -> Any: try: with fail_after(timeout): return await coro(*args) @@ -17,13 +20,13 @@ async def timeout_manager(timeout, coro, *args): raise RequestTimeout from e -def get_netloc_port(parsed_url): - port = parsed_url.port +def get_netloc_port(parsed_url: ParseResult) -> tuple[Optional[str], str]: + port: Optional[int] = parsed_url.port if not port: if parsed_url.scheme == "https": - port = "443" + port = 443 else: - port = "80" + port = 80 return parsed_url.hostname, str(port) @@ -33,7 +36,7 @@ def get_netloc_port(parsed_url): ) -def unquote_unreserved(uri): +def unquote_unreserved(uri: str) -> str: """Un-escape any percent-escape sequences in a URI that are unreserved characters. This leaves all reserved, illegal and non-ASCII bytes encoded. :rtype: str @@ -56,7 +59,7 @@ def unquote_unreserved(uri): return "".join(parts) -def requote_uri(uri): +def requote_uri(uri: str) -> str: """Re-quote the given URI. This function passes the given URI through an unquote/quote cycle to ensure that it is fully and consistently quoted. @@ -76,9 +79,9 @@ def requote_uri(uri): return quote(uri, safe=safe_without_percent) -def processor(gen): +def processor(gen: Callable[..., Any]) -> Callable[..., Any]: @wraps(gen) - def wrapper(*args, **kwargs): + def wrapper(*args: Any, **kwargs: Any) -> Any: g = gen(*args, **kwargs) next(g) return g diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 0000000..157326b --- /dev/null +++ b/mypy.ini @@ -0,0 +1,3 @@ +[mypy] +strict = true +files = asks, tests diff --git a/setup.py b/setup.py index 274ea08..524df5c 100755 --- a/setup.py +++ b/setup.py @@ -16,6 +16,9 @@ author='Mark Jameson - aka theelous3', url='https://github.com/theelous3/asks', packages=['asks'], + package_data={ + 'asks': ['py.typed'], + }, python_requires='>= 3.6.2', install_requires=['h11', 'async_generator', 'anyio ~= 3.0'], tests_require=['pytest', 'curio', 'trio', 'overly'], diff --git a/tests/test_anyio.py b/tests/test_anyio.py index 9d87e7e..3157479 100644 --- a/tests/test_anyio.py +++ b/tests/test_anyio.py @@ -4,6 +4,7 @@ from os import path from functools import partial from pathlib import Path +from typing import Optional, cast import h11 import pytest @@ -30,6 +31,10 @@ import asks from asks.request_object import RequestProcessor from asks.errors import TooManyRedirects, BadStatus, RequestTimeout +from asks.response_objects import Response, StreamResponse +from asks.req_structs import SocketLike + +import _pytest pytestmark = pytest.mark.anyio @@ -38,7 +43,7 @@ @pytest.fixture -def server(request): +def server(request: _pytest.fixtures.SubRequest) -> Server: srv = Server(_TEST_LOC, **request.param) srv.daemon = True srv.start() @@ -49,7 +54,7 @@ def server(request): @pytest.mark.parametrize('server', [dict(steps=[send_200, finish])], indirect=True) -async def test_http_get(server): +async def test_http_get(server: Server) -> None: r = await asks.get(server.http_test_url) assert r.status_code == 200 @@ -60,7 +65,7 @@ async def test_http_get(server): @pytest.mark.parametrize('server', [ dict(steps=[send_200, finish], socket_wrapper=ssl_socket_wrapper) ], indirect=True) -async def test_https_get(server, caplog): +async def test_https_get(server: Server, caplog: pytest.LogCaptureFixture) -> None: import logging caplog.set_level(logging.DEBUG) # If we use ssl_context= to trust the CA, then we can successfully do a @@ -72,9 +77,9 @@ async def test_https_get(server, caplog): @pytest.mark.parametrize('server', [ dict(steps=[send_200, finish], socket_wrapper=ssl_socket_wrapper) ], indirect=True) -async def test_https_get_checks_cert(server): +async def test_https_get_checks_cert(server: Server) -> None: try: - expected_error = ssl.SSLCertVerificationError + expected_error: type = ssl.SSLCertVerificationError except AttributeError: # If we're running in Python <3.7, we won't have the specific error # that will be raised, but we can expect it to raise an SSLError @@ -93,7 +98,7 @@ async def test_https_get_checks_cert(server): @pytest.mark.parametrize('server', [dict(steps=[send_400, finish])], indirect=True) -async def test_http_get_client_error(server): +async def test_http_get_client_error(server: Server) -> None: r = await asks.get(server.http_test_url) with pytest.raises(BadStatus) as excinfo: r.raise_for_status() @@ -102,7 +107,7 @@ async def test_http_get_client_error(server): @pytest.mark.parametrize('server', [dict(steps=[send_500, finish])], indirect=True) -async def test_http_get_server_error(server): +async def test_http_get_server_error(server: Server) -> None: r = await asks.get(server.http_test_url) with pytest.raises(BadStatus) as excinfo: r.raise_for_status() @@ -125,12 +130,12 @@ async def test_http_get_server_error(server): ordered_steps=True, ) ], indirect=True) -async def test_http_redirect(server): +async def test_http_redirect(server: Server) -> None: r = await asks.get(server.http_test_url + "/redirect_1") assert len(r.history) == 1 # make sure history doesn't persist across responses - r.history.append("not a response obj") + r.history.append(cast(Response, "not a response obj")) r = await asks.get(server.http_test_url + "/redirect_1") assert len(r.history) == 1 @@ -152,7 +157,7 @@ async def test_http_redirect(server): ], ) ], indirect=True) -async def test_http_max_redirect_error(server): +async def test_http_max_redirect_error(server: Server) -> None: with pytest.raises(TooManyRedirects): await asks.get(server.http_test_url + "/redirect_max", max_redirects=1) @@ -170,7 +175,7 @@ async def test_http_max_redirect_error(server): ], ) ], indirect=True) -async def test_redirect_relative_url(server): +async def test_redirect_relative_url(server: Server) -> None: r = await asks.get(server.http_test_url + "/path/redirect", max_redirects=1) assert len(r.history) == 1 assert r.url == "http://{0}:{1}/foo/bar".format(*_TEST_LOC) @@ -189,7 +194,7 @@ async def test_redirect_relative_url(server): ], ) ], indirect=True) -async def test_http_under_max_redirect(server): +async def test_http_under_max_redirect(server: Server) -> None: r = await asks.get(server.http_test_url + "/redirect_once", max_redirects=2) assert r.status_code == 200 @@ -206,7 +211,7 @@ async def test_http_under_max_redirect(server): ], ) ], indirect=True) -async def test_dont_follow_redirects(server): +async def test_dont_follow_redirects(server: Server) -> None: r = await asks.get(server.http_test_url + "/redirect_once", follow_redirects=False) assert r.status_code == 303 assert r.headers["location"] == "/" @@ -215,13 +220,13 @@ async def test_dont_follow_redirects(server): @pytest.mark.parametrize('server', [dict(steps=[delay(2), send_200, finish])], indirect=True) -async def test_http_timeout_error(server): +async def test_http_timeout_error(server: Server) -> None: with pytest.raises(RequestTimeout): await asks.get(server.http_test_url, timeout=1) @pytest.mark.parametrize('server', [dict(steps=[send_200, finish])], indirect=True) -async def test_http_timeout(server): +async def test_http_timeout(server: Server) -> None: r = await asks.get(server.http_test_url, timeout=10) assert r.status_code == 200 @@ -230,8 +235,10 @@ async def test_http_timeout(server): @pytest.mark.parametrize('server', [dict(steps=[send_request_as_json, finish])], indirect=True) -async def test_param_dict_set(server): +async def test_param_dict_set(server: Server) -> None: r = await asks.get(server.http_test_url, params={"cheese": "the best"}) + if not isinstance(r, Response): + raise TypeError("expected Response") j = r.json() assert next(v == "the best" for k, v in j["params"] if k == "cheese") @@ -240,8 +247,10 @@ async def test_param_dict_set(server): @pytest.mark.parametrize('server', [dict(steps=[send_request_as_json, finish])], indirect=True) -async def test_data_dict_set(server): +async def test_data_dict_set(server: Server) -> None: r = await asks.post(server.http_test_url, data={"cheese": "please bby"}) + if not isinstance(r, Response): + raise TypeError("expected Response") j = r.json() assert next(v == "please bby" for k, v in j["form"] if k == "cheese") @@ -252,7 +261,7 @@ async def test_data_dict_set(server): @pytest.mark.parametrize('server', [ dict(steps=[accept_cookies_and_respond, finish]) ], indirect=True) -async def test_cookie_dict_send(server): +async def test_cookie_dict_send(server: Server) -> None: cookies = {"Test-Cookie": "Test Cookie Value", "koooookie": "pie"} @@ -260,7 +269,7 @@ async def test_cookie_dict_send(server): for cookie in r.cookies: assert cookie.name in cookies - if " " in cookie.value: + if cookie.value and " " in cookie.value: assert cookie.value == '"' + cookies[cookie.name] + '"' else: assert cookie.value == cookies[cookie.name] @@ -270,10 +279,12 @@ async def test_cookie_dict_send(server): @pytest.mark.parametrize('server', [dict(steps=[send_request_as_json, finish])], indirect=True) -async def test_header_set(server): +async def test_header_set(server: Server) -> None: r = await asks.get( server.http_test_url, headers={"Asks-Header": "Test Header Value"} ) + if not isinstance(r, Response): + raise TypeError("expected Response") j = r.json() assert any(k == "asks-header" for k, _ in j["headers"]) @@ -289,8 +300,11 @@ async def test_header_set(server): @pytest.mark.parametrize('server', [dict(steps=[send_request_as_json, finish])], indirect=True) -async def test_file_send_single(server): +async def test_file_send_single(server: Server) -> None: r = await asks.post(server.http_test_url, files={"file_1": TEST_FILE1}) + if not isinstance(r, Response): + raise TypeError("expected Response") + j = r.json() assert any(file_data["name"] == "file_1" for file_data in j["files"]) @@ -302,10 +316,14 @@ async def test_file_send_single(server): @pytest.mark.parametrize('server', [dict(steps=[send_request_as_json, finish])], indirect=True) -async def test_file_send_double(server): +async def test_file_send_double(server: Server) -> None: r = await asks.post( - server.http_test_url, files={"file_1": TEST_FILE1, "file_2": TEST_FILE2} + server.http_test_url, files={ + "file_1": TEST_FILE1, "file_2": TEST_FILE2} ) + if not isinstance(r, Response): + raise TypeError("expected Response") + j = r.json() assert any(file_data["name"] == "file_1" for file_data in j["files"]) @@ -322,11 +340,14 @@ async def test_file_send_double(server): @pytest.mark.parametrize('server', [dict(steps=[send_request_as_json, finish])], indirect=True) -async def test_file_send_file_and_form_data(server): +async def test_file_send_file_and_form_data(server: Server) -> None: r = await asks.post( server.http_test_url, files={"file_1": TEST_FILE1, "data_1": "watwatwatwat=yesyesyes"}, ) + if not isinstance(r, Response): + raise TypeError("expected Response") + j = r.json() assert any(file_data["name"] == "file_1" for file_data in j["files"]) @@ -352,8 +373,11 @@ async def test_file_send_file_and_form_data(server): @pytest.mark.parametrize('server', [dict(steps=[send_request_as_json, finish])], indirect=True) -async def test_multipart_send_single(server): +async def test_multipart_send_single(server: Server) -> None: r = await asks.post(server.http_test_url, multipart={"file_1": Path(TEST_FILE1)}) + if not isinstance(r, Response): + raise TypeError("expected Response") + j = r.json() assert any(file_data["name"] == "file_1" for file_data in j["files"]) @@ -365,9 +389,12 @@ async def test_multipart_send_single(server): @pytest.mark.parametrize('server', [dict(steps=[send_request_as_json, finish])], indirect=True) -async def test_multipart_send_single_already_open(server): +async def test_multipart_send_single_already_open(server: Server) -> None: with open(TEST_FILE1, "rb") as f: r = await asks.post(server.http_test_url, multipart={"file_1": f}) + if not isinstance(r, Response): + raise TypeError("expected Response") + j = r.json() assert any(file_data["name"] == "file_1" for file_data in j["files"]) @@ -379,9 +406,12 @@ async def test_multipart_send_single_already_open(server): @pytest.mark.parametrize('server', [dict(steps=[send_request_as_json, finish])], indirect=True) -async def test_multipart_send_single_already_open_async(server): +async def test_multipart_send_single_already_open_async(server: Server) -> None: async with await open_file(TEST_FILE1, "rb") as f: r = await asks.post(server.http_test_url, multipart={"file_1": f}) + if not isinstance(r, Response): + raise TypeError("expected Response") + j = r.json() assert any(file_data["name"] == "file_1" for file_data in j["files"]) @@ -393,7 +423,7 @@ async def test_multipart_send_single_already_open_async(server): @pytest.mark.parametrize('server', [dict(steps=[send_request_as_json, finish])], indirect=True) -async def test_multipart_send_raw_bytes(server): +async def test_multipart_send_raw_bytes(server: Server) -> None: r = await asks.post( server.http_test_url, multipart={ @@ -402,6 +432,9 @@ async def test_multipart_send_raw_bytes(server): ) }, ) + if not isinstance(r, Response): + raise TypeError("expected Response") + j = r.json() assert any(file_data["name"] == "file_1" for file_data in j["files"]) @@ -413,11 +446,14 @@ async def test_multipart_send_raw_bytes(server): @pytest.mark.parametrize('server', [dict(steps=[send_request_as_json, finish])], indirect=True) -async def test_multipart_send_double(server): +async def test_multipart_send_double(server: Server) -> None: r = await asks.post( server.http_test_url, multipart={"file_1": Path(TEST_FILE1), "file_2": Path(TEST_FILE2)}, ) + if not isinstance(r, Response): + raise TypeError("expected Response") + j = r.json() assert any(file_data["name"] == "file_1" for file_data in j["files"]) @@ -434,11 +470,15 @@ async def test_multipart_send_double(server): @pytest.mark.parametrize('server', [dict(steps=[send_request_as_json, finish])], indirect=True) -async def test_multipart_send_file_and_form_data(server): +async def test_multipart_send_file_and_form_data(server: Server) -> None: r = await asks.post( server.http_test_url, - multipart={"file_1": Path(TEST_FILE1), "data_1": "watwatwatwat=yesyesyes"}, + multipart={"file_1": Path(TEST_FILE1), + "data_1": "watwatwatwat=yesyesyes"}, ) + if not isinstance(r, Response): + raise TypeError("expected Response") + j = r.json() assert any(file_data["name"] == "file_1" for file_data in j["files"]) @@ -459,10 +499,13 @@ async def test_multipart_send_file_and_form_data(server): @pytest.mark.parametrize('server', [dict(steps=[send_request_as_json, finish])], indirect=True) -async def test_json_send(server): +async def test_json_send(server: Server) -> None: r = await asks.post( server.http_test_url, json={"key_1": True, "key_2": "cheesestring"} ) + if not isinstance(r, Response): + raise TypeError("expected Response") + j = r.json() json_1 = next(iter(j["json"])) @@ -477,16 +520,22 @@ async def test_json_send(server): @pytest.mark.parametrize('server', [ dict(steps=[partial(send_gzip, data="wolowolowolo"), finish]) ], indirect=True) -async def test_gzip(server): +async def test_gzip(server: Server) -> None: r = await asks.get(server.http_test_url) + if not isinstance(r, Response): + raise TypeError("expected Response") + assert r.text == "wolowolowolo" @pytest.mark.parametrize('server', [ dict(steps=[partial(send_deflate, data="wolowolowolo"), finish]) ], indirect=True) -async def test_deflate(server): +async def test_deflate(server: Server) -> None: r = await asks.get(server.http_test_url) + if not isinstance(r, Response): + raise TypeError("expected Response") + assert r.text == "wolowolowolo" @@ -496,17 +545,23 @@ async def test_deflate(server): @pytest.mark.parametrize('server', [ dict(steps=[partial(send_chunked, data=["ham "] * 10), finish]) ], indirect=True) -async def test_chunked(server): +async def test_chunked(server: Server) -> None: r = await asks.get(server.http_test_url) + if not isinstance(r, Response): + raise TypeError("expected Response") + assert r.text == "ham ham ham ham ham ham ham ham ham ham " @pytest.mark.parametrize('server', [ dict(steps=[partial(send_chunked, data=["ham "] * 10), finish]) ], indirect=True) -async def test_stream(server): +async def test_stream(server: Server) -> None: data = b"" r = await asks.get(server.http_test_url, stream=True) + if not isinstance(r, StreamResponse): + raise TypeError("expected StreamResponse") + async for chunk in r.body: data += chunk assert data == b"ham ham ham ham ham ham ham ham ham ham " @@ -518,12 +573,13 @@ async def test_stream(server): @pytest.mark.parametrize('server', [ dict(steps=[partial(send_chunked, data=["ham "] * 10), finish]) ], indirect=True) -async def test_callback(server): - async def callback_example(chunk): +async def test_callback(server: Server) -> None: + callback_data = b"" + + async def callback_example(chunk: bytearray) -> None: nonlocal callback_data callback_data += chunk - callback_data = b"" await asks.get(server.http_test_url, callback=callback_example) assert callback_data == b"ham ham ham ham ham ham ham ham ham ham " @@ -533,11 +589,14 @@ async def callback_example(chunk): @pytest.mark.parametrize('server', [ dict( - steps=[partial(send_200_blank_headers, headers=[("connection", "close")]), finish], + steps=[partial(send_200_blank_headers, headers=[ + ("connection", "close")]), finish], ) ], indirect=True) -async def test_connection_close_no_content_len(server): +async def test_connection_close_no_content_len(server: Server) -> None: r = await asks.get(server.http_test_url) + if not isinstance(r, Response): + raise TypeError("expected Response") assert r.text == "200" @@ -549,12 +608,13 @@ async def test_connection_close_no_content_len(server): @pytest.mark.parametrize('server', [ dict( - steps=[partial(send_200_blank_headers, headers=[("connection", "close")]), finish], + steps=[partial(send_200_blank_headers, headers=[ + ("connection", "close")]), finish], max_requests=10, ) ], indirect=True) -async def test_session_smallpool(server): - async def worker(s): +async def test_session_smallpool(server: Server) -> None: + async def worker(s: asks.Session) -> None: r = await s.get(path="/get") assert r.status_code == 200 @@ -571,12 +631,16 @@ async def worker(s): @pytest.mark.parametrize('server', [ dict(steps=[accept_cookies_and_respond, finish]) ], indirect=True) -async def test_session_stateful(server): +async def test_session_stateful(server: Server) -> None: s = asks.Session(server.http_test_url, persist_cookies=True) await s.get(cookies={"Test-Cookie": "Test Cookie Value"}) - assert ":".join(str(x) for x in _TEST_LOC) in s._cookie_tracker.domain_dict.keys() + if not s._cookie_tracker: + raise ValueError("expected s._cookie_tracker to not be None") + assert ":".join(str(x) + for x in _TEST_LOC) in s._cookie_tracker.domain_dict.keys() assert ( - s._cookie_tracker.domain_dict[":".join(str(x) for x in _TEST_LOC)][0].value + s._cookie_tracker.domain_dict[":".join( + str(x) for x in _TEST_LOC)][0].value == '"Test Cookie Value"' ) @@ -584,40 +648,56 @@ async def test_session_stateful(server): # Test session instantiates outside event loop -def test_instantiate_session_outside_of_event_loop(): +def test_instantiate_session_outside_of_event_loop() -> None: try: asks.Session() except RuntimeError: pytest.fail("Could not instantiate Session outside of event loop") -async def test_session_unknown_kwargs(): +async def test_session_unknown_kwargs() -> None: with pytest.raises(TypeError, match=r"request\(\) got .*"): session = asks.Session("https://httpbin.org/get") await session.request("GET", ko=7, foo=0, bar=3, shite=3) pytest.fail("Passing unknown kwargs does not raise TypeError") -async def test_recv_event_anyio2_end_of_stream(): +async def test_recv_event_anyio2_end_of_stream() -> None: class MockH11Connection: - def __init__(self): - self.data = None - def next_event(self): + def __init__(self) -> None: + self.data: Optional[bytes] = None + + def next_event(self) -> type: if self.data == b"": return h11.PAUSED else: return h11.NEED_DATA - def receive_data(self, data): + + def receive_data(self, data: bytes) -> None: self.data = data class MockSock: - def receive(self): + + def __init__(self) -> None: + self.host = "mocksock" + self.port = "mocksock" + self._active = True + + async def aclose(self) -> None: + ... + + async def send(self, item: Optional[bytes]) -> None: + ... + + async def receive(self) -> None: raise EndOfStream req = RequestProcessor(None, "get", "toot-toot", None) - req.sock = MockSock() + # TODO: fix the leaky abstraction! + req.sock = cast(SocketLike, MockSock()) h11_connection = MockH11Connection() - event = await req._recv_event(h11_connection) - assert event is h11.PAUSED + # TODO: fix the leaky abstraction! + event = await req._recv_event(cast(h11.Connection, h11_connection)) + assert cast(type, event) is h11.PAUSED assert h11_connection.data == b"" diff --git a/tests/test_http_utils.py b/tests/test_http_utils.py index 843221b..a9dbef7 100644 --- a/tests/test_http_utils.py +++ b/tests/test_http_utils.py @@ -1,5 +1,6 @@ import zlib import gzip +from typing import Callable import pytest @@ -12,34 +13,34 @@ @pytest.mark.parametrize( "compressor,name", [(zlib.compress, "deflate"), (gzip.compress, "gzip")] ) -def test_decompress_one_zlib(compressor, name): - data = zlib.compress(INPUT_DATA) - decompressor = http_utils.decompress_one("deflate") +def test_decompress_one_zlib(compressor: Callable[[bytes], bytes], name: str) -> None: + data = compressor(INPUT_DATA) + decompressor = http_utils.decompress_one(name) result = b"" for i in range(len(data)): - b = data[i : i + 1] + b = data[i: i + 1] result += decompressor.send(b) assert result == INPUT_DATA -def test_decompress(): +def test_decompress() -> None: # we don't expect to see multiple compression types in the wild # but test anyway data = zlib.compress(gzip.compress(INPUT_DATA)) decompressor = http_utils.decompress(["gzip", "deflate"]) result = b"" for i in range(len(data)): - b = data[i : i + 1] + b = data[i: i + 1] result += decompressor.send(b) assert result == INPUT_DATA -def test_decompress_decoding(): +def test_decompress_decoding() -> None: data = zlib.compress(UNICODE_INPUT_DATA.encode("utf-8")) decompressor = http_utils.decompress(["deflate"], encoding="utf-8") result = "" for i in range(len(data)): - b = data[i : i + 1] + b = data[i: i + 1] res = decompressor.send(b) result += res assert result == UNICODE_INPUT_DATA @@ -78,15 +79,15 @@ def test_decompress_decoding(): ) ] ) -def test_api_url_construction(url_segments, expected): +def test_api_url_construction(url_segments: tuple[str, str, str], expected: str) -> None: base_location, endpoint, path = url_segments session = Session(base_location=base_location, endpoint=endpoint) constructed_url = session._make_url(path) assert constructed_url == expected -def test_api_url_construction_but_no_base(): +def test_api_url_construction_but_no_base() -> None: base_location, endpoint, path = ("", "/some_endpoint", "/some_path") session = Session(base_location=base_location, endpoint=endpoint) with pytest.raises(ValueError): - constructed_url = session._make_url(path) + _ = session._make_url(path) diff --git a/tests/test_multipart.py b/tests/test_multipart.py index eaf7149..ec0f172 100644 --- a/tests/test_multipart.py +++ b/tests/test_multipart.py @@ -12,7 +12,7 @@ @pytest.fixture(scope="session") -def dummy_file_path(tmpdir_factory): +def dummy_file_path(tmpdir_factory: pytest.TempdirFactory) -> Path: dummy = tmpdir_factory.mktemp("multipart").join("test.txt") with open(dummy, "w") as f: @@ -21,7 +21,7 @@ def dummy_file_path(tmpdir_factory): return Path(dummy) -async def test_multipart_body_dummy_file(): +async def test_multipart_body_dummy_file() -> None: assert ( await build_multipart_body( values=OrderedDict( @@ -38,7 +38,7 @@ async def test_multipart_body_dummy_file(): ) -async def test_multipart_body_with_not_file_argument(): +async def test_multipart_body_with_not_file_argument() -> None: assert ( await build_multipart_body( values=OrderedDict( @@ -56,14 +56,14 @@ async def test_multipart_body_with_not_file_argument(): ) -async def test_multipart_body_with_file_like_argument(): +async def test_multipart_body_with_file_like_argument() -> None: # Simulate an open file with a BytesIO. f = BytesIO(b"dummyfile\n") f.name = "test.txt" assert ( await build_multipart_body( - values=OrderedDict({"file": f, "notfile": "abc",}), + values=OrderedDict({"file": f, "notfile": "abc", }), encoding="utf8", boundary_data="8banana133744910kmmr13a56!102!8423", ) @@ -71,10 +71,10 @@ async def test_multipart_body_with_file_like_argument(): ) -async def test_multipart_body_with_path_argument(dummy_file_path): +async def test_multipart_body_with_path_argument(dummy_file_path: Path) -> None: assert ( await build_multipart_body( - values=OrderedDict({"file": dummy_file_path, "notfile": "abc",}), + values=OrderedDict({"file": dummy_file_path, "notfile": "abc", }), encoding="utf8", boundary_data="8banana133744910kmmr13a56!102!8423", ) @@ -82,7 +82,7 @@ async def test_multipart_body_with_path_argument(dummy_file_path): ) -async def test_multipart_body_with_multiple_arguments(dummy_file_path): +async def test_multipart_body_with_multiple_arguments(dummy_file_path: Path) -> None: # Simulate an open file with a BytesIO. f = BytesIO(b"dummyfile2\n") f.name = "test.jpg" @@ -90,7 +90,8 @@ async def test_multipart_body_with_multiple_arguments(dummy_file_path): assert ( await build_multipart_body( values=OrderedDict( - {"file": dummy_file_path, "file2": f, "notfile": "abc", "integer": 3,} + {"file": dummy_file_path, "file2": f, + "notfile": "abc", "integer": 3, } ), encoding="utf8", boundary_data="8banana133744910kmmr13a56!102!8423", @@ -99,7 +100,7 @@ async def test_multipart_body_with_multiple_arguments(dummy_file_path): ) -async def test_multipart_body_with_custom_metadata(): +async def test_multipart_body_with_custom_metadata() -> None: # Simulate an open file with a BytesIO. f = BytesIO(b"dummyfile but it is a jpeg\n") f.name = "test.jpg" @@ -107,7 +108,8 @@ async def test_multipart_body_with_custom_metadata(): assert ( await build_multipart_body( values=OrderedDict( - {"file": MultipartData(f, mime_type="text/plain", basename="test.txt"),} + {"file": MultipartData( + f, mime_type="text/plain", basename="test.txt"), } ), encoding="utf8", boundary_data="8banana133744910kmmr13a56!102!5649", @@ -116,10 +118,10 @@ async def test_multipart_body_with_custom_metadata(): ) -async def test_multipart_body_with_real_test_file(dummy_file_path): +async def test_multipart_body_with_real_test_file(dummy_file_path: Path) -> None: assert ( await build_multipart_body( - values=OrderedDict({"file": dummy_file_path,}), + values=OrderedDict({"file": dummy_file_path, }), encoding="utf8", boundary_data="8banana133744910kmmr13a56!102!5649", ) @@ -127,11 +129,11 @@ async def test_multipart_body_with_real_test_file(dummy_file_path): ) -async def test_multipart_body_with_real_pre_opened_test_file(dummy_file_path): +async def test_multipart_body_with_real_pre_opened_test_file(dummy_file_path: Path) -> None: async with await open_file(dummy_file_path, "rb") as f: assert ( await build_multipart_body( - values=OrderedDict({"file": f,}), + values=OrderedDict({"file": f, }), encoding="utf8", boundary_data="8banana133744910kmmr13a56!102!5649", ) diff --git a/tests/test_request_object.py b/tests/test_request_object.py index 383bf48..bf4d76a 100644 --- a/tests/test_request_object.py +++ b/tests/test_request_object.py @@ -1,42 +1,52 @@ # pylint: disable=no-member +from typing import Any, Union, cast import h11 import pytest from asks.request_object import RequestProcessor +from asks.response_objects import Response, StreamResponse -def _catch_response(monkeypatch, headers, data, http_version=b"1.1"): +def _catch_response(monkeypatch: pytest.MonkeyPatch, + headers: list[tuple[str, str]], + data: bytes, + http_version: bytes = b"1.1" + ) -> Union[Response, StreamResponse]: req = RequestProcessor(None, "get", "toot-toot", None) events = [ - h11._events.Response(status_code=200, headers=headers, http_version=http_version), + h11._events.Response(status_code=200, headers=headers, + http_version=http_version), h11._events.Data(data=data), h11._events.EndOfMessage(), ] - async def _recv_event(hconn): + async def _recv_event(hconn: Any) -> h11._events.Event: return events.pop(0) monkeypatch.setattr(req, "_recv_event", _recv_event) monkeypatch.setattr(req, "host", "lol") - cr = req._catch_response(None) + cr = req._catch_response(cast(h11.Connection, None)) try: cr.send(None) except StopIteration as e: response = e.value - return response + return cast(Union[Response, StreamResponse], response) -def test_http1_1(monkeypatch): - response = _catch_response(monkeypatch, [("Content-Length", "5")], b"hello") +def test_http1_1(monkeypatch: pytest.MonkeyPatch) -> None: + response = _catch_response( + monkeypatch, [("Content-Length", "5")], b"hello") assert response.body == b"hello" -def test_http1_1_connection_close(monkeypatch): - response = _catch_response(monkeypatch, [("Connection", "close")], b"hello") +def test_http1_1_connection_close(monkeypatch: pytest.MonkeyPatch) -> None: + response = _catch_response( + monkeypatch, [("Connection", "close")], b"hello") assert response.body == b"hello" -def test_http1_0_no_content_length(monkeypatch): + +def test_http1_0_no_content_length(monkeypatch: pytest.MonkeyPatch) -> None: response = _catch_response(monkeypatch, [], b"hello", b"1.0") assert response.body == b"hello" @@ -50,5 +60,5 @@ def test_http1_0_no_content_length(monkeypatch): [{"false": False}, "?false=False"], ], ) -def test_dict_to_query(data, query_str): +def test_dict_to_query(data: dict[str, Any], query_str: str) -> None: assert RequestProcessor._dict_to_query(data) == query_str diff --git a/tests/test_response_objects.py b/tests/test_response_objects.py index 1816b54..70a37dc 100644 --- a/tests/test_response_objects.py +++ b/tests/test_response_objects.py @@ -1,12 +1,12 @@ import asks.response_objects -def test_response_repr(): +def test_response_repr() -> None: r = asks.response_objects.Response("ascii", "", 200, "Meh", {}, "", "", "") assert repr(r) == "" -def test_response_guess_encoding(): +def test_response_guess_encoding() -> None: r = asks.response_objects.Response( "ascii", "", 200, "", {"content-type": "text/plain; charset=utf-8"}, "", "", "" ) @@ -24,6 +24,6 @@ def test_response_guess_encoding(): assert r.encoding == "ascii" -def test_response_json(): +def test_response_json() -> None: r = asks.response_objects.Response(None, "", 200, "", {}, '{"foo":"bar"}', "", "") assert r.json() == {"foo": "bar"} diff --git a/tests/test_utils.py b/tests/test_utils.py index f66933a..43439fa 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -5,7 +5,7 @@ from asks.utils import get_netloc_port -def test_netloc_port(): +def test_netloc_port() -> None: assert ("example.com", "80") == get_netloc_port(urlparse("http://example.com")) assert ("example.com", "443") == get_netloc_port(urlparse("http://example.com:443")) assert ("example.com", "443") == get_netloc_port(urlparse("https://example.com"))