Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
joocer committed Jan 24, 2025
1 parent 411cee7 commit 7569798
Show file tree
Hide file tree
Showing 2 changed files with 157 additions and 7 deletions.
43 changes: 39 additions & 4 deletions opteryx/connectors/iceberg_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,36 +15,71 @@
from typing import Union

import pyarrow
import pyiceberg.typedef
import pyiceberg.types
from orso.schema import FlatColumn
from orso.schema import RelationSchema

from opteryx.connectors import DiskConnector
from opteryx.connectors.base.base_connector import BaseConnector
from opteryx.connectors.capabilities import LimitPushable
from opteryx.connectors.capabilities import Statistics
from opteryx.models import RelationStatistics


class IcebergConnector(BaseConnector, LimitPushable):
class IcebergConnector(BaseConnector, LimitPushable, Statistics):
__mode__ = "Blob"
__type__ = "ICEBERG"

def __init__(self, *args, catalog=None, io=DiskConnector, **kwargs):
BaseConnector.__init__(self, **kwargs)
LimitPushable.__init__(self, **kwargs)
Statistics.__init__(self, **kwargs)

self.dataset = self.dataset.lower()
self.table = catalog.load_table(self.dataset)
self.io_connector = io(**kwargs)

def get_dataset_schema(self) -> RelationSchema:
arrow_schema = self.table.schema().as_arrow()
iceberg_schema = self.table.schema()
arrow_schema = iceberg_schema.as_arrow()

self.schema = RelationSchema(
name=self.dataset,
columns=[FlatColumn.from_arrow(field) for field in arrow_schema],
)

# Get statistics
relation_statistics = RelationStatistics()

column_names = {col.field_id: col.name for col in iceberg_schema.columns}
column_types = {col.field_id: col.field_type for col in iceberg_schema.columns}

files = self.table.inspect.files()
relation_statistics.record_count = pyarrow.compute.sum(files.column("record_count")).as_py()

if "distinct_counts" in files.columns:
for file in files.column("distinct_counts"):
for k, v in file:
relation_statistics.set_cardinality_estimate(column_names[k], v)

if "value_counts" in files.columns:
for file in files.column("value_counts"):
for k, v in file:
relation_statistics.add_count(column_names[k], v)

for file in files.column("lower_bounds"):
for k, v in file:
relation_statistics.update_lower(
column_names[k], IcebergConnector.decode_iceberg_value(v, column_types[k])
)

for file in files.column("upper_bounds"):
for k, v in file:
relation_statistics.update_upper(
column_names[k], IcebergConnector.decode_iceberg_value(v, column_types[k])
)

self.relation_statistics = relation_statistics

return self.schema

def read_dataset(self, columns: list = None, **kwargs) -> pyarrow.Table:
Expand Down
121 changes: 118 additions & 3 deletions tests/catalog/test_iceberg.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,16 +46,16 @@ def set_up_iceberg():
)

# Step 2: Get the data (so we can get the schema)
data = opteryx.query_to_arrow("SELECT * FROM testdata.flat.formats.parquet")
data = opteryx.query_to_arrow("SELECT tweet_id, text, timestamp, user_id, user_verified, user_name, hash_tags, followers, following, tweets_by_user, is_quoting, is_reply_to, is_retweeting FROM testdata.flat.formats.parquet")

# Step 3: Create an Iceberg table
catalog.create_namespace("iceberg")
table = catalog.create_table("iceberg.tweets", schema=data.schema)

# Step 4: Copy the Parquet files into the warehouse
table.overwrite(data)
table.append(data.slice(0, 50000))
table.append(data.slice(50000, 50000))

print(f"Iceberg table set up at {BASE_PATH}")
return BASE_PATH

@skip_if(is_arm() or is_windows() or is_mac())
Expand Down Expand Up @@ -99,6 +99,121 @@ def test_iceberg_get_schema():
table = catalog.load_table("iceberg.tweets")
table.schema().as_arrow()


@skip_if(is_arm() or is_windows() or is_mac())
def test_iceberg_get_statistics_manual():

import pyarrow
from pyiceberg.catalog import load_catalog
from opteryx.models.relation_statistics import RelationStatistics

set_up_iceberg()

catalog = load_catalog(
"default",
**{
"uri": f"sqlite:///{BASE_PATH}/pyiceberg_catalog.db",
"warehouse": f"file://{BASE_PATH}",
},
)

opteryx.register_store("iceberg", IcebergConnector, catalog=catalog, io=DiskConnector)

table = catalog.load_table("iceberg.tweets")
table.schema().as_arrow()

stats = RelationStatistics()

column_names = {col.field_id:col.name for col in table.schema().columns}
column_types = {col.field_id:col.field_type for col in table.schema().columns}

files = table.inspect.files()
stats.record_count = pyarrow.compute.sum(files.column("record_count")).as_py()

if "distinct_counts" in files.columns:
for file in files.column("distinct_counts"):
for k, v in file:
stats.set_cardinality_estimate[column_names[k]] += v

if "value_counts" in files.columns:
for file in files.column("value_counts"):
for k, v in file:
stats.add_count(column_names[k], v)

for file in files.column("lower_bounds"):
for k, v in file:
stats.update_lower(column_names[k], IcebergConnector.decode_iceberg_value(v, column_types[k]))

for file in files.column("upper_bounds"):
for k, v in file:
stats.update_upper(column_names[k], IcebergConnector.decode_iceberg_value(v, column_types[k]))

assert stats.record_count == 100000
assert stats.lower_bounds["followers"] == 0
assert stats.upper_bounds["followers"] == 8266250
assert stats.lower_bounds["user_name"] == ""
assert stats.upper_bounds["user_name"] == "🫖🔫"
assert stats.lower_bounds["tweet_id"] == 1346604539013705728
assert stats.upper_bounds["tweet_id"] == 1346615999009755142
assert stats.lower_bounds["text"] == "!! PLEASE STOP A"
assert stats.upper_bounds["text"] == "🪶Cultural approq"
assert stats.lower_bounds["timestamp"] == "2021-01-05T23:48"
assert stats.upper_bounds["timestamp"] == "2021-01-06T00:35"

@skip_if(is_arm() or is_windows() or is_mac())
def test_iceberg_connector():

from pyiceberg.catalog import load_catalog
from opteryx.models.relation_statistics import RelationStatistics

set_up_iceberg()

catalog = load_catalog(
"default",
**{
"uri": f"sqlite:///{BASE_PATH}/pyiceberg_catalog.db",
"warehouse": f"file://{BASE_PATH}",
},
)

opteryx.register_store("iceberg", IcebergConnector, catalog=catalog)
table = opteryx.query("SELECT * FROM iceberg.tweets WHERE followers = 10")
assert table.shape[0] == 353

@skip_if(is_arm() or is_windows() or is_mac())
def test_iceberg_get_stats():

from pyiceberg.catalog import load_catalog
from opteryx.connectors import IcebergConnector, connector_factory

set_up_iceberg()

catalog = load_catalog(
"default",
**{
"uri": f"sqlite:///{BASE_PATH}/pyiceberg_catalog.db",
"warehouse": f"file://{BASE_PATH}",
},
)

opteryx.register_store("iceberg", IcebergConnector, catalog=catalog, io=DiskConnector)
connector = connector_factory("iceberg.tweets", None)
connector.get_dataset_schema()
stats = connector.relation_statistics

assert stats.record_count == 100000
assert stats.lower_bounds["followers"] == 0
assert stats.upper_bounds["followers"] == 8266250
assert stats.lower_bounds["user_name"] == ""
assert stats.upper_bounds["user_name"] == "🫖🔫"
assert stats.lower_bounds["tweet_id"] == 1346604539013705728
assert stats.upper_bounds["tweet_id"] == 1346615999009755142
assert stats.lower_bounds["text"] == "!! PLEASE STOP A"
assert stats.upper_bounds["text"] == "🪶Cultural approq"
assert stats.lower_bounds["timestamp"] == "2021-01-05T23:48"
assert stats.upper_bounds["timestamp"] == "2021-01-06T00:35"


@skip_if(is_arm() or is_windows() or is_mac())
def test_iceberg_remote():

Expand Down

0 comments on commit 7569798

Please sign in to comment.