Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 4 additions & 7 deletions google/genai/_interactions/_qs.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,16 @@

from __future__ import annotations

from typing import Any, List, Tuple, Union, Mapping, TypeVar
from typing import Any, List, Mapping, Tuple, TypeVar, Union
from urllib.parse import parse_qs, urlencode
from typing_extensions import Literal, get_args

from ._types import NotGiven, not_given
from typing_extensions import get_args

from ._types import ArrayFormat, NestedFormat, NotGiven, not_given
from ._utils import flatten

_T = TypeVar("_T")


ArrayFormat = Literal["comma", "repeat", "indices", "brackets"]
NestedFormat = Literal["dots", "brackets"]

PrimitiveData = Union[str, int, float, bool, None]
# this should be Data = Union[PrimitiveData, "List[Data]", "Tuple[Data]", "Mapping[str, Data]"]
# https://github.com/microsoft/pyright/issues/3555
Expand Down
3 changes: 3 additions & 0 deletions google/genai/_interactions/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@
ModelT = TypeVar("ModelT", bound=pydantic.BaseModel)
_T = TypeVar("_T")

ArrayFormat = Literal["comma", "repeat", "indices", "brackets"]
NestedFormat = Literal["dots", "brackets"]


# Approximates httpx internal ProxiesTypes and RequestFiles types
# while adding support for `PathLike` instances
Expand Down
188 changes: 112 additions & 76 deletions google/genai/_interactions/_utils/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,28 +16,28 @@
# mypy: ignore-errors
from __future__ import annotations

from datetime import date, datetime
import functools
import inspect
import os
from pathlib import Path
import re
import inspect
import functools
from typing import (
Any,
Tuple,
Mapping,
TypeVar,
Callable,
Iterable,
Mapping,
Sequence,
Tuple,
TypeVar,
cast,
overload,
)
from pathlib import Path
from datetime import date, datetime
from typing_extensions import TypeGuard

import sniffio
from typing_extensions import TypeGuard, get_args

from .._types import Omit, NotGiven, FileTypes, HeadersLike
from .._types import ArrayFormat, FileTypes, HeadersLike, NotGiven, Omit

_T = TypeVar("_T")
_TupleT = TypeVar("_TupleT", bound=Tuple[object, ...])
Expand All @@ -56,17 +56,41 @@ def extract_files(
query: Mapping[str, object],
*,
paths: Sequence[Sequence[str]],
array_format: ArrayFormat = "brackets",
) -> list[tuple[str, FileTypes]]:
"""Recursively extract files from the given dictionary based on specified paths.
"""Recursively extract files from the given dictionary based on specified paths.

A path may look like this ['foo', 'files', '<array>', 'data'].
A path may look like this ['foo', 'files', '<array>', 'data'].

Note: this mutates the given dictionary.
"""
files: list[tuple[str, FileTypes]] = []
for path in paths:
files.extend(_extract_items(query, path, index=0, flattened_key=None))
return files
``array_format`` controls how ``<array>`` segments contribute to the emitted
field name. Supported values: ``"brackets"`` (``foo[]``), ``"repeat"`` and
``"comma"`` (``foo``), ``"indices"`` (``foo[0]``, ``foo[1]``).

Note: this mutates the given dictionary.
"""
files: list[tuple[str, FileTypes]] = []
for path in paths:
files.extend(
_extract_items(
query, path, index=0, flattened_key=None, array_format=array_format
)
)
return files


def _array_suffix(array_format: ArrayFormat, array_index: int) -> str:
if array_format == "brackets":
return "[]"
if array_format == "indices":
return f"[{array_index}]"
if array_format == "repeat" or array_format == "comma":
# Both repeat the bare field name for each file part; there is no
# meaningful way to comma-join binary parts.
return ""
raise NotImplementedError(
f"Unknown array_format value: {array_format}, choose from"
f" {', '.join(get_args(ArrayFormat))}"
)


def _extract_items(
Expand All @@ -75,72 +99,78 @@ def _extract_items(
*,
index: int,
flattened_key: str | None,
array_format: ArrayFormat,
) -> list[tuple[str, FileTypes]]:
try:
key = path[index]
except IndexError:
if not is_given(obj):
# no value was provided - we can safely ignore
return []

# cyclical import
from .._files import assert_is_file_content

# We have exhausted the path, return the entry we found.
assert flattened_key is not None

if is_list(obj):
files: list[tuple[str, FileTypes]] = []
for array_index, entry in enumerate(obj):
suffix = _array_suffix(array_format, array_index)
emitted_key = (flattened_key + suffix) if flattened_key else suffix
assert_is_file_content(entry, key=emitted_key)
files.append((emitted_key, cast(FileTypes, entry)))
return files

assert_is_file_content(obj, key=flattened_key)
return [(flattened_key, cast(FileTypes, obj))]

index += 1
if is_dict(obj):
try:
key = path[index]
except IndexError:
if not is_given(obj):
# no value was provided - we can safely ignore
return []

# cyclical import
from .._files import assert_is_file_content

# We have exhausted the path, return the entry we found.
assert flattened_key is not None

if is_list(obj):
files: list[tuple[str, FileTypes]] = []
for entry in obj:
assert_is_file_content(entry, key=flattened_key + "[]" if flattened_key else "")
files.append((flattened_key + "[]", cast(FileTypes, entry)))
return files

assert_is_file_content(obj, key=flattened_key)
return [(flattened_key, cast(FileTypes, obj))]

index += 1
if is_dict(obj):
try:
# Remove the field if there are no more dict keys in the path,
# only "<array>" traversal markers or end.
if all(p == "<array>" for p in path[index:]):
item = obj.pop(key)
else:
item = obj[key]
except KeyError:
# Key was not present in the dictionary, this is not indicative of an error
# as the given path may not point to a required field. We also do not want
# to enforce required fields as the API may differ from the spec in some cases.
return []
if flattened_key is None:
flattened_key = key
else:
flattened_key += f"[{key}]"
return _extract_items(
# Remove the field if there are no more dict keys in the path,
# only "<array>" traversal markers or end.
if all(p == "<array>" for p in path[index:]):
item = obj.pop(key)
else:
item = obj[key]
except KeyError:
# Key was not present in the dictionary, this is not indicative of an error
# as the given path may not point to a required field. We also do not want
# to enforce required fields as the API may differ from the spec in some cases.
return []
if flattened_key is None:
flattened_key = key
else:
flattened_key += f"[{key}]"
return _extract_items(
item,
path,
index=index,
flattened_key=flattened_key,
array_format=array_format,
)
elif is_list(obj):
if key != "<array>":
return []

return flatten([
_extract_items(
item,
path,
index=index,
flattened_key=flattened_key,
)
elif is_list(obj):
if key != "<array>":
return []

return flatten(
[
_extract_items(
item,
path,
index=index,
flattened_key=flattened_key + "[]" if flattened_key is not None else "[]",
)
for item in obj
]
flattened_key=(
(flattened_key if flattened_key is not None else "")
+ _array_suffix(array_format, array_index)
),
array_format=array_format,
)
for array_index, item in enumerate(obj)
])

# Something unexpected was passed, just ignore it.
return []
# Something unexpected was passed, just ignore it.
return []


def is_given(obj: _T | NotGiven | Omit) -> TypeGuard[_T]:
Expand Down Expand Up @@ -293,14 +323,20 @@ def wrapper(*args: object, **kwargs: object) -> object:


@overload


def strip_not_given(obj: None) -> None: ...


@overload


def strip_not_given(obj: Mapping[_K, _V | NotGiven]) -> dict[_K, _V]: ...


@overload


def strip_not_given(obj: object) -> object: ...


Expand Down
Loading