Skip to content

Commit

Permalink
Add support for aggregates (#623)
Browse files Browse the repository at this point in the history
* WIP for aggregates

* Extended aggregates support

* Expanded tiledb_array_apply_aggregate and updated help page

* Add unit tests

* Raise download.file timeout limit from 60s to 180s
  • Loading branch information
eddelbuettel authored Nov 29, 2023
1 parent 19846fd commit 20e0ca4
Show file tree
Hide file tree
Showing 12 changed files with 222 additions and 11 deletions.
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ export(tdb_select)
export(tile)
export(tile_order)
export(tiledb_array)
export(tiledb_array_apply_aggregate)
export(tiledb_array_close)
export(tiledb_array_create)
export(tiledb_array_delete_fragments)
Expand Down
30 changes: 30 additions & 0 deletions R/Array.R
Original file line number Diff line number Diff line change
Expand Up @@ -177,3 +177,33 @@ tiledb_array_has_enumeration <- function(arr) {
}
return(libtiledb_array_has_enumeration_vector(ctx@ptr, arr@ptr))
}

##' Run an aggregate query on the given array and attribute
##'
##' @param array A TileDB Array object
##' @param attrname The name of an attribute
##' @param operation The name of aggregation operation
##' @param nullable A boolean toggle whether the attribute is nullable
##' @return The value of the aggregation
##' @export
tiledb_array_apply_aggregate <- function(array, attrname,
operation = c("Count", "NullCount", "Min", "Max",
"Mean", "Sum"),
nullable = TRUE) {
stopifnot("The 'query' argument must be a TileDB Array object" = is(array, "tiledb_array"),
"The 'attrname' argument must be character" = is.character(attrname),
"The 'operation' argument must be character" = is.character(operation),
"The 'nullable' argument must be logical" = is.logical(nullable))

operation <- match.arg(operation)

if (tiledb_array_is_open(array))
array <- tiledb_array_close(array)

query <- tiledb_query(array, "READ")

if (! tiledb_query_get_layout(query) %in% c("UNORDERED", "GLOBAL_ORDER"))
query <- tiledb_query_set_layout(query, "UNORDERED")

libtiledb_query_apply_aggregate(query@ptr, attrname, operation, nullable)
}
2 changes: 1 addition & 1 deletion R/Query.R
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# MIT License
#
# Copyright (c) 2017-2022 TileDB Inc.
# Copyright (c) 2017-2023 TileDB Inc.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
Expand Down
4 changes: 4 additions & 0 deletions R/RcppExports.R
Original file line number Diff line number Diff line change
Expand Up @@ -814,6 +814,10 @@ libtiledb_query_get_ctx <- function(query) {
.Call(`_tiledb_libtiledb_query_get_ctx`, query)
}

libtiledb_query_apply_aggregate <- function(query, attribute_name, operator_name, nullable = FALSE) {
.Call(`_tiledb_libtiledb_query_apply_aggregate`, query, attribute_name, operator_name, nullable)
}

libtiledb_query_condition <- function(ctx) {
.Call(`_tiledb_libtiledb_query_condition`, ctx)
}
Expand Down
10 changes: 3 additions & 7 deletions inst/include/tiledb.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,9 @@

#include <tiledb/tiledb>
#if TILEDB_VERSION_MAJOR == 2 && TILEDB_VERSION_MINOR >= 4
#include <tiledb/tiledb_experimental>
#endif
#if TILEDB_VERSION_MAJOR == 2 && TILEDB_VERSION_MINOR >= 17
#include <tiledb/array_experimental.h>
#include <tiledb/attribute_experimental.h>
#include <tiledb/enumeration_experimental.h>
#include <tiledb/query_condition_experimental.h>
// this header includes the other experimental headers
// condition on the appropriate version is still done in each function
#include <tiledb/tiledb_experimental>
#endif

// Use the 'finalizer on exit' toggle in the XPtr template to ensure
Expand Down
31 changes: 31 additions & 0 deletions inst/tinytest/test_aggregates.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
library(tinytest)
library(tiledb)

if (tiledb_version(TRUE) < "2.18.0") exit_file("Needs TileDB 2.18.0 or later")
if (!requireNamespace("palmerpenguins", quietly=TRUE)) exit_file("Remainder needs 'palmerpenguins'")

tiledb_ctx(limitTileDBCores())

library(palmerpenguins)
uri <- tempfile()
expect_silent(fromDataFrame(penguins, uri, sparse=TRUE))

expect_silent(arr <- tiledb_array(uri, extended=FALSE))

expect_error(tiledb_array_apply_aggregate(uri, "body_mass_g", "Mean")) # not an array
expect_error(tiledb_array_apply_aggregate(arr, "does_not_exit", "Mean")) # not an attribute
expect_error(tiledb_array_apply_aggregate(arr, "body_mass_g", "UnknownFunction")) # not an operator

expect_equal(tiledb_array_apply_aggregate(arr, "body_mass_g", "Count", FALSE), 344)
expect_equal(tiledb_array_apply_aggregate(arr, "body_mass_g", "NullCount"), 2)
expect_equal(tiledb_array_apply_aggregate(arr, "body_mass_g", "Min"), 2700)
expect_equal(tiledb_array_apply_aggregate(arr, "body_mass_g", "Max"), 6300)
expect_equal(tiledb_array_apply_aggregate(arr, "body_mass_g", "Sum"), 1437000)
expect_equal(tiledb_array_apply_aggregate(arr, "body_mass_g", "Mean"), 4201.7543869)

expect_equal(tiledb_array_apply_aggregate(arr, "year", "Count", FALSE), 344)
expect_error(tiledb_array_apply_aggregate(arr, "year", "NullCount")) # no nullcount on non-nullable
expect_equal(tiledb_array_apply_aggregate(arr, "year", "Min", FALSE), 2007)
expect_equal(tiledb_array_apply_aggregate(arr, "year", "Max", FALSE), 2009)
expect_equal(tiledb_array_apply_aggregate(arr, "year", "Sum", FALSE), 690762)
expect_equal(tiledb_array_apply_aggregate(arr, "year", "Mean", FALSE), 2008.02906977)
28 changes: 28 additions & 0 deletions man/tiledb_array_apply_aggregate.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 15 additions & 0 deletions src/RcppExports.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2415,6 +2415,20 @@ BEGIN_RCPP
return rcpp_result_gen;
END_RCPP
}
// libtiledb_query_apply_aggregate
SEXP libtiledb_query_apply_aggregate(XPtr<tiledb::Query> query, std::string attribute_name, std::string operator_name, bool nullable);
RcppExport SEXP _tiledb_libtiledb_query_apply_aggregate(SEXP querySEXP, SEXP attribute_nameSEXP, SEXP operator_nameSEXP, SEXP nullableSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
Rcpp::traits::input_parameter< XPtr<tiledb::Query> >::type query(querySEXP);
Rcpp::traits::input_parameter< std::string >::type attribute_name(attribute_nameSEXP);
Rcpp::traits::input_parameter< std::string >::type operator_name(operator_nameSEXP);
Rcpp::traits::input_parameter< bool >::type nullable(nullableSEXP);
rcpp_result_gen = Rcpp::wrap(libtiledb_query_apply_aggregate(query, attribute_name, operator_name, nullable));
return rcpp_result_gen;
END_RCPP
}
// libtiledb_query_condition
XPtr<tiledb::QueryCondition> libtiledb_query_condition(XPtr<tiledb::Context> ctx);
RcppExport SEXP _tiledb_libtiledb_query_condition(SEXP ctxSEXP) {
Expand Down Expand Up @@ -3758,6 +3772,7 @@ static const R_CallMethodDef CallEntries[] = {
{"_tiledb_libtiledb_query_get_schema", (DL_FUNC) &_tiledb_libtiledb_query_get_schema, 2},
{"_tiledb_libtiledb_query_stats", (DL_FUNC) &_tiledb_libtiledb_query_stats, 1},
{"_tiledb_libtiledb_query_get_ctx", (DL_FUNC) &_tiledb_libtiledb_query_get_ctx, 1},
{"_tiledb_libtiledb_query_apply_aggregate", (DL_FUNC) &_tiledb_libtiledb_query_apply_aggregate, 4},
{"_tiledb_libtiledb_query_condition", (DL_FUNC) &_tiledb_libtiledb_query_condition, 1},
{"_tiledb_libtiledb_query_condition_init", (DL_FUNC) &_tiledb_libtiledb_query_condition_init, 5},
{"_tiledb_libtiledb_query_condition_combine", (DL_FUNC) &_tiledb_libtiledb_query_condition_combine, 3},
Expand Down
106 changes: 106 additions & 0 deletions src/libtiledb.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3881,6 +3881,112 @@ XPtr<tiledb::Context> libtiledb_query_get_ctx(XPtr<tiledb::Query> query) {
return make_xptr<tiledb::Context>(new tiledb::Context(ctx));
}

template <typename T>
SEXP apply_unary_aggregate(XPtr<tiledb::Query> query, std::string operator_name, bool nullable = false) {
#if TILEDB_VERSION >= TileDB_Version(2,18,0)
T result = 0;
std::vector<uint8_t> nulls = { 0 };
uint64_t size = 1;
query->set_data_buffer(operator_name, &result, size);
if (nullable) query->set_validity_buffer(operator_name, nulls);
query->submit();
SEXP res = Rcpp::wrap(result);
return res;
#else
return Rcpp::wrap(R_NaReal);
#endif
}

// [[Rcpp::export]]
SEXP libtiledb_query_apply_aggregate(XPtr<tiledb::Query> query,
std::string attribute_name,
std::string operator_name,
bool nullable = false) {
#if TILEDB_VERSION >= TileDB_Version(2,18,0)
check_xptr_tag<tiledb::Query>(query);
tiledb::QueryChannel channel = tiledb::QueryExperimental::get_default_channel(*query.get());
if (operator_name == "Sum") {
tiledb::ChannelOperation operation = tiledb::QueryExperimental::create_unary_aggregate<tiledb::SumOperator>(*query.get(), attribute_name);
channel.apply_aggregate(operator_name, operation);
} else if (operator_name == "Min") {
tiledb::ChannelOperation operation = tiledb::QueryExperimental::create_unary_aggregate<tiledb::MinOperator>(*query.get(), attribute_name);
channel.apply_aggregate(operator_name, operation);
} else if (operator_name == "Max") {
tiledb::ChannelOperation operation = tiledb::QueryExperimental::create_unary_aggregate<tiledb::MaxOperator>(*query.get(), attribute_name);
channel.apply_aggregate(operator_name, operation);
} else if (operator_name == "Mean") {
tiledb::ChannelOperation operation = tiledb::QueryExperimental::create_unary_aggregate<tiledb::MeanOperator>(*query.get(), attribute_name);
channel.apply_aggregate(operator_name, operation);
} else if (operator_name == "NullCount") {
tiledb::ChannelOperation operation = tiledb::QueryExperimental::create_unary_aggregate<tiledb::NullCountOperator>(*query.get(), attribute_name);
channel.apply_aggregate(operator_name, operation);
} else if (operator_name == "Count") {
channel.apply_aggregate(operator_name, tiledb::CountOperation());
} else {
Rcpp::stop("Invalid aggregation operator '%s' specified.", operator_name.c_str());
}
std::vector<uint8_t> nulls = { 0 };
uint64_t size = 1;
if (operator_name == "NullCount" || operator_name == "Count") {
// Count and null count take uint64_t.
uint64_t result = 0;
query->set_data_buffer(operator_name, &result, size);
if (nullable && operator_name != "NullCount") { // no validity buffer for NullCount
query->set_validity_buffer(operator_name, nulls);
}
query->submit();
return Rcpp::wrap(result);
} else if (operator_name == "Mean") {
// Mean always takes in a double.
return apply_unary_aggregate<double>(query, operator_name, nullable);
} else if (operator_name == "Sum") {
// Sum will take int64_t for signed integers, uint64_t for unsigned integers
// and double for floating point values.
tiledb::Context ctx = query->ctx();
auto arr = query->array();
auto sch = tiledb::ArraySchema(ctx, arr.uri());
auto attr = tiledb::Attribute(sch.attribute(attribute_name));
std::string type_name = _tiledb_datatype_to_string(attr.type());
if (type_name == "INT8" || type_name == "INT16" ||
type_name == "INT32" || type_name == "INT64") {
return apply_unary_aggregate<int64_t>(query, operator_name, nullable);
} else if (type_name == "UINT8" || type_name == "UINT16" ||
type_name == "UINT32" || type_name == "UINT64") {
return apply_unary_aggregate<uint64_t>(query, operator_name, nullable);
} else if (type_name == "FLOAT32" || type_name == "FLOAT64") {
return apply_unary_aggregate<double>(query, operator_name, nullable);
} else {
Rcpp::stop("'Sum' operator not valid for attribute '%s' of type '%s'",
attribute_name, type_name);
}
} else if (operator_name == "Min" || operator_name == "Max") {
// Min/max will take whatever the datatype of the column is.
tiledb::Context ctx = query->ctx();
auto arr = query->array();
auto sch = tiledb::ArraySchema(ctx, arr.uri());
auto attr = tiledb::Attribute(sch.attribute(attribute_name));
std::string type_name = _tiledb_datatype_to_string(attr.type());
switch (attr.type()) {
case TILEDB_INT8: return apply_unary_aggregate<int16_t>(query, operator_name, nullable); // int8_t bites char
case TILEDB_INT16: return apply_unary_aggregate<int16_t>(query, operator_name, nullable);
case TILEDB_INT32: return apply_unary_aggregate<int32_t>(query, operator_name, nullable);
case TILEDB_INT64: return apply_unary_aggregate<int64_t>(query, operator_name, nullable);
case TILEDB_UINT8: return apply_unary_aggregate<uint16_t>(query, operator_name, nullable); // uint8_t bites char
case TILEDB_UINT16: return apply_unary_aggregate<uint16_t>(query, operator_name, nullable);
case TILEDB_UINT32: return apply_unary_aggregate<uint32_t>(query, operator_name, nullable);
case TILEDB_UINT64: return apply_unary_aggregate<uint64_t>(query, operator_name, nullable);
case TILEDB_FLOAT32: return apply_unary_aggregate<float>(query, operator_name, nullable);
case TILEDB_FLOAT64: return apply_unary_aggregate<double>(query, operator_name, nullable);
default: Rcpp::stop("'%s' is not defined for attribute '%s' of type '%s'",
operator_name, attribute_name, type_name);
}
} else {
Rcpp::stop("'%s' is not implemented for '%s'", operator_name, attribute_name);
}
#else
return Rcpp::wrap(R_NaReal);
#endif
}

/**
* Query Condition
Expand Down
2 changes: 1 addition & 1 deletion tools/fetchTileDBLib.R
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,6 @@ dlurl <- switch(osarg,
url = urlarg)
cat("downloading", dlurl, "\n")
op <- options()
options(timeout=60)
options(timeout=180)
download.file(dlurl, "tiledb.tar.gz", quiet=TRUE)
options(op)
2 changes: 1 addition & 1 deletion tools/fetchTileDBSrc.R
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ if (arg == "default") {
if (!file.exists("tiledb.tar,gz")) {
cat("Downloading", url, "\n")
op <- options()
options(timeout=60)
options(timeout=180)
download.file(url, "tiledb.tar.gz", quiet=TRUE)
options(op)
}
Expand Down
2 changes: 1 addition & 1 deletion tools/winlibs.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ ver <- dcf[[1, "version"]]
if (!file.exists("../windows/rwinlib-tiledb/include/tiledb/tiledb.h")) {
if (getRversion() < "4") stop("This package requires Rtools40 or newer")
op <- options()
options(timeout=60) # CRAN request to have patient download settings
options(timeout=180) # CRAN request to have patient download settings
download.file(sprintf("https://github.com/TileDB-Inc/rwinlib-tiledb/archive/v%s.zip", ver), "lib.zip", quiet = TRUE)
options(op)
dir.create("../windows", showWarnings = FALSE)
Expand Down

0 comments on commit 20e0ca4

Please sign in to comment.