Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: improve map(), struct(), array() #8666

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions ibis/backends/dask/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
plan,
)
from ibis.common.exceptions import UnboundExpressionError, UnsupportedOperationError
from ibis.formats.numpy import NumpyType
from ibis.formats.pandas import PandasData, PandasType
from ibis.util import gen_name

Expand Down Expand Up @@ -155,9 +156,10 @@ def mapper(df, cases):
return cls.partitionwise(mapper, kwargs, name=op.name, dtype=dtype)

@classmethod
def visit(cls, op: ops.Array, exprs):
def visit(cls, op: ops.Array, exprs, dtype):
np_type = NumpyType.from_ibis(dtype)
return cls.rowwise(
lambda row: np.array(row, dtype=object), exprs, name=op.name, dtype=object
lambda row: np.array(row, dtype=np_type), exprs, name=op.name, dtype=object
)

@classmethod
Expand Down
4 changes: 2 additions & 2 deletions ibis/backends/dask/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def concat(cls, dfs, **kwargs):

@classmethod
def asseries(cls, value, like=None):
"""Ensure that value is a pandas Series object, broadcast if necessary."""
"""Ensure that value is a dask Series object, broadcast if necessary."""

if isinstance(value, dd.Series):
return value
Expand All @@ -50,7 +50,7 @@ def asseries(cls, value, like=None):
elif isinstance(value, pd.Series):
return dd.from_pandas(value, npartitions=1)
elif like is not None:
if isinstance(value, (tuple, list, dict)):
if isinstance(value, (tuple, list, dict, np.ndarray)):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ops.Array executes to a np.ndarray (both before this change, and after). We never ran into this scenario in tests before, so this tweak here is adjusting an existing bug.

fn = lambda df: pd.Series([value] * len(df), index=df.index)
else:
fn = lambda df: pd.Series(value, index=df.index)
Expand Down
1 change: 1 addition & 0 deletions ibis/backends/exasol/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ class ExasolCompiler(SQLGlotCompiler):
ops.StringSplit,
ops.StringToDate,
ops.StringToTimestamp,
ops.StructColumn,
ops.TimeDelta,
ops.TimestampAdd,
ops.TimestampBucket,
Expand Down
8 changes: 6 additions & 2 deletions ibis/backends/pandas/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
)
from ibis.common.dispatch import Dispatched
from ibis.common.exceptions import OperationNotDefinedError, UnboundExpressionError
from ibis.formats.numpy import NumpyType
from ibis.formats.pandas import PandasData, PandasType
from ibis.util import any_of, gen_name

Expand All @@ -49,6 +50,8 @@ def visit(cls, op: ops.Node, **kwargs):

@classmethod
def visit(cls, op: ops.Literal, value, dtype):
if value is None:
return None
if dtype.is_interval():
value = pd.Timedelta(value, dtype.unit.short)
elif dtype.is_array():
Expand Down Expand Up @@ -220,8 +223,9 @@ def visit(cls, op: ops.FindInSet, needle, values):
return pd.Series(result, name=op.name)

@classmethod
def visit(cls, op: ops.Array, exprs):
return cls.rowwise(lambda row: np.array(row, dtype=object), exprs)
def visit(cls, op: ops.Array, exprs, dtype):
np_type = NumpyType.from_ibis(dtype)
return cls.rowwise(lambda row: np.array(row, dtype=np_type), exprs)

@classmethod
def visit(cls, op: ops.StructColumn, names, values):
Expand Down
20 changes: 12 additions & 8 deletions ibis/backends/polars/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,14 @@ def _make_duration(value, dtype):
def literal(op, **_):
value = op.value
dtype = op.dtype
if dtype.is_interval():
return _make_duration(value, dtype)

if dtype.is_array():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you avoid moving the code around here and only handle the None case? It's harder to review when these kinds of changes are mixed together.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to look at this again, but IIRC this was actually the minimum possible change, something was erroring with NULL duration handling. Will check again

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this reorg is needed. ibis's Duration.DAY and friends are not representable in polars' datatypes, so if you call PolarsType.from_ibis(dtype) on those, you get an error. So we need an early guard to look for the interval dtypes and return early.

typ = PolarsType.from_ibis(dtype)
if value is None:
return pl.lit(None, dtype=typ)
elif dtype.is_array():
value = pl.Series("", value)
typ = PolarsType.from_ibis(dtype)
val = pl.lit(value, dtype=typ)
return val.implode()
elif dtype.is_struct():
Expand All @@ -98,14 +102,11 @@ def literal(op, **_):
for k, v in value.items()
]
return pl.struct(values)
elif dtype.is_interval():
return _make_duration(value, dtype)
elif dtype.is_null():
return pl.lit(value)
elif dtype.is_binary():
return pl.lit(value)
else:
typ = PolarsType.from_ibis(dtype)
return pl.lit(op.value, dtype=typ)


Expand Down Expand Up @@ -985,9 +986,12 @@ def array_concat(op, **kw):


@translate.register(ops.Array)
def array_column(op, **kw):
cols = [translate(col, **kw) for col in op.exprs]
return pl.concat_list(cols)
def array_literal(op, **kw):
pdt = PolarsType.from_ibis(op.dtype)
if op.exprs:
return pl.concat_list([translate(col, **kw) for col in op.exprs]).cast(pdt)
else:
return pl.lit([], dtype=pdt)


@translate.register(ops.ArrayCollect)
Expand Down
4 changes: 2 additions & 2 deletions ibis/backends/sql/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1019,8 +1019,8 @@ def visit_InSubquery(self, op, *, rel, needle):
query = sg.select(STAR).from_(query)
return needle.isin(query=query)

def visit_Array(self, op, *, exprs):
return self.f.array(*exprs)
def visit_Array(self, op, *, exprs, dtype):
return self.cast(self.f.array(*exprs), dtype)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems extraordinarily heavy-handed and overly broad.

Do we really need to cast the entire on constructing an array, for all SQL backends?


def visit_StructColumn(self, op, *, names, values):
return sge.Struct.from_arg_list(
Expand Down
1 change: 1 addition & 0 deletions ibis/backends/sqlite/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class SQLiteCompiler(SQLGlotCompiler):
ops.TimestampDiff,
ops.StringToDate,
ops.StringToTimestamp,
ops.StructColumn,
ops.TimeDelta,
ops.DateDelta,
ops.TimestampDelta,
Expand Down
70 changes: 70 additions & 0 deletions ibis/backends/tests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
SnowflakeProgrammingError,
TrinoUserError,
)
from ibis.common.annotations import ValidationError
from ibis.common.collections import frozendict

pytestmark = [
Expand Down Expand Up @@ -73,6 +74,75 @@
# list.


def test_array_factory(con):
a = ibis.array([1, 2, 3])
assert a.type() == dt.Array(value_type=dt.Int8)
assert con.execute(a) == [1, 2, 3]

a2 = ibis.array(a)
assert a.type() == dt.Array(value_type=dt.Int8)
assert con.execute(a2) == [1, 2, 3]


def test_array_factory_typed(con):
typed = ibis.array([1, 2, 3], type="array<string>")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really like the implicit conversion here. Is it necessary?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The conversion is exactly the part of the API I was trying to test here. I think this conversion is something that this API should explicitly support. What do you think?

Maybe I add a param as

@pytest.mark.param(raw, [[1,2,3], ["1", "2", "3"]])
def test_array_factory_typed(con, raw):
    typed = ibis.array(raw, type="array<string>")
    ...

assert con.execute(typed) == ["1", "2", "3"]

typed2 = ibis.array(ibis.array([1, 2, 3]), type="array<string>")
assert con.execute(typed2) == ["1", "2", "3"]


@pytest.mark.notimpl("flink", raises=Py4JJavaError)
@pytest.mark.notimpl(["pandas", "dask"], raises=ValueError)
def test_array_factory_empty(con):
with pytest.raises(ValidationError):
ibis.array([])

empty_typed = ibis.array([], type="array<string>")
assert empty_typed.type() == dt.Array(value_type=dt.string)
assert con.execute(empty_typed) == []


@pytest.mark.notyet(
"clickhouse", raises=ClickHouseDatabaseError, reason="nested types can't be NULL"
)
@pytest.mark.notyet(
"flink", raises=Py4JJavaError, reason="Parameters must be of the same type"
)
def test_array_factory_null(con):
with pytest.raises(ValidationError):
ibis.array(None)
with pytest.raises(ValidationError):
ibis.array(None, type="int64")
none_typed = ibis.array(None, type="array<string>")
assert none_typed.type() == dt.Array(value_type=dt.string)
assert con.execute(none_typed) is None

nones = ibis.array([None, None], type="array<string>")
assert nones.type() == dt.Array(value_type=dt.string)
assert con.execute(nones) == [None, None]

# Execute a real value here, so the backends that don't support arrays
# actually xfail as we expect them to.
# Otherwise would have to @mark.xfail every test in this file besides this one.
assert con.execute(ibis.array([1, 2])) == [1, 2]


@pytest.mark.broken(
["datafusion", "flink", "polars"],
raises=AssertionError,
reason="[None, 1] executes to [np.nan, 1.0]",
)
def test_array_factory_null_mixed(con):
none_and_val = ibis.array([None, 1])
assert none_and_val.type() == dt.Array(value_type=dt.Int8)
assert con.execute(none_and_val) == [None, 1]

none_and_val_typed = ibis.array([None, 1], type="array<string>")
assert none_and_val_typed.type() == dt.Array(value_type=dt.String)
assert con.execute(none_and_val_typed) == [None, "1"]


def test_array_column(backend, alltypes, df):
expr = ibis.array(
[alltypes["double_col"], alltypes["double_col"], 5.0, ibis.literal(6.0)]
Expand Down
31 changes: 31 additions & 0 deletions ibis/backends/tests/test_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import ibis.common.exceptions as exc
import ibis.expr.datatypes as dt
from ibis.backends.tests.errors import PsycoPg2InternalError, Py4JJavaError
from ibis.common.annotations import ValidationError

pytestmark = [
pytest.mark.never(
Expand Down Expand Up @@ -39,6 +40,31 @@
)


@mark_notimpl_risingwave_hstore
@pytest.mark.parametrize(
"values_factory",
[
lambda: ({"a": "b"},),
lambda: (["a"], ["b"]),
lambda: (ibis.map({"a": "b"}),),
],
)
def test_map_factory_dict(con, values_factory):
vals = values_factory()
with pytest.raises(ValidationError):
ibis.map(*vals, type="array<string>")
assert con.execute(ibis.map(*vals)) == {"a": "b"}
assert con.execute(ibis.map(*vals, type="map<string, string>")) == {"a": "b"}


def test_map_null(con):
with pytest.raises(ValidationError):
ibis.map(None)
with pytest.raises(ValidationError):
ibis.map(None, type="array<string>")
assert con.execute(ibis.map(None, type="map<string, string>")) is None


@pytest.mark.notyet("clickhouse", reason="nested types can't be NULL")
@pytest.mark.broken(["pandas", "dask"], reason="TypeError: iteration over a 0-d array")
@pytest.mark.notimpl(
Expand Down Expand Up @@ -503,6 +529,11 @@ def test_literal_map_getitem_broadcast(backend, alltypes, df):
marks=[
pytest.mark.notyet("clickhouse", reason="nested types can't be null"),
mark_notyet_postgres,
pytest.mark.notimpl(
"flink",
raises=Py4JJavaError,
reason="Unexpected error in type inference logic of function 'COALESCE'",
),
],
id="struct",
),
Expand Down
6 changes: 5 additions & 1 deletion ibis/backends/tests/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,11 @@
marks=[
pytest.mark.never(
["impala", "mysql", "sqlite", "mssql", "exasol"],
raises=(NotImplementedError, exc.UnsupportedBackendType),
raises=(
exc.OperationNotDefinedError,
NotImplementedError,
exc.UnsupportedBackendType,
),
reason="structs not supported in the backend",
),
pytest.mark.notimpl(
Expand Down
Loading
Loading