diff --git a/NAMESPACE b/NAMESPACE index 5e3e18ceb8..fde87fd1ae 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -83,6 +83,7 @@ export(tdb_select) export(tile) export(tile_order) export(tiledb_array) +export(tiledb_array_apply_aggregate) export(tiledb_array_close) export(tiledb_array_create) export(tiledb_array_delete_fragments) diff --git a/R/Array.R b/R/Array.R index 675d924c62..316415dea2 100644 --- a/R/Array.R +++ b/R/Array.R @@ -177,3 +177,29 @@ tiledb_array_has_enumeration <- function(arr) { } return(libtiledb_array_has_enumeration_vector(ctx@ptr, arr@ptr)) } + +##' Run an aggregate query on the given array and attribute +##' +##' @param qry A TileDB Query object +##' @param attrname The name of an attribute +##' @param operation The name of aggregation operation +##' @param nullable A boolean toggle whether the attribute is nullable +##' @return The value of the aggregation +##' @export +tiledb_array_apply_aggregate <- function(array, attrname, operation, nullable = TRUE) { + stopifnot("The 'query' argument must be a TileDB Array object" = is(array, "tiledb_array"), + "The 'attrname' argument must be character" = is.character(attrname), + "The 'operation' argument must be character" = is.character(operation), + "The 'nullable' argument must be logical" = is.logical(nullable)) + ## TODO: match.arg for operation + + if (tiledb_array_is_open(array)) + array <- tiledb_array_close(array) + + query <- tiledb_query(array, "READ") + + if (tiledb_query_get_layout(query) != "UNORDERED") + query <- tiledb_query_set_layout(query, "UNORDERED") # TODO: allow GLOBAL_ORDER too? + + libtiledb_query_apply_aggregate(query@ptr, attrname, operation, nullable) +} diff --git a/R/Query.R b/R/Query.R index 34a9d47fc3..7d56f0a617 100644 --- a/R/Query.R +++ b/R/Query.R @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2017-2022 TileDB Inc. +# Copyright (c) 2017-2023 TileDB Inc. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal diff --git a/man/tiledb_array_apply_aggregate.Rd b/man/tiledb_array_apply_aggregate.Rd new file mode 100644 index 0000000000..3f9bd53e06 --- /dev/null +++ b/man/tiledb_array_apply_aggregate.Rd @@ -0,0 +1,23 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/Array.R +\name{tiledb_array_apply_aggregate} +\alias{tiledb_array_apply_aggregate} +\title{Run an aggregate query on the given array and attribute} +\usage{ +tiledb_array_apply_aggregate(array, attrname, operation, nullable = TRUE) +} +\arguments{ +\item{attrname}{The name of an attribute} + +\item{operation}{The name of aggregation operation} + +\item{nullable}{A boolean toggle whether the attribute is nullable} + +\item{qry}{A TileDB Query object} +} +\value{ +The value of the aggregation +} +\description{ +Run an aggregate query on the given array and attribute +} diff --git a/src/libtiledb.cpp b/src/libtiledb.cpp index 3f7c067c61..5459e664cf 100644 --- a/src/libtiledb.cpp +++ b/src/libtiledb.cpp @@ -3881,6 +3881,22 @@ XPtr libtiledb_query_get_ctx(XPtr query) { return make_xptr(new tiledb::Context(ctx)); } +template +SEXP apply_unary_aggregate(XPtr query, std::string operator_name, bool nullable = false) { +#if TILEDB_VERSION >= TileDB_Version(2,18,0) + T result = 0; + std::vector nulls = { 0 }; + uint64_t size = 1; + query->set_data_buffer(operator_name, &result, size); + if (nullable) query->set_validity_buffer(operator_name, nulls); + query->submit(); + SEXP res = Rcpp::wrap(result); + return res; +#else + return Rcpp::wrap(R_NaReal); +#endif +} + // [[Rcpp::export]] SEXP libtiledb_query_apply_aggregate(XPtr query, std::string attribute_name, @@ -3889,35 +3905,83 @@ SEXP libtiledb_query_apply_aggregate(XPtr query, #if TILEDB_VERSION >= TileDB_Version(2,18,0) check_xptr_tag(query); tiledb::QueryChannel channel = tiledb::QueryExperimental::get_default_channel(*query.get()); - tiledb::ChannelOperation operation; if (operator_name == "Sum") { - operation = tiledb::QueryExperimental::create_unary_aggregate(*query.get(), attribute_name); + tiledb::ChannelOperation operation = tiledb::QueryExperimental::create_unary_aggregate(*query.get(), attribute_name); + channel.apply_aggregate(operator_name, operation); } else if (operator_name == "Min") { - operation = tiledb::QueryExperimental::create_unary_aggregate(*query.get(), attribute_name); + tiledb::ChannelOperation operation = tiledb::QueryExperimental::create_unary_aggregate(*query.get(), attribute_name); + channel.apply_aggregate(operator_name, operation); } else if (operator_name == "Max") { - operation = tiledb::QueryExperimental::create_unary_aggregate(*query.get(), attribute_name); + tiledb::ChannelOperation operation = tiledb::QueryExperimental::create_unary_aggregate(*query.get(), attribute_name); + channel.apply_aggregate(operator_name, operation); } else if (operator_name == "Mean") { - operation = tiledb::QueryExperimental::create_unary_aggregate(*query.get(), attribute_name); + tiledb::ChannelOperation operation = tiledb::QueryExperimental::create_unary_aggregate(*query.get(), attribute_name); + channel.apply_aggregate(operator_name, operation); } else if (operator_name == "NullCount") { - operation = tiledb::QueryExperimental::create_unary_aggregate(*query.get(), attribute_name); + tiledb::ChannelOperation operation = tiledb::QueryExperimental::create_unary_aggregate(*query.get(), attribute_name); + channel.apply_aggregate(operator_name, operation); + } else if (operator_name == "Count") { + channel.apply_aggregate(operator_name, tiledb::CountOperation()); } else { Rcpp::stop("Invalid aggregation operator '%s' specified.", operator_name.c_str()); } - channel.apply_aggregate(operator_name, operation); std::vector nulls = { 0 }; uint64_t size = 1; - if (operator_name != "NullCount") { - double result = 0; - query->set_data_buffer(operator_name, &result, size); - if (nullable) query->set_validity_buffer(operator_name, nulls); - query->submit(); - return Rcpp::wrap(result); - } else { + if (operator_name == "NullCount" || operator_name == "Count") { + // Count and null count take uint64_t. uint64_t result = 0; query->set_data_buffer(operator_name, &result, size); - // no validity buffer for NullCount + if (nullable && operator_name != "NullCount") { // no validity buffer for NullCount + query->set_validity_buffer(operator_name, nulls); + } query->submit(); return Rcpp::wrap(result); + } else if (operator_name == "Mean") { + // Mean always takes in a double. + return apply_unary_aggregate(query, operator_name, nullable); + } else if (operator_name == "Sum") { + // Sum will take int64_t for signed integers, uint64_t for unsigned integers + // and double for floating point values. + tiledb::Context ctx = query->ctx(); + auto arr = query->array(); + auto sch = tiledb::ArraySchema(ctx, arr.uri()); + auto attr = tiledb::Attribute(sch.attribute(attribute_name)); + std::string type_name = _tiledb_datatype_to_string(attr.type()); + if (type_name == "INT8" || type_name == "INT16" || + type_name == "INT32" || type_name == "INT64") { + return apply_unary_aggregate(query, operator_name, nullable); + } else if (type_name == "UINT8" || type_name == "UINT16" || + type_name == "UINT32" || type_name == "UINT64") { + return apply_unary_aggregate(query, operator_name, nullable); + } else if (type_name == "FLOAT32" || type_name == "FLOAT64") { + return apply_unary_aggregate(query, operator_name, nullable); + } else { + Rcpp::stop("'Sum' operator not valid for attribute '%s' of type '%s'", + attribute_name, type_name); + } + } else if (operator_name == "Min" || operator_name == "Max") { + // Min/max will take whatever the datatype of the column is. + tiledb::Context ctx = query->ctx(); + auto arr = query->array(); + auto sch = tiledb::ArraySchema(ctx, arr.uri()); + auto attr = tiledb::Attribute(sch.attribute(attribute_name)); + std::string type_name = _tiledb_datatype_to_string(attr.type()); + switch (attr.type()) { + case TILEDB_INT8: return apply_unary_aggregate(query, operator_name, nullable); // int8_t bites char + case TILEDB_INT16: return apply_unary_aggregate(query, operator_name, nullable); + case TILEDB_INT32: return apply_unary_aggregate(query, operator_name, nullable); + case TILEDB_INT64: return apply_unary_aggregate(query, operator_name, nullable); + case TILEDB_UINT8: return apply_unary_aggregate(query, operator_name, nullable); // uint8_t bites char + case TILEDB_UINT16: return apply_unary_aggregate(query, operator_name, nullable); + case TILEDB_UINT32: return apply_unary_aggregate(query, operator_name, nullable); + case TILEDB_UINT64: return apply_unary_aggregate(query, operator_name, nullable); + case TILEDB_FLOAT32: return apply_unary_aggregate(query, operator_name, nullable); + case TILEDB_FLOAT64: return apply_unary_aggregate(query, operator_name, nullable); + default: Rcpp::stop("'%s' is not defined for attribute '%s' of type '%s'", + operator_name, attribute_name, type_name); + } + } else { + Rcpp::stop("'%s' is not implemented for '%s'", operator_name, attribute_name); } #else return Rcpp::wrap(R_NaReal);