Skip to content

Commit

Permalink
Merge pull request #2256 from mabel-dev/#2252
Browse files Browse the repository at this point in the history
  • Loading branch information
joocer authored Jan 16, 2025
2 parents bc50535 + 11654fd commit 0f2f81a
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 62 deletions.
2 changes: 1 addition & 1 deletion opteryx/__version__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__build__ = 987
__build__ = 988

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
15 changes: 1 addition & 14 deletions opteryx/functions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,19 +243,6 @@ def _inner(array, literal):
return _inner


def _iterate_double_parameter_field_second(func):
"""
for functions called FUNCTION(LITERAL, FIELD)
"""

def _inner(literal, array):
if isinstance(array, str):
array = [array]
return [func(literal, item) for item in array]

return _inner


def get_len(obj):
"""len, but nullsafe"""
if hasattr(obj, "__len__"):
Expand Down Expand Up @@ -467,7 +454,7 @@ def sleep(x):
"LOG": (compute.logb, OrsoTypes.DOUBLE, 1.0),

# DATES & TIMES
"DATE_TRUNC": (_iterate_double_parameter_field_second(dates.date_trunc), OrsoTypes.TIMESTAMP, 1.0),
"DATE_TRUNC": (dates.date_trunc, OrsoTypes.TIMESTAMP, 1.0),
"TIME_BUCKET": (date_functions.date_floor, OrsoTypes.TIMESTAMP, 1.0),
"DATEDIFF": (date_functions.date_diff, OrsoTypes.INTEGER, 1.0),
"TIMEDIFF": (date_functions.time_diff, OrsoTypes.INTEGER, 1.0),
Expand Down
3 changes: 1 addition & 2 deletions opteryx/functions/date_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,7 @@ def date_part(part, arr):
if not hasattr(arr, "__iter__"):
arr = numpy.array([arr])

if arr.dtype == numpy.int64():
arr = convert_int64_array_to_pyarrow_datetime(arr)
arr = numpy.array(arr, dtype="datetime64[us]")

part = part[0].lower() # [#325]
if part in extractors:
Expand Down
44 changes: 26 additions & 18 deletions opteryx/utils/dates.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,38 +188,46 @@ def parse_iso(value):
return None


def date_trunc(truncate_to, date_value):
def date_trunc(truncate_to, date_values) -> numpy.ndarray:
"""
Truncate a datetime to a specified unit
Truncate an array of datetimes to a specified unit
"""
date_value = parse_iso(date_value)

# numpy.datetime64(int(date_values), 's').astype(datetime.datetime)

date_values = numpy.array(date_values, dtype="datetime64")

if not isinstance(truncate_to, str):
truncate_to = truncate_to[0] # [#325]

# [#711]
truncate_to = str(truncate_to).lower()

# fmt:off
if truncate_to == "year":
return datetime.datetime(date_value.year, 1, 1, tzinfo=date_value.tzinfo)
return date_values.astype("datetime64[Y]").astype("datetime64[s]")
elif truncate_to == "quarter":
quarter = (date_value.month - 1) // 3 + 1
return datetime.datetime(date_value.year, 3 * (quarter - 1) + 1, 1, tzinfo=date_value.tzinfo)
months = date_values.astype("datetime64[M]").astype(int) // 3 * 3
return numpy.array(months,
dtype="datetime64[M]",
).astype("datetime64[s]")
elif truncate_to == "month":
return datetime.datetime(date_value.year, date_value.month, 1, tzinfo=date_value.tzinfo)
return date_values.astype("datetime64[M]").astype("datetime64[s]")
elif truncate_to == "week":
days_since_monday = date_value.weekday()
monday = date_value - datetime.timedelta(days=days_since_monday)
return date_trunc("day", monday)
return (
date_values
- ((date_values.astype("datetime64[D]").astype(int) - 4) % 7).astype("timedelta64[D]")
).astype("datetime64[D]").astype("datetime64[s]")
elif truncate_to == "day":
return datetime.datetime(date_value.year, date_value.month, date_value.day, tzinfo=date_value.tzinfo)
return date_values.astype("datetime64[D]").astype("datetime64[s]")
elif truncate_to == "hour":
return datetime.datetime(date_value.year, date_value.month, date_value.day, date_value.hour, tzinfo=date_value.tzinfo)
timestamps = date_values.astype("datetime64[s]").astype("int64")
truncated = (timestamps // 3600) * 3600
return truncated.astype("datetime64[s]")
elif truncate_to == "minute":
return datetime.datetime(date_value.year, date_value.month, date_value.day, date_value.hour, date_value.minute, tzinfo=date_value.tzinfo)
timestamps = date_values.astype("datetime64[s]").astype("int64")
truncated = (timestamps // 60) * 60
return truncated.astype("datetime64[s]")
elif truncate_to == "second":
return datetime.datetime(date_value.year, date_value.month, date_value.day, date_value.hour, date_value.minute, date_value.second, tzinfo=date_value.tzinfo)
else: # pragma: no cover
return date_values.astype("datetime64[s]")

else:
raise ValueError("Invalid unit: {}".format(truncate_to))
# fmt:on
52 changes: 26 additions & 26 deletions tests/functions/test_date_trunc.py
Original file line number Diff line number Diff line change
@@ -1,95 +1,95 @@
import os
import sys
import numpy

sys.path.insert(1, os.path.join(sys.path[0], "../.."))

from datetime import datetime, timezone

from opteryx.utils.dates import date_trunc

DEFAULT_DT = datetime(2012, 7, 12, 12, 14, 14, 342, timezone.utc)


def test_truncate_to_second():
dt = datetime(2012, 7, 12, 12, 14, 14, 342, timezone.utc)
expected = dt.replace(microsecond=0)
actual = date_trunc("second", dt)
expected = numpy.datetime64(dt.replace(microsecond=0))
actual = date_trunc("second", [dt])
assert actual == expected


def test_truncate_to_minute():
dt = datetime(2012, 7, 12, 12, 14, 14, 342, timezone.utc)
expected = dt.replace(second=0, microsecond=0)
actual = date_trunc("minute", dt)
expected = numpy.datetime64(dt.replace(second=0, microsecond=0))
actual = date_trunc("minute", [dt])
assert actual == expected


def test_truncate_to_hour():
dt = datetime(2012, 7, 12, 12, 14, 14, 342, timezone.utc)
expected = dt.replace(minute=0, second=0, microsecond=0)
actual = date_trunc("hour", dt)
expected = numpy.datetime64(dt.replace(minute=0, second=0, microsecond=0))
actual = date_trunc("hour", [dt])
assert actual == expected


def test_truncate_to_day():
dt = datetime(2012, 7, 12, 12, 14, 14, 342, timezone.utc)
expected = dt.replace(hour=0, minute=0, second=0, microsecond=0)
actual = date_trunc("day", dt)
expected = numpy.datetime64(dt.replace(hour=0, minute=0, second=0, microsecond=0))
actual = date_trunc("day", [dt])
assert actual == expected, f"{actual}, {expected}"


def test_truncate_to_month():
dt = datetime(2012, 7, 12, 12, 14, 14, 342, timezone.utc)
expected = dt.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
actual = date_trunc("month", dt)
expected = numpy.datetime64(dt.replace(day=1, hour=0, minute=0, second=0, microsecond=0))
actual = date_trunc("month", [dt])
assert actual == expected


def test_truncate_to_year():
dt = datetime(2012, 7, 12, 12, 14, 14, 342, timezone.utc)
expected = dt.replace(month=1, day=1, hour=0, minute=0, second=0, microsecond=0)
actual = date_trunc("year", dt)
expected = numpy.datetime64(dt.replace(month=1, day=1, hour=0, minute=0, second=0, microsecond=0))
actual = date_trunc("year", [dt])
assert actual == expected


def test_truncate_to_week():
dt = datetime(2012, 7, 12, 12, 14, 14, 342, timezone.utc)
expected = dt.replace(day=9, hour=0, minute=0, second=0, microsecond=0)
actual = date_trunc("week", dt)
expected = numpy.datetime64(dt.replace(day=9, hour=0, minute=0, second=0, microsecond=0))
actual = date_trunc("week", [dt])
assert actual == expected, f"{actual}, {expected}"

dt = datetime(2012, 7, 9, 12, 14, 14, 342, timezone.utc)
expected = dt.replace(hour=0, minute=0, second=0, microsecond=0)
actual = date_trunc("week", dt)
expected = numpy.datetime64(dt.replace(hour=0, minute=0, second=0, microsecond=0))
actual = date_trunc("week", [dt])
assert actual == expected


def test_truncate_to_quarter():
dt = datetime(2012, 7, 12, 12, 14, 14, 342, timezone.utc)
expected = datetime(2012, 7, 1, 0, 0, 0, 0, timezone.utc)
actual = date_trunc("quarter", dt)
expected = numpy.datetime64(datetime(2012, 7, 1, 0, 0, 0, 0, timezone.utc))
actual = date_trunc("quarter", [dt])
assert actual == expected, f"{actual}, {expected}"

dt = datetime(2012, 1, 15, 10, 30, 45, 123, timezone.utc)
expected = datetime(2012, 1, 1, 0, 0, 0, 0, timezone.utc)
actual = date_trunc("quarter", dt)
expected = numpy.datetime64(datetime(2012, 1, 1, 0, 0, 0, 0, timezone.utc))
actual = date_trunc("quarter", [dt])
assert actual == expected, f"{actual}, {expected}"

dt = datetime(2012, 6, 25, 5, 20, 30, 456, timezone.utc)
expected = datetime(2012, 4, 1, 0, 0, 0, 0, timezone.utc)
actual = date_trunc("quarter", dt)
expected = numpy.datetime64(datetime(2012, 4, 1, 0, 0, 0, 0, timezone.utc))
actual = date_trunc("quarter", [dt])
assert actual == expected, f"{actual}, {expected}"

dt = datetime(2012, 11, 5, 23, 59, 59, 999, timezone.utc)
expected = datetime(2012, 10, 1, 0, 0, 0, 0, timezone.utc)
actual = date_trunc("quarter", dt)
expected = numpy.datetime64(datetime(2012, 10, 1, 0, 0, 0, 0, timezone.utc))
actual = date_trunc("quarter", [dt])
assert actual == expected, f"{actual}, {expected}"


def test_truncate_to_decade():
dt = datetime(2012, 7, 12, 12, 14, 14, 342, timezone.utc)
try:
date_trunc("decade", dt)
date_trunc("decade", [dt])
except ValueError:
pass
except Exception as e:
Expand Down
3 changes: 2 additions & 1 deletion tests/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ sqlalchemy-cockroachdb
sqlalchemy-bigquery
pyodbc
cassandra-driver
pyiceberg[sql-sqlite, pyarrow]
pyiceberg[sql-sqlite]
psycopg2-binary
openpyxl

setuptools_rust
Expand Down

0 comments on commit 0f2f81a

Please sign in to comment.