Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ Using UV (recommended):
1. Install UV if you haven't already
`curl -LsSf https://astral.sh/uv/install.sh | sh`
2. Install the client and test dependencies
`uv pip install -e ".[all]" "pytest<8" "requests-mock>=1.8.0" "pytest-asyncio>0.21"`
`uv pip install -e ".[all]" "pytest<8" "pytest-mock>=3.15.0" "pytest-asyncio>0.21"`
3. Run tests
`pytest -sv --host <indico_host> tests/`
_ Only run unit tests `pytest -sv --host <indico_host> tests/unit/`
Expand All @@ -139,7 +139,7 @@ Or using pip:
3. Install the client
`pip3 install --editable .[all]`
4. Install test deps
`pip3 install "pytest<8" "requests-mock>=1.8.0" "pytest-asyncio>0.21"`
`pip3 install "pytest<8" "pytest-mock>=3.15.0" "pytest-asyncio>0.21"`
5. Run tests
`pytest -sv --host <indico_host> tests/`
_ Only run unit tests `pytest -sv --host <indico_host> tests/unit/`
Expand Down
2 changes: 1 addition & 1 deletion indico/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ async def create(self) -> "Self":
return self

async def cleanup(self) -> None:
await self._http.request_session.close()
await self._http.request_session.aclose()

async def _handle_request_chain(
self,
Expand Down
4 changes: 3 additions & 1 deletion indico/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ class IndicoConfig:
api_token_path= (str, optional): Path to the Indico API token file indico_api_token.txt. Defaults to user's home directory. Ignored if api_token is provided.
api_token= (str, optional): The actual text of the API Token. Takes precedence over api_token_path
verify_ssl= (bool, optional): Whether to verify the host's SSL certificate. Default=True
requests_params= (dict, optional): Dictionary of requests. Session parameters to set
requests_params= (dict, optional): Dictionary of httpx Client parameters to set
enable_http2= (bool, optional): Enable HTTP/2 for async client. Default=False (HTTP/1.1).

Returns:
IndicoConfig object
Expand All @@ -41,6 +42,7 @@ def __init__(self, **kwargs: "Any"):
self.verify_ssl: bool = True
self.requests_params: "Optional[AnyDict]" = None
self._disable_cookie_domain: bool = False
self.enable_http2: bool = False

for key, value in kwargs.items():
if hasattr(self, key):
Expand Down
178 changes: 67 additions & 111 deletions indico/http/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,10 @@
import logging
from contextlib import contextmanager
from copy import deepcopy
from http.cookiejar import DefaultCookiePolicy
from pathlib import Path
from typing import TYPE_CHECKING, cast

import aiohttp
import requests
import httpx

from indico.config import IndicoConfig
from indico.errors import (
Expand All @@ -20,10 +18,8 @@
from .retry import aioretry

if TYPE_CHECKING: # pragma: no cover
from http.cookiejar import Cookie
from io import IOBase
from typing import Any, Dict, Iterator, List, Optional, Union
from urllib.request import Request

from indico.client.request import HTTPRequest, ResponseType
from indico.typing import AnyDict
Expand All @@ -32,30 +28,18 @@
logger = logging.getLogger(__file__)


class CookiePolicyOverride(DefaultCookiePolicy):
def set_ok(self, cookie: "Cookie", request: "Request") -> bool:
return True

def return_ok(self, cookie: "Cookie", request: "Request") -> bool:
return True

def path_return_ok(self, path: str, request: "Request") -> bool:
return True

def domain_return_ok(self, domain: str, request: "Request") -> bool:
return True


class HTTPClient:
def __init__(self, config: "Optional[IndicoConfig]" = None):
self.config = config or IndicoConfig()
self.base_url = f"{self.config.protocol}://{self.config.host}"

self.request_session = requests.Session()
self.request_session = httpx.Client(
http2=self.config.enable_http2,
verify=self.config.verify_ssl,
)
if isinstance(self.config.requests_params, dict):
for param in self.config.requests_params.keys():
setattr(self.request_session, param, self.config.requests_params[param])
self.request_session.cookies.set_policy(CookiePolicyOverride())

self.get_short_lived_access_token()

Expand All @@ -73,28 +57,21 @@ def get(
return self._make_request("post", *args, params=params, **kwargs)

def get_short_lived_access_token(self) -> "AnyDict":
# If the cookie here is already due to _disable_cookie_domain set and we try to
# pop it later it will error out because we have two cookies with the same
# name. We just remove the old one here as we are about to refresh it.
if "auth_token" in self.request_session.cookies:
self.request_session.cookies.pop("auth_token")
self.request_session.cookies.delete("auth_token")

r = self.post(
"/auth/users/refresh_token",
headers={"Authorization": f"Bearer {self.config.api_token}"},
_refreshed=True,
)

# Disable cookie domain in cases where the domain won't match due to using short name domains
if self.config._disable_cookie_domain:
value = self.request_session.cookies.get("auth_token", None)
value = self.request_session.cookies.get("auth_token")
if not value:
raise IndicoAuthenticationFailed()
self.request_session.cookies.pop("auth_token")
self.request_session.cookies.set_cookie(
# must ignore because untyped in typeshed
requests.cookies.create_cookie(name="auth_token", value=value) # type: ignore
)
self.request_session.cookies.delete("auth_token")
self.request_session.cookies.set("auth_token", value)

return cast("AnyDict", r)

Expand Down Expand Up @@ -165,13 +142,9 @@ def _make_request(
f"[{method}] {path}\n\t Headers: {headers}\n\tRequest Args:{request_kwargs}"
)
with self._handle_files(request_kwargs) as new_kwargs:
response = getattr(self.request_session, method)(
response: httpx.Response = getattr(self.request_session, method)(
f"{self.base_url}{path}",
headers=headers,
stream=True,
verify=False
if not self.config.verify_ssl or not self.request_session.verify
else True,
**new_kwargs,
)

Expand All @@ -193,7 +166,7 @@ def _make_request(
if response.status_code >= 500:
raise IndicoRequestError(
code=response.status_code,
error=response.reason,
error=response.reason_phrase or "",
extras=repr(response.content),
)

Expand Down Expand Up @@ -221,19 +194,17 @@ def _make_request(

class AIOHTTPClient:
"""
Beta client with a minimal set of features that can execute
requests using the aiohttp library
Async client using httpx. Supports HTTP/1.1 and HTTP/2 (toggle via config.use_http2).
"""

def __init__(self, config: "Optional[IndicoConfig]" = None):
"""
Config options specific to aiohttp
unsafe - allows interacting with IP urls
"""
self.config = config or IndicoConfig()
self.base_url = f"{self.config.protocol}://{self.config.host}"

self.request_session = aiohttp.ClientSession()
self.request_session = httpx.AsyncClient(
http2=self.config.enable_http2,
verify=self.config.verify_ssl,
)
if isinstance(self.config.requests_params, dict):
for param in self.config.requests_params.keys():
setattr(self.request_session, param, self.config.requests_params[param])
Expand All @@ -252,12 +223,14 @@ async def get(
return await self._make_request("post", *args, params=params, **kwargs)

async def get_short_lived_access_token(self) -> "AnyDict":
r = await self.post(
"/auth/users/refresh_token",
headers={"Authorization": f"Bearer {self.config.api_token}"},
_refreshed=True,
return cast(
"AnyDict",
await self.post(
"/auth/users/refresh_token",
headers={"Authorization": f"Bearer {self.config.api_token}"},
_refreshed=True,
),
)
return cast("AnyDict", r)

async def execute_request(
self, request: "HTTPRequest[ResponseType]"
Expand All @@ -269,54 +242,40 @@ async def execute_request(
)

@contextmanager
def _handle_files(
self, req_kwargs: "AnyDict"
) -> "Iterator[List[aiohttp.FormData]]":
files = []
file_args = []
def _handle_files(self, req_kwargs: "AnyDict") -> "Iterator[List[Dict[str, Any]]]":
files: "List[Any]" = []
file_args: "List[Dict[str, Any]]" = []
dup_counts: "Dict[str, int]" = {}
for filepath in req_kwargs.pop("files", []) or []:
data = aiohttp.FormData()
path = Path(filepath)
fd = path.open("rb")
files.append(fd)
# follow the convention of adding (n) after a duplicate filename
_add_suffix = f".{path.suffix}" if path.suffix else ""
if path.stem in dup_counts:
data.add_field(
"file",
fd,
filename=path.stem + f"({dup_counts[path.stem]})" + _add_suffix,
)
name = path.stem + f"({dup_counts[path.stem]})" + _add_suffix
dup_counts[path.stem] += 1
else:
data.add_field("file", fd, filename=path.name)
name = path.name
dup_counts[path.stem] = 1
file_args.append(data)
file_args.append({"files": {"file": (name, fd)}})

for filename, stream in (req_kwargs.pop("streams", {}) or {}).items():
# similar operation as above.
files.append(stream)
data = aiohttp.FormData()
if filename in dup_counts:
data.add_field(
"file",
stream,
filename=filename + f"({dup_counts[filename]})",
)
name = filename + f"({dup_counts[filename]})"
dup_counts[filename] += 1
else:
data.add_field("file", stream, filename=filename)
name = filename
dup_counts[filename] = 1
file_args.append(data)
file_args.append({"files": {"file": (name, stream)}})

try:
yield file_args
finally:
for f in files:
f.close()

@aioretry(aiohttp.ClientConnectionError, aiohttp.ServerDisconnectedError)
@aioretry(httpx.ConnectError, httpx.RemoteProtocolError, httpx.ReadError)
async def _make_request(
self,
method: str,
Expand All @@ -336,57 +295,54 @@ async def _make_request(
resps = await asyncio.gather(
*(
self._make_request(
method, path, headers, **request_kwargs, data=data
method, path, headers, **request_kwargs, **data
)
for data in file_args
)
)
return [resp for resp_set in resps for resp in resp_set]

async with getattr(self.request_session, method)(
response: httpx.Response = await getattr(self.request_session, method)(
f"{self.base_url}{path}",
headers=headers,
verify_ssl=self.config.verify_ssl,
**request_kwargs,
) as response:
# If auth expired refresh
if response.status == 401 and not _refreshed:
await self.get_short_lived_access_token()
return await self._make_request(
method, path, headers, _refreshed=True, **request_kwargs
)
elif response.status == 401 and _refreshed:
raise IndicoAuthenticationFailed()
)

if response.status == 503 and "Retry-After" in response.headers:
raise IndicoHibernationError(
after=response.headers.get("Retry-After")
)
if response.status_code == 401 and not _refreshed:
await self.get_short_lived_access_token()
return await self._make_request(
method, path, headers, _refreshed=True, **request_kwargs
)
if response.status_code == 401 and _refreshed:
raise IndicoAuthenticationFailed()

if response.status >= 500:
raise IndicoRequestError(
code=response.status,
error=response.reason,
extras=repr(response.content),
)
if response.status_code == 503 and "Retry-After" in response.headers:
raise IndicoHibernationError(after=response.headers.get("Retry-After"))

content: "Any" = await aio_deserialize(
response, force_json=json, force_decompress=decompress
if response.status_code >= 500:
raise IndicoRequestError(
code=response.status_code,
error=response.reason_phrase or "",
extras=repr(response.content),
)

if response.status >= 400:
if isinstance(content, dict):
error = (
f"{content.pop('error_type', 'Unknown Error')}, "
f"{content.pop('message', '')}"
)
extras = content
else:
error = content
extras = None
content: "Any" = await aio_deserialize(
response, force_json=json, force_decompress=decompress
)

raise IndicoRequestError(
error=error, code=response.status, extras=extras
if response.status_code >= 400:
if isinstance(content, dict):
error = (
f"{content.pop('error_type', 'Unknown Error')}, "
f"{content.pop('message', '')}"
)
extras = content
else:
error = content
extras = None

raise IndicoRequestError(
error=error, code=response.status_code, extras=extras
)

return content
return content
Loading