Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 77 additions & 1 deletion packages/bigframes/bigframes/core/compile/polars/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@
import bigframes.operations.generic_ops as gen_ops
import bigframes.operations.json_ops as json_ops
import bigframes.operations.numeric_ops as num_ops
import bigframes.operations.remote_function_ops as remote_function_ops
import bigframes.operations.string_ops as string_ops
import bigframes.operations.struct_ops as struct_ops
from bigframes.core import agg_expressions, identifiers, nodes, ordering, window_spec
from bigframes.core.compile.polars import lowering

Expand Down Expand Up @@ -122,7 +124,7 @@ def _bigframes_dtype_to_polars_dtype(
]
)
if bigframes.dtypes.is_array_like(dtype):
return pl.Array(
return pl.List(
inner=_bigframes_dtype_to_polars_dtype(
bigframes.dtypes.get_array_inner_type(dtype)
)
Expand Down Expand Up @@ -502,6 +504,50 @@ def _(self, op: json_ops.ToJSON, input: pl.Expr) -> pl.Expr:
else:
return input.cast(pl.String())

@compile_op.register(json_ops.ToJSONString)
def _(self, op: json_ops.ToJSONString, input: pl.Expr) -> pl.Expr:
from_type = self._expr_types.get(id(input))

def preprocess_binary(
expr: pl.Expr, dtype: bigframes.dtypes.ExpressionType
) -> pl.Expr:
if dtype == bigframes.dtypes.BYTES_DTYPE:
return expr.bin.encode("base64")
if bigframes.dtypes.is_struct_like(dtype):
fields = bigframes.dtypes.get_struct_fields(dtype)
return pl.struct(
*[
preprocess_binary(
expr.struct.field(name), field_type
).alias(name)
for name, field_type in fields.items()
]
)
if bigframes.dtypes.is_array_like(dtype):
inner_type = bigframes.dtypes.get_array_inner_type(dtype)
return expr.list.eval(preprocess_binary(pl.element(), inner_type))
return expr

preprocessed = preprocess_binary(input, from_type)

if bigframes.dtypes.is_struct_like(from_type):
result = preprocessed.struct.json_encode()
elif from_type == bigframes.dtypes.INT_DTYPE:
result = preprocessed.cast(pl.String)
elif from_type == bigframes.dtypes.BOOL_DTYPE:
result = (
pl.when(preprocessed)
.then(pl.lit("true"))
.otherwise(pl.lit("false"))
)
elif from_type == bigframes.dtypes.BYTES_DTYPE:
result = pl.lit('"') + preprocessed + pl.lit('"')
else:
wrapped = pl.struct(value=preprocessed).struct.json_encode()
result = wrapped.str.slice(9, wrapped.str.len_chars() - 10)

return pl.when(input.is_null()).then(pl.lit("null")).otherwise(result)

@compile_op.register(arr_ops.ToArrayOp)
def _(self, op: ops.ToArrayOp, *inputs: pl.Expr) -> pl.Expr:
return pl.concat_list(*inputs)
Expand Down Expand Up @@ -532,6 +578,36 @@ def _(self, op: ops.ArrayReduceOp, input: pl.Expr) -> pl.Expr:
f"Haven't implemented array aggregation: {op.aggregation}"
)

@compile_op.register(struct_ops.StructOp)
def _(self, op: struct_ops.StructOp, *inputs: pl.Expr) -> pl.Expr:
return pl.struct(**{col: inp for col, inp in zip(op.column_names, inputs)}) # type: ignore

@compile_op.register(struct_ops.StructFieldOp)
def _(self, op: struct_ops.StructFieldOp, *inputs: pl.Expr) -> pl.Expr:
return inputs[0].struct[op.name_or_index]

@compile_op.register(remote_function_ops.PythonUdfOp)
def _(self, op: ops.PythonUdfOp, *inputs: pl.Expr) -> pl.Expr:
from bigframes.functions import function_template

code = op.function_def.code.to_callable()
if op.function_def.signature.is_row_processor:

def handler(py_struct):
args = list(py_struct.values())
series_arg = function_template.get_pd_series(args[0])
return code(series_arg, *args[1:])
else:

def handler(py_struct):
return code(*(field for field in py_struct.values()))

return pl.struct(*inputs).map_elements(
handler,
return_dtype=_bigframes_dtype_to_polars_dtype(op.output_type()),
skip_nulls=False,
)
Comment thread
TrevorBergeron marked this conversation as resolved.

@dataclasses.dataclass(frozen=True)
class PolarsAggregateCompiler:
scalar_compiler = PolarsExpressionCompiler()
Expand Down
69 changes: 5 additions & 64 deletions packages/bigframes/bigframes/functions/_function_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,16 @@

from __future__ import annotations

import collections.abc
import functools
import inspect
import logging
import random
import string
import sys
import threading
import time
import warnings
from typing import (
TYPE_CHECKING,
Any,
Literal,
Mapping,
Optional,
Sequence,
Union,
Expand Down Expand Up @@ -512,22 +507,10 @@ def wrapper(func):
TypeError, f"func must be a callable, got {func}"
)

if sys.version_info >= (3, 10):
# Add `eval_str = True` so that deferred annotations are turned into their
# corresponding type objects. Need Python 3.10 for eval_str parameter.
# https://docs.python.org/3/library/inspect.html#inspect.signature
signature_kwargs: Mapping[str, Any] = {"eval_str": True}
else:
signature_kwargs = {} # type: ignore

py_sig = _resolve_signature(
inspect.signature(func, **signature_kwargs),
udf_sig = _utils.get_func_signature(
func,
input_types,
output_type,
)

udf_sig = udf_def.UdfSignature.from_py_signature(
py_sig
).to_remote_function_compatible()

full_package_requirements = _utils.get_updated_package_requirements(
Expand Down Expand Up @@ -786,23 +769,11 @@ def wrapper(func):
TypeError, f"func must be a callable, got {func}"
)

if sys.version_info >= (3, 10):
# Add `eval_str = True` so that deferred annotations are turned into their
# corresponding type objects. Need Python 3.10 for eval_str parameter.
# https://docs.python.org/3/library/inspect.html#inspect.signature
signature_kwargs: Mapping[str, Any] = {"eval_str": True}
else:
signature_kwargs = {} # type: ignore

py_sig = inspect.signature(
udf_sig = _utils.get_func_signature(
func,
**signature_kwargs,
input_types,
output_type,
)
py_sig = _resolve_signature(py_sig, input_types, output_type)

# The function will actually be receiving a pandas Series, but allow
# both BigQuery DataFrames and pandas object types for compatibility.
udf_sig = udf_def.UdfSignature.from_py_signature(py_sig)

code_def = udf_def.CodeDef.from_func(func, package_requirements=packages)
requirements = udf_def.RuntimeRequirements(
Expand Down Expand Up @@ -878,36 +849,6 @@ def deploy_udf(
return self.udf(_force_deploy=True, **kwargs)(func)


def _resolve_signature(
py_sig: inspect.Signature,
input_types: Union[None, type, Sequence[type]] = None,
output_type: Optional[type] = None,
) -> inspect.Signature:
if input_types is not None:
if not isinstance(input_types, collections.abc.Sequence):
input_types = [input_types]
if _utils.has_conflict_input_type(py_sig, input_types):
msg = bfe.format_message(
"Conflicting input types detected, using the one from the decorator."
)
warnings.warn(msg, category=bfe.FunctionConflictTypeHintWarning)
py_sig = py_sig.replace(
parameters=[
par.replace(annotation=itype)
for par, itype in zip(py_sig.parameters.values(), input_types)
]
)
if output_type:
if _utils.has_conflict_output_type(py_sig, output_type):
msg = bfe.format_message(
"Conflicting return type detected, using the one from the decorator."
)
warnings.warn(msg, category=bfe.FunctionConflictTypeHintWarning)
py_sig = py_sig.replace(return_annotation=output_type)

return py_sig


def get_cloud_function_name(
function_def: udf_def.CloudRunFunctionConfig, session_id=None, uniq_suffix=False
):
Expand Down
56 changes: 54 additions & 2 deletions packages/bigframes/bigframes/functions/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,14 @@
# limitations under the License.


import collections
Comment thread
TrevorBergeron marked this conversation as resolved.
import hashlib
import inspect
import json
import sys
import typing
import warnings
from typing import Any, Optional, Sequence, Set, cast
from typing import Any, Mapping, Optional, Sequence, Set, cast

import cloudpickle
import google.api_core.exceptions
Expand All @@ -31,7 +32,7 @@

import bigframes.exceptions as bfe
import bigframes.formatting_helpers as bf_formatting
from bigframes.functions import function_typing
from bigframes.functions import function_typing, udf_def

# Naming convention for the function artifacts
_BIGFRAMES_FUNCTION_PREFIX = "bigframes"
Expand Down Expand Up @@ -304,3 +305,54 @@ def has_conflict_output_type(
return False

return return_annotation != output_type


def get_func_signature(
func,
input_types: type | Sequence[type] | None = None,
output_type: type | None = None,
) -> udf_def.UdfSignature:
if sys.version_info >= (3, 10):
# Add `eval_str = True` so that deferred annotations are turned into their
# corresponding type objects. Need Python 3.10 for eval_str parameter.
# https://docs.python.org/3/library/inspect.html#inspect.signature
signature_kwargs: Mapping[str, Any] = {"eval_str": True}
else:
signature_kwargs = {} # type: ignore

py_sig = resolve_signature(
inspect.signature(func, **signature_kwargs),
input_types,
output_type,
)
return udf_def.UdfSignature.from_py_signature(py_sig)


def resolve_signature(
py_sig: inspect.Signature,
input_types: type | Sequence[type] | None = None,
output_type: type | None = None,
) -> inspect.Signature:
if input_types is not None:
if not isinstance(input_types, collections.abc.Sequence):
input_types = [input_types]
if has_conflict_input_type(py_sig, input_types):
msg = bfe.format_message(
"Conflicting input types detected, using the one from the decorator."
)
warnings.warn(msg, category=bfe.FunctionConflictTypeHintWarning)
py_sig = py_sig.replace(
parameters=[
par.replace(annotation=itype)
for par, itype in zip(py_sig.parameters.values(), input_types)
]
)
if output_type:
if has_conflict_output_type(py_sig, output_type):
msg = bfe.format_message(
"Conflicting return type detected, using the one from the decorator."
)
warnings.warn(msg, category=bfe.FunctionConflictTypeHintWarning)
py_sig = py_sig.replace(return_annotation=output_type)

return py_sig
24 changes: 24 additions & 0 deletions packages/bigframes/bigframes/testing/polars_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import bigframes.session.execution_spec
import bigframes.session.executor
import bigframes.session.metrics
from bigframes.functions import _utils, function, udf_def


# Does not support to_sql, dry_run, peek, cached
Expand Down Expand Up @@ -111,6 +112,29 @@ def read_pandas(self, pandas_dataframe, write_engine="default"):

return bf_df

def udf(
self,
*,
input_types=None,
output_type=None,
**kwargs,
):
def wrapper(func):
udf_sig = _utils.get_func_signature(
func,
input_types,
output_type,
)

code_def = udf_def.CodeDef.from_func(func)
udf_definition = udf_def.PythonUdf(
signature=udf_sig,
code=code_def,
)
return function.UdfRoutine(func=func, _udf_def=udf_definition)

return wrapper

@property
def bqclient(self):
# prevents logger from trying to call bq upon any errors
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,25 @@ def test_engines_astype_to_json(scalars_array_value: array_value.ArrayValue, eng
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)


@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
def test_engines_to_json_string(scalars_array_value: array_value.ArrayValue, engine):
exprs = [
ops.ToJSONString().as_expr(expression.deref("int64_col")),
ops.ToJSONString().as_expr(
# Use a const since float to json has precision issues
expression.const(5.2, bigframes.dtypes.FLOAT_DTYPE)
),
ops.ToJSONString().as_expr(expression.deref("bool_col")),
ops.ToJSONString().as_expr(
# Use a const since "str_col" has special chars.
expression.const('"hello world"', bigframes.dtypes.STRING_DTYPE)
),
]
arr, _ = scalars_array_value.compute_values(exprs)

assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)


@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
def test_engines_astype_timedelta(scalars_array_value: array_value.ArrayValue, engine):
arr = apply_op(
Expand Down
Loading
Loading