Skip to content

Commit

Permalink
feat: support empty arrays, improve ibis.array() API
Browse files Browse the repository at this point in the history
Picking out the array stuff from ibis-project#8666
  • Loading branch information
NickCrews committed Jun 29, 2024
1 parent 33ec754 commit cbd2510
Show file tree
Hide file tree
Showing 9 changed files with 175 additions and 36 deletions.
9 changes: 8 additions & 1 deletion ibis/backends/pandas/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@ def visit(cls, op: ops.Node, **kwargs):

@classmethod
def visit(cls, op: ops.Literal, value, dtype):
if dtype.is_interval():
if value is None:
value = None
elif dtype.is_interval():
value = pd.Timedelta(value, dtype.unit.short)
elif dtype.is_array():
value = np.array(value)
Expand Down Expand Up @@ -219,6 +221,11 @@ def visit(cls, op: ops.FindInSet, needle, values):
result = np.select(condlist, choicelist, default=-1)
return pd.Series(result, name=op.name)

@classmethod
def visit(cls, op: ops.EmptyArray, dtype):
pdt = PandasType.from_ibis(dtype)
return np.array([], dtype=pdt)

@classmethod
def visit(cls, op: ops.Array, exprs):
return cls.rowwise(lambda row: np.array(row, dtype=object), exprs)
Expand Down
15 changes: 8 additions & 7 deletions ibis/backends/polars/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,7 @@ def literal(op, **_):
value = op.value
dtype = op.dtype

if dtype.is_array():
value = pl.Series("", value)
typ = PolarsType.from_ibis(dtype)
val = pl.lit(value, dtype=typ)
return val.implode()
elif dtype.is_struct():
if dtype.is_struct():
values = [
pl.lit(v, dtype=PolarsType.from_ibis(dtype[k])).alias(k)
for k, v in value.items()
Expand All @@ -106,7 +101,7 @@ def literal(op, **_):
return pl.lit(value)
else:
typ = PolarsType.from_ibis(dtype)
return pl.lit(op.value, dtype=typ)
return pl.lit(value, dtype=typ)


_TIMESTAMP_SCALE_TO_UNITS = {
Expand Down Expand Up @@ -973,6 +968,12 @@ def array_concat(op, **kw):
return result


@translate.register(ops.EmptyArray)
def empty_array(op, **kw):
pdt = PolarsType.from_ibis(op.dtype)
return pl.lit([], dtype=pdt)


@translate.register(ops.Array)
def array_column(op, **kw):
cols = [translate(col, **kw) for col in op.exprs]
Expand Down
3 changes: 3 additions & 0 deletions ibis/backends/sql/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1019,6 +1019,9 @@ def visit_InSubquery(self, op, *, rel, needle):
query = sg.select(STAR).from_(query)
return needle.isin(query=query)

def visit_EmptyArray(self, op, *, dtype):
return self.cast(self.f.array(), dtype)

def visit_Array(self, op, *, exprs):
return self.f.array(*exprs)

Expand Down
68 changes: 67 additions & 1 deletion ibis/backends/tests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
PySparkAnalysisException,
TrinoUserError,
)
from ibis.common.annotations import ValidationError
from ibis.common.collections import frozendict

pytestmark = [
Expand Down Expand Up @@ -72,6 +73,71 @@
# list.


def test_array_factory(con):
a = ibis.array([1, 2, 3])
assert a.type() == dt.Array(value_type=dt.Int8)
assert con.execute(a) == [1, 2, 3]

a2 = ibis.array(a)
assert a.type() == dt.Array(value_type=dt.Int8)
assert con.execute(a2) == [1, 2, 3]


def test_array_factory_typed(con):
typed = ibis.array([1, 2, 3], type="array<string>")
assert con.execute(typed) == ["1", "2", "3"]

typed2 = ibis.array(ibis.array([1, 2, 3]), type="array<string>")
assert con.execute(typed2) == ["1", "2", "3"]


@pytest.mark.notimpl("flink", raises=Py4JJavaError)
def test_array_factory_empty(con):
with pytest.raises(ValidationError):
ibis.array([])

empty_typed = ibis.array([], type="array<string>")
assert empty_typed.type() == dt.Array(value_type=dt.string)
assert con.execute(empty_typed) == []


@pytest.mark.notyet(
"clickhouse", raises=ClickHouseDatabaseError, reason="nested types can't be NULL"
)
def test_array_factory_null(con):
with pytest.raises(ValidationError):
ibis.array(None)
with pytest.raises(ValidationError):
ibis.array(None, type="int64")
none_typed = ibis.array(None, type="array<string>")
assert none_typed.type() == dt.Array(value_type=dt.string)
assert con.execute(none_typed) is None

nones = ibis.array([None, None], type="array<string>")
assert nones.type() == dt.Array(value_type=dt.string)
assert con.execute(nones) == [None, None]

# Execute a real value here, so the backends that don't support arrays
# actually xfail as we expect them to.
# Otherwise would have to @mark.xfail every test in this file besides this one.
assert con.execute(ibis.array([1, 2])) == [1, 2]


@pytest.mark.broken(
["datafusion", "flink", "polars"],
raises=AssertionError,
reason="[None, 1] executes to [np.nan, 1.0]",
)
def test_array_factory_null_mixed(con):
none_and_val = ibis.array([None, 1])
assert none_and_val.type() == dt.Array(value_type=dt.Int8)
assert con.execute(none_and_val) == [None, 1]

none_and_val_typed = ibis.array([None, 1], type="array<string>")
assert none_and_val_typed.type() == dt.Array(value_type=dt.String)
assert con.execute(none_and_val_typed) == [None, "1"]


def test_array_column(backend, alltypes, df):
expr = ibis.array(
[alltypes["double_col"], alltypes["double_col"], 5.0, ibis.literal(6.0)]
Expand Down Expand Up @@ -1356,7 +1422,7 @@ def test_unnest_range(con):
pytest.mark.notyet(["bigquery"], raises=GoogleBadRequest),
pytest.mark.broken(
["polars"],
reason="expression input not supported with nested arrays",
reason="upstream polars bug: https://github.com/pola-rs/polars/issues/17294",
raises=TypeError,
),
],
Expand Down
7 changes: 3 additions & 4 deletions ibis/backends/tests/test_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1431,13 +1431,12 @@ def query(t, group_cols):
snapshot.assert_match(str(ibis.to_sql(t3, dialect=con.name)), "out.sql")


@pytest.mark.notimpl(["oracle", "exasol"], raises=com.OperationNotDefinedError)
@pytest.mark.notimpl(["druid"], raises=AssertionError)
@pytest.mark.notyet(
["datafusion", "impala", "mssql", "mysql", "sqlite"],
["datafusion", "exasol", "impala", "mssql", "mysql", "oracle", "sqlite"],
reason="backend doesn't support arrays and we don't implement pivot_longer with unions yet",
raises=com.OperationNotDefinedError,
raises=(com.OperationNotDefinedError, com.UnsupportedBackendType),
)
@pytest.mark.notimpl(["druid"], raises=AssertionError)
@pytest.mark.broken(
["trino"],
reason="invalid code generated for unnesting a struct",
Expand Down
16 changes: 13 additions & 3 deletions ibis/backends/tests/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,18 @@
ibis.struct(dict(abc=432)),
marks=[
pytest.mark.never(
["impala", "mysql", "sqlite", "mssql", "exasol"],
raises=(NotImplementedError, exc.UnsupportedBackendType),
[
"exasol",
"impala",
"mysql",
"sqlite",
"mssql",
],
raises=(
exc.OperationNotDefinedError,
NotImplementedError,
exc.UnsupportedBackendType,
),
reason="structs not supported in the backend",
),
pytest.mark.notimpl(
Expand Down Expand Up @@ -104,7 +114,7 @@ def test_isin_bug(con, snapshot):
@pytest.mark.notyet(
["datafusion", "exasol", "oracle", "flink", "risingwave"],
reason="no unnest support",
raises=exc.OperationNotDefinedError,
raises=(exc.OperationNotDefinedError, exc.UnsupportedBackendType),
)
@pytest.mark.notyet(
["sqlite", "mysql", "druid", "impala", "mssql"], reason="no unnest support upstream"
Expand Down
10 changes: 5 additions & 5 deletions ibis/backends/tests/test_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -835,18 +835,18 @@ def test_capitalize(con, inp, expected):
assert pd.isnull(result)


@pytest.mark.never(
["exasol", "impala", "mssql", "mysql", "sqlite"],
reason="Backend doesn't support arrays",
raises=(com.OperationNotDefinedError, com.UnsupportedBackendType),
)
@pytest.mark.notimpl(
[
"dask",
"pandas",
"polars",
"oracle",
"flink",
"sqlite",
"mssql",
"mysql",
"exasol",
"impala",
],
raises=com.OperationNotDefinedError,
)
Expand Down
21 changes: 15 additions & 6 deletions ibis/expr/operations/arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from __future__ import annotations

from typing import Optional
from typing import TYPE_CHECKING, Annotated, Optional

from public import public

Expand All @@ -13,16 +13,25 @@
from ibis.common.typing import VarTuple # noqa: TCH001
from ibis.expr.operations.core import Unary, Value

if TYPE_CHECKING:
from ibis.common.patterns import Length


@public
class EmptyArray(Value):
"""Construct an array with 0 elements."""

dtype: dt.Array
shape = ds.scalar


@public
class Array(Value):
"""Construct an array."""
"""Construct an array with 1+ elements. Use `EmptyArray` for empty arrays."""

exprs: VarTuple[Value]
exprs: Annotated[VarTuple[Value], Length(at_least=1)]

@attribute
def shape(self):
return rlz.highest_precedence_shape(self.exprs)
shape = rlz.shape_like("exprs")

@attribute
def dtype(self):
Expand Down
62 changes: 53 additions & 9 deletions ibis/expr/types/arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,17 @@

from public import public

import ibis
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
import ibis.expr.types as ir
from ibis.common.annotations import ValidationError
from ibis.common.deferred import Deferred, deferrable
from ibis.expr.types.generic import Column, Scalar, Value

if TYPE_CHECKING:
from collections.abc import Callable, Iterable

import ibis.expr.types as ir
from ibis.expr.types.typing import V

import ibis.common.exceptions as com
Expand Down Expand Up @@ -1067,7 +1070,11 @@ def __getitem__(self, index: int | ir.IntegerValue | slice) -> ir.Column:

@public
@deferrable
def array(values: Iterable[V]) -> ArrayValue:
def array(
values: ArrayValue | Iterable[V] | ir.NullValue | None,
*,
type: str | dt.DataType | None = None,
) -> ArrayValue:
"""Create an array expression.
If any values are [column expressions](../concepts/datatypes.qmd) the
Expand All @@ -1078,6 +1085,9 @@ def array(values: Iterable[V]) -> ArrayValue:
----------
values
An iterable of Ibis expressions or Python literals
type
An instance of `ibis.expr.datatypes.DataType` or a string indicating
the Ibis type of `value`. eg `array<float>`.
Returns
-------
Expand Down Expand Up @@ -1108,15 +1118,49 @@ def array(values: Iterable[V]) -> ArrayValue:
│ [3, 42, ... +1] │
└──────────────────────┘
>>> ibis.array([t.a, 42 + ibis.literal(5)])
>>> ibis.array([t.a, 42 + ibis.literal(5)], type="array<float>")
┏━━━━━━━━━━━━━━━━━━━━━━┓
┃ Array()
┃ Array( )
┡━━━━━━━━━━━━━━━━━━━━━━┩
│ array<int64>
│ array<float64>
├──────────────────────┤
│ [1, 47]
│ [2, 47]
│ [3, 47]
│ [1.0, 47.0]
│ [2.0, 47.0]
│ [3.0, 47.0]
└──────────────────────┘
"""
return ops.Array(tuple(values)).to_expr()
type = dt.dtype(type) if type is not None else None
if type is not None and not isinstance(type, dt.Array):
raise ValidationError(f"type must be an array, got {type}")

if isinstance(values, ir.Value):
if type is not None:
return values.cast(type)
elif isinstance(values, ArrayValue):
return values
else:
raise ValidationError(
f"If no type passed, values must be an array, got {values.type()}"
)

if values is None:
if type is None:
raise ValidationError("If values is None/NULL, type must be provided")
return ir.null(type)

values = tuple(values)
if len(values) == 0:
if type is None:
raise ValidationError("If values is empty, type must be provided")
return ops.EmptyArray(type).to_expr()
else:
value_type = type.value_type if type is not None else None
values = [_value(v, value_type) for v in values]
return ops.Array(values).to_expr()


def _value(x, type) -> ir.Value:
if isinstance(x, (ir.Value, Deferred)):
return x.cast(type) if type is not None else x
else:
return ibis.literal(x, type=type)

0 comments on commit cbd2510

Please sign in to comment.