Skip to content

Commit

Permalink
Extended aggregates support
Browse files Browse the repository at this point in the history
  • Loading branch information
eddelbuettel committed Nov 28, 2023
1 parent 54f9214 commit f389063
Show file tree
Hide file tree
Showing 5 changed files with 130 additions and 16 deletions.
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ export(tdb_select)
export(tile)
export(tile_order)
export(tiledb_array)
export(tiledb_array_apply_aggregate)
export(tiledb_array_close)
export(tiledb_array_create)
export(tiledb_array_delete_fragments)
Expand Down
26 changes: 26 additions & 0 deletions R/Array.R
Original file line number Diff line number Diff line change
Expand Up @@ -177,3 +177,29 @@ tiledb_array_has_enumeration <- function(arr) {
}
return(libtiledb_array_has_enumeration_vector(ctx@ptr, arr@ptr))
}

##' Run an aggregate query on the given array and attribute
##'
##' @param qry A TileDB Query object
##' @param attrname The name of an attribute
##' @param operation The name of aggregation operation
##' @param nullable A boolean toggle whether the attribute is nullable
##' @return The value of the aggregation
##' @export
tiledb_array_apply_aggregate <- function(array, attrname, operation, nullable = TRUE) {
stopifnot("The 'query' argument must be a TileDB Array object" = is(array, "tiledb_array"),
"The 'attrname' argument must be character" = is.character(attrname),
"The 'operation' argument must be character" = is.character(operation),
"The 'nullable' argument must be logical" = is.logical(nullable))
## TODO: match.arg for operation

if (tiledb_array_is_open(array))
array <- tiledb_array_close(array)

query <- tiledb_query(array, "READ")

if (tiledb_query_get_layout(query) != "UNORDERED")
query <- tiledb_query_set_layout(query, "UNORDERED") # TODO: allow GLOBAL_ORDER too?

libtiledb_query_apply_aggregate(query@ptr, attrname, operation, nullable)
}
2 changes: 1 addition & 1 deletion R/Query.R
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# MIT License
#
# Copyright (c) 2017-2022 TileDB Inc.
# Copyright (c) 2017-2023 TileDB Inc.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
Expand Down
23 changes: 23 additions & 0 deletions man/tiledb_array_apply_aggregate.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

94 changes: 79 additions & 15 deletions src/libtiledb.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3881,6 +3881,22 @@ XPtr<tiledb::Context> libtiledb_query_get_ctx(XPtr<tiledb::Query> query) {
return make_xptr<tiledb::Context>(new tiledb::Context(ctx));
}

template <typename T>
SEXP apply_unary_aggregate(XPtr<tiledb::Query> query, std::string operator_name, bool nullable = false) {
#if TILEDB_VERSION >= TileDB_Version(2,18,0)
T result = 0;
std::vector<uint8_t> nulls = { 0 };
uint64_t size = 1;
query->set_data_buffer(operator_name, &result, size);
if (nullable) query->set_validity_buffer(operator_name, nulls);
query->submit();
SEXP res = Rcpp::wrap(result);
return res;
#else
return Rcpp::wrap(R_NaReal);
#endif
}

// [[Rcpp::export]]
SEXP libtiledb_query_apply_aggregate(XPtr<tiledb::Query> query,
std::string attribute_name,
Expand All @@ -3889,35 +3905,83 @@ SEXP libtiledb_query_apply_aggregate(XPtr<tiledb::Query> query,
#if TILEDB_VERSION >= TileDB_Version(2,18,0)
check_xptr_tag<tiledb::Query>(query);
tiledb::QueryChannel channel = tiledb::QueryExperimental::get_default_channel(*query.get());
tiledb::ChannelOperation operation;
if (operator_name == "Sum") {
operation = tiledb::QueryExperimental::create_unary_aggregate<tiledb::SumOperator>(*query.get(), attribute_name);
tiledb::ChannelOperation operation = tiledb::QueryExperimental::create_unary_aggregate<tiledb::SumOperator>(*query.get(), attribute_name);
channel.apply_aggregate(operator_name, operation);
} else if (operator_name == "Min") {
operation = tiledb::QueryExperimental::create_unary_aggregate<tiledb::MinOperator>(*query.get(), attribute_name);
tiledb::ChannelOperation operation = tiledb::QueryExperimental::create_unary_aggregate<tiledb::MinOperator>(*query.get(), attribute_name);
channel.apply_aggregate(operator_name, operation);
} else if (operator_name == "Max") {
operation = tiledb::QueryExperimental::create_unary_aggregate<tiledb::MaxOperator>(*query.get(), attribute_name);
tiledb::ChannelOperation operation = tiledb::QueryExperimental::create_unary_aggregate<tiledb::MaxOperator>(*query.get(), attribute_name);
channel.apply_aggregate(operator_name, operation);
} else if (operator_name == "Mean") {
operation = tiledb::QueryExperimental::create_unary_aggregate<tiledb::MeanOperator>(*query.get(), attribute_name);
tiledb::ChannelOperation operation = tiledb::QueryExperimental::create_unary_aggregate<tiledb::MeanOperator>(*query.get(), attribute_name);
channel.apply_aggregate(operator_name, operation);
} else if (operator_name == "NullCount") {
operation = tiledb::QueryExperimental::create_unary_aggregate<tiledb::NullCountOperator>(*query.get(), attribute_name);
tiledb::ChannelOperation operation = tiledb::QueryExperimental::create_unary_aggregate<tiledb::NullCountOperator>(*query.get(), attribute_name);
channel.apply_aggregate(operator_name, operation);
} else if (operator_name == "Count") {
channel.apply_aggregate(operator_name, tiledb::CountOperation());
} else {
Rcpp::stop("Invalid aggregation operator '%s' specified.", operator_name.c_str());
}
channel.apply_aggregate(operator_name, operation);
std::vector<uint8_t> nulls = { 0 };
uint64_t size = 1;
if (operator_name != "NullCount") {
double result = 0;
query->set_data_buffer(operator_name, &result, size);
if (nullable) query->set_validity_buffer(operator_name, nulls);
query->submit();
return Rcpp::wrap(result);
} else {
if (operator_name == "NullCount" || operator_name == "Count") {
// Count and null count take uint64_t.
uint64_t result = 0;
query->set_data_buffer(operator_name, &result, size);
// no validity buffer for NullCount
if (nullable && operator_name != "NullCount") { // no validity buffer for NullCount
query->set_validity_buffer(operator_name, nulls);
}
query->submit();
return Rcpp::wrap(result);
} else if (operator_name == "Mean") {
// Mean always takes in a double.
return apply_unary_aggregate<double>(query, operator_name, nullable);
} else if (operator_name == "Sum") {
// Sum will take int64_t for signed integers, uint64_t for unsigned integers
// and double for floating point values.
tiledb::Context ctx = query->ctx();
auto arr = query->array();
auto sch = tiledb::ArraySchema(ctx, arr.uri());
auto attr = tiledb::Attribute(sch.attribute(attribute_name));
std::string type_name = _tiledb_datatype_to_string(attr.type());
if (type_name == "INT8" || type_name == "INT16" ||
type_name == "INT32" || type_name == "INT64") {
return apply_unary_aggregate<int64_t>(query, operator_name, nullable);
} else if (type_name == "UINT8" || type_name == "UINT16" ||
type_name == "UINT32" || type_name == "UINT64") {
return apply_unary_aggregate<uint64_t>(query, operator_name, nullable);
} else if (type_name == "FLOAT32" || type_name == "FLOAT64") {
return apply_unary_aggregate<double>(query, operator_name, nullable);
} else {
Rcpp::stop("'Sum' operator not valid for attribute '%s' of type '%s'",
attribute_name, type_name);
}
} else if (operator_name == "Min" || operator_name == "Max") {
// Min/max will take whatever the datatype of the column is.
tiledb::Context ctx = query->ctx();
auto arr = query->array();
auto sch = tiledb::ArraySchema(ctx, arr.uri());
auto attr = tiledb::Attribute(sch.attribute(attribute_name));
std::string type_name = _tiledb_datatype_to_string(attr.type());
switch (attr.type()) {
case TILEDB_INT8: return apply_unary_aggregate<int16_t>(query, operator_name, nullable); // int8_t bites char
case TILEDB_INT16: return apply_unary_aggregate<int16_t>(query, operator_name, nullable);
case TILEDB_INT32: return apply_unary_aggregate<int32_t>(query, operator_name, nullable);
case TILEDB_INT64: return apply_unary_aggregate<int64_t>(query, operator_name, nullable);
case TILEDB_UINT8: return apply_unary_aggregate<uint16_t>(query, operator_name, nullable); // uint8_t bites char
case TILEDB_UINT16: return apply_unary_aggregate<uint16_t>(query, operator_name, nullable);
case TILEDB_UINT32: return apply_unary_aggregate<uint32_t>(query, operator_name, nullable);
case TILEDB_UINT64: return apply_unary_aggregate<uint64_t>(query, operator_name, nullable);
case TILEDB_FLOAT32: return apply_unary_aggregate<float>(query, operator_name, nullable);
case TILEDB_FLOAT64: return apply_unary_aggregate<double>(query, operator_name, nullable);
default: Rcpp::stop("'%s' is not defined for attribute '%s' of type '%s'",
operator_name, attribute_name, type_name);
}
} else {
Rcpp::stop("'%s' is not implemented for '%s'", operator_name, attribute_name);
}
#else
return Rcpp::wrap(R_NaReal);
Expand Down

0 comments on commit f389063

Please sign in to comment.