Skip to content

Commit

Permalink
feat: Improve array(), map(), and struct
Browse files Browse the repository at this point in the history
fixes #8289

This does a lot of changes. It was hard for me to separate them out as I implemented them. But now that it's all hashed out, I can try to split this up into separate commits if you want. But that might be sorta hard in
some cases.

One this is adding support for passing in None to all these constructors.
These use the new `ibis.null(<type>)` API to return `op.Literal(None, <type>)`s

Make these constructors idempotent: you can
pass in existing Expressions into array(), etc.
The type argument for all of these now always has an effect, not just when passing in python literals. So basically it acts like a cast.

A big structural change is that now ops.Array has an optional
attribute "dtype", so if you pass in a 0-length sequence
of values the op still knows what dtype it is.

Several of the backends were always broken here, they just weren't getting caught. I marked them as broken, we can fix them in a followup.

You can test this locally with eg
`pytest -m <backend> -k factory ibis/backends/tests/test_array.py  ibis/backends/tests/test_map.py ibis/backends/tests/test_struct.py`

Also, fix executing Literal(None) on pandas and polars, 0-length arrays on polars

Also, fixing converting dtypes on clickhouse, Structs should be converted to nonnullable dtypes.
  • Loading branch information
NickCrews committed Jun 4, 2024
1 parent 91b0500 commit 25c3e7b
Show file tree
Hide file tree
Showing 18 changed files with 413 additions and 168 deletions.
2 changes: 1 addition & 1 deletion ibis/backends/dask/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def mapper(df, cases):
return cls.partitionwise(mapper, kwargs, name=op.name, dtype=dtype)

@classmethod
def visit(cls, op: ops.Array, exprs):
def visit(cls, op: ops.Array, exprs, dtype):
return cls.rowwise(
lambda row: np.array(row, dtype=object), exprs, name=op.name, dtype=object
)
Expand Down
4 changes: 2 additions & 2 deletions ibis/backends/dask/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def concat(cls, dfs, **kwargs):

@classmethod
def asseries(cls, value, like=None):
"""Ensure that value is a pandas Series object, broadcast if necessary."""
"""Ensure that value is a dask Series object, broadcast if necessary."""

if isinstance(value, dd.Series):
return value
Expand All @@ -50,7 +50,7 @@ def asseries(cls, value, like=None):
elif isinstance(value, pd.Series):
return dd.from_pandas(value, npartitions=1)
elif like is not None:
if isinstance(value, (tuple, list, dict)):
if isinstance(value, (tuple, list, dict, np.ndarray)):
fn = lambda df: pd.Series([value] * len(df), index=df.index)
else:
fn = lambda df: pd.Series(value, index=df.index)
Expand Down
1 change: 1 addition & 0 deletions ibis/backends/exasol/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ class ExasolCompiler(SQLGlotCompiler):
ops.StringSplit,
ops.StringToDate,
ops.StringToTimestamp,
ops.StructColumn,
ops.TimeDelta,
ops.TimestampAdd,
ops.TimestampBucket,
Expand Down
4 changes: 3 additions & 1 deletion ibis/backends/pandas/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ def visit(cls, op: ops.Node, **kwargs):

@classmethod
def visit(cls, op: ops.Literal, value, dtype):
if value is None:
return None
if dtype.is_interval():
value = pd.Timedelta(value, dtype.unit.short)
elif dtype.is_array():
Expand Down Expand Up @@ -220,7 +222,7 @@ def visit(cls, op: ops.FindInSet, needle, values):
return pd.Series(result, name=op.name)

@classmethod
def visit(cls, op: ops.Array, exprs):
def visit(cls, op: ops.Array, exprs, dtype):
return cls.rowwise(lambda row: np.array(row, dtype=object), exprs)

@classmethod
Expand Down
19 changes: 11 additions & 8 deletions ibis/backends/polars/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,14 @@ def _make_duration(value, dtype):
def literal(op, **_):
value = op.value
dtype = op.dtype
if dtype.is_interval():
return _make_duration(value, dtype)

if dtype.is_array():
typ = PolarsType.from_ibis(dtype)
if value is None:
return pl.lit(None, dtype=typ)
elif dtype.is_array():
value = pl.Series("", value)
typ = PolarsType.from_ibis(dtype)
val = pl.lit(value, dtype=typ)
return val.implode()
elif dtype.is_struct():
Expand All @@ -98,14 +102,11 @@ def literal(op, **_):
for k, v in value.items()
]
return pl.struct(values)
elif dtype.is_interval():
return _make_duration(value, dtype)
elif dtype.is_null():
return pl.lit(value)
elif dtype.is_binary():
return pl.lit(value)
else:
typ = PolarsType.from_ibis(dtype)
return pl.lit(op.value, dtype=typ)


Expand Down Expand Up @@ -982,9 +983,11 @@ def array_concat(op, **kw):


@translate.register(ops.Array)
def array_column(op, **kw):
cols = [translate(col, **kw) for col in op.exprs]
return pl.concat_list(cols)
def array_literal(op, **kw):
if op.exprs:
return pl.concat_list([translate(col, **kw) for col in op.exprs])
else:
return pl.lit([], dtype=PolarsType.from_ibis(op.dtype))


@translate.register(ops.ArrayCollect)
Expand Down
7 changes: 5 additions & 2 deletions ibis/backends/sql/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1019,8 +1019,11 @@ def visit_InSubquery(self, op, *, rel, needle):
query = sg.select(STAR).from_(query)
return needle.isin(query=query)

def visit_Array(self, op, *, exprs):
return self.f.array(*exprs)
def visit_Array(self, op, *, exprs, dtype):
result = self.f.array(*exprs)
if not exprs:
return self.cast(result, dtype)
return result

def visit_StructColumn(self, op, *, names, values):
return sge.Struct.from_arg_list(
Expand Down
6 changes: 4 additions & 2 deletions ibis/backends/sql/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1018,8 +1018,10 @@ class ClickHouseType(SqlglotType):
def from_ibis(cls, dtype: dt.DataType) -> sge.DataType:
"""Convert a sqlglot type to an ibis type."""
typ = super().from_ibis(dtype)
if dtype.nullable and not (dtype.is_map() or dtype.is_array()):
# map cannot be nullable in clickhouse
# nested types cannot be nullable in clickhouse
if dtype.nullable and not (
dtype.is_map() or dtype.is_array() or dtype.is_struct()
):
return sge.DataType(this=typecode.NULLABLE, expressions=[typ])
else:
return typ
Expand Down
1 change: 1 addition & 0 deletions ibis/backends/sqlite/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class SQLiteCompiler(SQLGlotCompiler):
ops.TimestampDiff,
ops.StringToDate,
ops.StringToTimestamp,
ops.StructColumn,
ops.TimeDelta,
ops.DateDelta,
ops.TimestampDelta,
Expand Down
75 changes: 70 additions & 5 deletions ibis/backends/tests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
SnowflakeProgrammingError,
TrinoUserError,
)
from ibis.common.annotations import ValidationError
from ibis.common.collections import frozendict

pytestmark = [
Expand Down Expand Up @@ -73,6 +74,75 @@
# list.


def test_array_factory(con):
a = ibis.array([1, 2, 3])
assert a.type() == dt.Array(value_type=dt.Int8)
assert con.execute(a) == [1, 2, 3]

a2 = ibis.array(a)
assert a.type() == dt.Array(value_type=dt.Int8)
assert con.execute(a2) == [1, 2, 3]


def test_array_factory_typed(con):
typed = ibis.array([1, 2, 3], type="array<string>")
assert con.execute(typed) == ["1", "2", "3"]

typed2 = ibis.array(ibis.array([1, 2, 3]), type="array<string>")
assert con.execute(typed2) == ["1", "2", "3"]


@pytest.mark.notimpl("flink", raises=Py4JJavaError)
@pytest.mark.notimpl(["pandas", "dask"], raises=ValueError)
def test_array_factory_empty(con):
with pytest.raises(ValidationError):
ibis.array([])

empty_typed = ibis.array([], type="array<string>")
assert empty_typed.type() == dt.Array(value_type=dt.string)
assert con.execute(empty_typed) == []


@pytest.mark.notyet(
"clickhouse", raises=ClickHouseDatabaseError, reason="nested types can't be NULL"
)
@pytest.mark.notyet(
"flink", raises=Py4JJavaError, reason="Parameters must be of the same type"
)
def test_array_factory_null(con):
with pytest.raises(ValidationError):
ibis.array(None)
with pytest.raises(ValidationError):
ibis.array(None, type="int64")
none_typed = ibis.array(None, type="array<string>")
assert none_typed.type() == dt.Array(value_type=dt.string)
assert con.execute(none_typed) is None

nones = ibis.array([None, None], type="array<string>")
assert nones.type() == dt.Array(value_type=dt.string)
assert con.execute(nones) == [None, None]

# Execute a real value here, so the backends that don't support arrays
# actually xfail as we expect them to.
# Otherwise would have to @mark.xfail every test in this file besides this one.
assert con.execute(ibis.array([1, 2])) == [1, 2]


@pytest.mark.broken(
["datafusion", "flink", "polars"],
raises=AssertionError,
reason="[None, 1] executes to [np.nan, 1.0]",
)
def test_array_factory_null_mixed(con):
none_and_val = ibis.array([None, 1])
assert none_and_val.type() == dt.Array(value_type=dt.Int8)
assert con.execute(none_and_val) == [None, 1]

none_and_val_typed = ibis.array([None, 1], type="array<string>")
assert none_and_val_typed.type() == dt.Array(value_type=dt.String)
assert con.execute(none_and_val_typed) == [None, "1"]


def test_array_column(backend, alltypes, df):
expr = ibis.array(
[alltypes["double_col"], alltypes["double_col"], 5.0, ibis.literal(6.0)]
Expand Down Expand Up @@ -923,11 +993,6 @@ def test_zip_null(con, fn):


@builtin_array
@pytest.mark.notyet(
["clickhouse"],
raises=ClickHouseDatabaseError,
reason="https://github.com/ClickHouse/ClickHouse/issues/41112",
)
@pytest.mark.notimpl(["postgres"], raises=PsycoPg2SyntaxError)
@pytest.mark.notimpl(["risingwave"], raises=PsycoPg2ProgrammingError)
@pytest.mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError)
Expand Down
37 changes: 37 additions & 0 deletions ibis/backends/tests/test_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import ibis.common.exceptions as exc
import ibis.expr.datatypes as dt
from ibis.backends.tests.errors import PsycoPg2InternalError, Py4JJavaError
from ibis.common.annotations import ValidationError

pytestmark = [
pytest.mark.never(
Expand Down Expand Up @@ -39,6 +40,31 @@
)


@mark_notimpl_risingwave_hstore
def test_map_factory_dict(con):
assert con.execute(ibis.map({"a": "b"})) == {"a": "b"}
assert con.execute(ibis.map({"a": "b"}, type="map<string, string>")) == {"a": "b"}
with pytest.raises(ValidationError):
ibis.map({1: 2}, type="array<string>")


@mark_notimpl_risingwave_hstore
def test_map_factory_keys_vals(con):
assert con.execute(ibis.map(["a"], ["b"])) == {"a": "b"}
assert con.execute(ibis.map(["a"], ["b"], type="map<string, string>")) == {"a": "b"}
with pytest.raises(ValidationError):
ibis.map(["a"], ["b"], type="array<string>")


@mark_notimpl_risingwave_hstore
def test_map_factory_expr(con):
m = ibis.map({"a": "b"})
assert con.execute(ibis.map(m)) == {"a": "b"}
assert con.execute(ibis.map(m, type="map<string, string>")) == {"a": "b"}
with pytest.raises(ValidationError):
ibis.map(m, type="array<string>")


@pytest.mark.notyet("clickhouse", reason="nested types can't be NULL")
@pytest.mark.broken(["pandas", "dask"], reason="TypeError: iteration over a 0-d array")
@pytest.mark.notimpl(
Expand All @@ -60,6 +86,12 @@ def test_map_nulls(con, k, v):
m = ibis.map(k, v)
assert con.execute(m) is None

assert con.execute(ibis.map(None, type="map<string, string>")) is None
with pytest.raises(ValidationError):
ibis.map(None)
with pytest.raises(ValidationError):
ibis.map(None, type="array<string>")


@pytest.mark.notyet("clickhouse", reason="nested types can't be NULL")
@pytest.mark.broken(["pandas", "dask"], reason="TypeError: iteration over a 0-d array")
Expand Down Expand Up @@ -503,6 +535,11 @@ def test_literal_map_getitem_broadcast(backend, alltypes, df):
marks=[
pytest.mark.notyet("clickhouse", reason="nested types can't be null"),
mark_notyet_postgres,
pytest.mark.notimpl(
"flink",
raises=Py4JJavaError,
reason="Unexpected error in type inference logic of function 'COALESCE'",
),
],
id="struct",
),
Expand Down
76 changes: 38 additions & 38 deletions ibis/backends/tests/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,52 +10,52 @@

sg = pytest.importorskip("sqlglot")

simple_literal = param(ibis.literal(1), id="simple_literal")
array_literal = param(
ibis.array([1]),
marks=[
pytest.mark.never(
["mysql", "mssql", "oracle", "impala", "sqlite"],
raises=(exc.OperationNotDefinedError, exc.UnsupportedBackendType),
reason="arrays not supported in the backend",
),
],
id="array_literal",
)
no_structs = pytest.mark.never(
["impala", "mysql", "sqlite", "mssql", "exasol"],
raises=(NotImplementedError, exc.UnsupportedBackendType),
reason="structs not supported in the backend",
)
no_struct_literals = pytest.mark.notimpl(
["mssql"], reason="struct literals are not yet implemented"
)
not_sql = pytest.mark.never(
["pandas", "dask"],
raises=(exc.IbisError, NotImplementedError, ValueError),
reason="Not a SQL backend",
)
no_sql_extraction = pytest.mark.notimpl(
["polars"], reason="Not clear how to extract SQL from the backend"
)


@pytest.mark.parametrize(
"expr",
"expr,contains",
[
simple_literal,
array_literal,
param(ibis.literal(432), "432", id="simple_literal"),
param(
ibis.struct(dict(a=1)),
marks=[no_structs, no_struct_literals],
ibis.array([432]),
"432",
marks=[
pytest.mark.never(
["mysql", "mssql", "oracle", "impala", "sqlite"],
raises=(exc.OperationNotDefinedError, exc.UnsupportedBackendType),
reason="arrays not supported in the backend",
),
],
id="array_literal",
),
param(
ibis.struct(dict(abc=432)),
"432",
marks=[
pytest.mark.never(
["impala", "mysql", "sqlite", "mssql", "exasol"],
raises=(
exc.OperationNotDefinedError,
NotImplementedError,
exc.UnsupportedBackendType,
),
reason="structs not supported in the backend",
),
pytest.mark.notimpl(
["mssql"], reason="struct literals are not yet implemented"
),
],
id="struct_literal",
),
],
)
@not_sql
@no_sql_extraction
def test_literal(backend, expr):
assert ibis.to_sql(expr, dialect=backend.name())
@pytest.mark.never(
["pandas", "dask"],
raises=(exc.IbisError, NotImplementedError, ValueError),
reason="Not a SQL backend",
)
@pytest.mark.notimpl(["polars"], reason="Not clear how to extract SQL from the backend")
def test_literal(backend, expr, contains):
assert contains in ibis.to_sql(expr, dialect=backend.name())


@pytest.mark.never(["pandas", "dask", "polars"], reason="not SQL")
Expand Down
Loading

0 comments on commit 25c3e7b

Please sign in to comment.