Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add type hints to asks module #191

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions asks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
# pylint: disable=wildcard-import
# pylint: disable=wrong-import-position

from .auth import *
from .base_funcs import *
from .sessions import *

from typing import Any
from warnings import warn

from .auth import * # noqa
from .base_funcs import * # noqa
from .sessions import * # noqa


def init(library):
def init(library: Any) -> None:
"""
Unused. asks+anyio auto-detects your library.
"""
Expand Down
48 changes: 35 additions & 13 deletions asks/auth.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
# pylint: disable=abstract-method
from abc import abstractmethod, ABCMeta

import re

import base64
import re
from abc import ABCMeta, abstractmethod
from hashlib import md5
from random import choice
from string import ascii_lowercase, digits

from typing import Any

__all__ = ["AuthBase", "PreResponseAuth", "PostResponseAuth", "BasicAuth", "DigestAuth"]

Expand All @@ -20,7 +18,11 @@ class AuthBase(metaclass=ABCMeta):
"""

@abstractmethod
def __call__(self):
async def __call__(
self,
response_obj: "response_objects.Response",
req_obj: "request_object.RequestProcessor",
) -> dict[str, str]:
"""Not Implemented"""


Expand All @@ -37,7 +39,9 @@ class PostResponseAuth(AuthBase):
Auth class for response dependant auth.
"""

def __init__(self):
auth_attempted: bool

def __init__(self) -> None:
self.auth_attempted = False


Expand All @@ -46,11 +50,15 @@ class BasicAuth(PreResponseAuth):
Ye Olde Basic HTTP Auth.
"""

def __init__(self, auth_info, encoding="utf-8"):
def __init__(self, auth_info: Any, encoding: str = "utf-8"):
self.auth_info = auth_info
self.encoding = encoding

async def __call__(self, _):
async def __call__(
self,
response_obj: "response_objects.Response",
req_obj: "request_object.RequestProcessor",
) -> dict[str, str]:
usrname, psword = [bytes(x, self.encoding) for x in self.auth_info]
encoded_auth = str(base64.b64encode(usrname + b":" + psword), self.encoding)
return {"Authorization": "Basic {}".format(encoded_auth)}
Expand All @@ -69,15 +77,21 @@ class DigestAuth(PostResponseAuth):

_HDR_VAL_PARSE = re.compile(r'\b(\w+)=(?:"([^\\"]+)"|(\S+))')

def __init__(self, auth_info, encoding="utf-8"):
domain_space: list[Any]

def __init__(self, auth_info: Any, encoding: str = "utf-8") -> None:
super().__init__()
self.auth_info = auth_info
self.encoding = encoding
self.domain_space = []
self.nonce = None
self.nonce_count = 1

async def __call__(self, response_obj, req_obj):
async def __call__(
self,
response_obj: "response_objects.Response",
req_obj: "request_object.RequestProcessor",
) -> dict[str, str]:

usrname, psword = [bytes(x, self.encoding) for x in self.auth_info]
try:
Expand Down Expand Up @@ -128,10 +142,14 @@ async def __call__(self, response_obj, req_obj):
self.encoding,
)

bytes_path = bytes(req_obj.path, self.encoding)
bytes_path = bytes(req_obj.path or "", self.encoding)
bytes_method = bytes(req_obj.method, self.encoding)
try:
if b"auth-int" in auth_dict["qop"].lower():

if isinstance(response_obj.raw, str):
raise ValueError("response_obj.raw is str when it shouldn't")

hashed_body = bytes(
md5(response_obj.raw or b"").hexdigest(), self.encoding
)
Expand Down Expand Up @@ -160,7 +178,7 @@ async def __call__(self, response_obj, req_obj):
auth_dict["nonce"],
bytes_nc,
cnonce,
bytes(qop, self.encoding),
bytes(qop or "", self.encoding),
ha2,
)
)
Expand Down Expand Up @@ -190,3 +208,7 @@ async def __call__(self, response_obj, req_obj):
]
)
return {"Authorization": "Digest {}".format(", ".join(response_items))}


from . import request_object # noqa
from . import response_objects # noqa
5 changes: 3 additions & 2 deletions asks/base_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
to the caller.
"""
from functools import partial
from typing import Any, Union

from .response_objects import Response, StreamResponse
from .sessions import Session


__all__ = ["get", "head", "post", "put", "delete", "options", "patch", "request"]


async def request(method, uri, **kwargs):
async def request(method: str, uri: str, **kwargs: Any) -> Union[Response, StreamResponse]:
"""Base function for one time http requests.

Args:
Expand Down
24 changes: 14 additions & 10 deletions asks/cookie_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
__all__ = ["CookieTracker", "parse_cookies"]


from .response_objects import Cookie
from typing import Any, Optional, Union

from .response_objects import BaseResponse, Cookie

_CookiesToSend = dict[Optional[str], Optional[str]]


class CookieTracker:
Expand All @@ -10,21 +14,21 @@ class CookieTracker:
the otherwise stateless general http method functions.
"""

def __init__(self):
self.domain_dict = {}
def __init__(self) -> None:
self.domain_dict: dict[str, list[Cookie]] = {}

def get_additional_cookies(self, netloc, path):
def get_additional_cookies(self, netloc: str, path: str) -> _CookiesToSend:
netloc = netloc.replace("://www.", "://", 1)
return self._check_cookies(netloc + path)

def _store_cookies(self, response_obj):
def _store_cookies(self, response_obj: BaseResponse[Any]) -> None:
for cookie in response_obj.cookies:
try:
self.domain_dict[cookie.host.lstrip()].append(cookie)
except KeyError:
self.domain_dict[cookie.host.lstrip()] = [cookie]

def _check_cookies(self, endpoint):
def _check_cookies(self, endpoint: str) -> _CookiesToSend:
relevant_domains = []
domains = self.domain_dict.keys()

Expand All @@ -37,22 +41,22 @@ def _check_cookies(self, endpoint):
relevant_domains.append(check_domain)
return self._get_cookies_to_send(relevant_domains)

def _get_cookies_to_send(self, domain_list):
cookies_to_go = {}
def _get_cookies_to_send(self, domain_list: list[str]) -> _CookiesToSend:
cookies_to_go: dict[Optional[str], Optional[str]] = {}
for domain in domain_list:
for cookie_obj in self.domain_dict[domain]:
cookies_to_go[cookie_obj.name] = cookie_obj.value
return cookies_to_go


def parse_cookies(response, host):
def parse_cookies(response: BaseResponse[Any], host: str) -> None:
"""
Sticks cookies to a response.
"""
cookie_pie = []
try:
for cookie in response.headers["set-cookie"]:
cookie_jar = {}
cookie_jar: dict[str, Union[str, bool]] = {}
name_val, *rest = cookie.split(";")
name, value = name_val.split("=", 1)
cookie_jar["name"] = name.strip()
Expand Down
12 changes: 11 additions & 1 deletion asks/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
Simple exceptions to be raised in case of errors.
"""

from typing import Any


class AsksException(Exception):
"""
Expand All @@ -28,7 +30,12 @@ class BadHttpResponse(AsksException):


class BadStatus(AsksException):
def __init__(self, err, response, status_code=500):
def __init__(
self,
err: Any,
response: "response_objects.BaseResponse[Any]",
status_code: int = 500,
) -> None:
super().__init__(err)
self.response = response
self.status_code = status_code
Expand All @@ -42,3 +49,6 @@ class RequestTimeout(ConnectivityError):

class ServerClosedConnectionError(ConnectivityError):
pass


from . import response_objects # noqa
40 changes: 25 additions & 15 deletions asks/http_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,23 @@


import codecs
from zlib import decompressobj, MAX_WBITS
from typing import Iterator, Optional
from zlib import MAX_WBITS, decompressobj

from .utils import processor


def parse_content_encoding(content_encoding: str) -> [str]:
def parse_content_encoding(content_encoding: str) -> list[str]:
compressions = [x.strip() for x in content_encoding.split(",")]
return compressions


@processor
def decompress(compressions, encoding=None):
data = b""
def decompress(
compressions: list[str], encoding: Optional[str] = None
) -> Iterator[bytes]:
encoded: Optional[bytes] = b""
decoded: bytes = b""
# https://tools.ietf.org/html/rfc7231
# "If one or more encodings have been applied to a representation, the
# sender that applied the encodings MUST generate a Content-Encoding
Expand All @@ -32,9 +36,11 @@ def decompress(compressions, encoding=None):
if encoding:
decompressors.append(make_decoder_shim(encoding))
while True:
data = yield data
for decompressor in decompressors:
data = decompressor.send(data)
encoded = yield decoded
if encoded is not None:
for decompressor in decompressors:
encoded = decompressor.send(encoded)
decoded = encoded


# https://tools.ietf.org/html/rfc7230#section-4.2.1 - #section-4.2.3
Expand All @@ -47,20 +53,24 @@ def decompress(compressions, encoding=None):


@processor
def decompress_one(compression):
data = b""
def decompress_one(compression: str) -> Iterator[bytes]:
encoded: Optional[bytes] = b""
decoded: bytes = b""
decompressor = decompressobj(wbits=DECOMPRESS_WBITS[compression])
while True:
data = yield data
data = decompressor.decompress(data)
encoded = yield decoded
if encoded is not None:
decoded = decompressor.decompress(encoded)
yield decompressor.flush()


@processor
def make_decoder_shim(encoding):
data = b""
def make_decoder_shim(encoding: str) -> Iterator[str]:
encoded: Optional[bytes] = b""
decoded: str = ""
decoder = codecs.getincrementaldecoder(encoding)(errors="replace")
while True:
data = yield data
data = decoder.decode(data)
encoded = yield decoded
if encoded is not None:
decoded = decoder.decode(encoded)
yield decoder.decode(b"", final=True)
Loading