From 79d7cc7f2d0dd52da38563026d379c7afdee8c84 Mon Sep 17 00:00:00 2001 From: Cocoa Date: Tue, 12 Nov 2024 01:07:50 +0000 Subject: [PATCH 1/6] c_src: bump ADBC C source to ADBC 15 rc1 --- 3rd_party/apache-arrow-adbc/CHANGELOG.md | 60 + 3rd_party/apache-arrow-adbc/CONTRIBUTING.md | 6 +- 3rd_party/apache-arrow-adbc/README.md | 2 +- 3rd_party/apache-arrow-adbc/c/CMakeLists.txt | 4 - 3rd_party/apache-arrow-adbc/c/apidoc/Doxyfile | 13 +- .../c/cmake_modules/AdbcVersion.cmake | 2 +- .../c/driver/bigquery/bigquery_test.cc | 13 +- .../apache-arrow-adbc/c/driver/common/utils.c | 66 +- .../apache-arrow-adbc/c/driver/common/utils.h | 6 +- .../driver/flightsql/sqlite_flightsql_test.cc | 1 + .../c/driver/framework/CMakeLists.txt | 2 +- .../c/driver/framework/base_driver.h | 11 + .../c/driver/framework/catalog.cc | 328 ----- .../c/driver/framework/catalog.h | 162 --- .../c/driver/framework/connection.h | 6 +- .../c/driver/framework/meson.build | 2 +- .../c/driver/framework/objects.cc | 175 ++- .../c/driver/framework/objects.h | 120 +- .../c/driver/framework/statement.h | 3 +- .../c/driver/framework/status.h | 38 + .../c/driver/framework/utility.cc | 179 +++ .../c/driver/framework/utility.h | 73 ++ .../c/driver/postgresql/bind_stream.h | 554 +++------ .../c/driver/postgresql/connection.cc | 1068 +++++++---------- .../c/driver/postgresql/connection.h | 8 +- .../copy/postgres_copy_reader_test.cc | 2 +- .../copy/postgres_copy_writer_test.cc | 162 +-- .../c/driver/postgresql/copy/reader.h | 5 +- .../c/driver/postgresql/copy/writer.h | 10 +- .../c/driver/postgresql/database.cc | 269 +++-- .../c/driver/postgresql/database.h | 23 +- .../c/driver/postgresql/error.cc | 73 +- .../c/driver/postgresql/error.h | 73 +- .../c/driver/postgresql/postgres_type.h | 96 +- .../c/driver/postgresql/postgres_type_test.cc | 18 +- .../c/driver/postgresql/postgresql.cc | 32 +- .../c/driver/postgresql/postgresql_test.cc | 98 +- .../c/driver/postgresql/result_helper.cc | 89 +- .../c/driver/postgresql/result_helper.h | 86 +- .../c/driver/postgresql/result_reader.cc | 106 +- .../c/driver/postgresql/result_reader.h | 29 +- .../c/driver/postgresql/statement.cc | 296 ++--- .../c/driver/postgresql/statement.h | 24 +- .../c/driver/snowflake/snowflake_test.cc | 12 +- .../c/driver/sqlite/sqlite.cc | 62 +- .../c/driver/sqlite/sqlite_test.cc | 21 +- .../c/driver/sqlite/statement_reader.c | 19 +- .../c/driver/sqlite/statement_reader.h | 5 - .../c/driver_manager/CMakeLists.txt | 4 + .../c/driver_manager/adbc_driver_manager.cc | 63 +- .../adbc_driver_manager_test.cc | 18 +- .../c/include/arrow-adbc/adbc.h | 2 +- .../include/arrow-adbc/adbc_driver_manager.h | 5 + 3rd_party/apache-arrow-adbc/c/meson.build | 2 +- .../c/subprojects/nanoarrow.wrap | 10 +- .../c/validation/adbc_validation.h | 54 +- .../validation/adbc_validation_connection.cc | 35 +- .../c/validation/adbc_validation_statement.cc | 147 ++- .../c/validation/adbc_validation_util.cc | 60 +- .../c/validation/adbc_validation_util.h | 127 +- .../c/vendor/nanoarrow/nanoarrow.c | 701 +++++++++-- .../c/vendor/nanoarrow/nanoarrow.h | 600 ++++++++- .../c/vendor/nanoarrow/nanoarrow.hpp | 13 +- .../c/vendor/vendor_nanoarrow.sh | 22 +- 64 files changed, 3756 insertions(+), 2619 deletions(-) delete mode 100644 3rd_party/apache-arrow-adbc/c/driver/framework/catalog.cc delete mode 100644 3rd_party/apache-arrow-adbc/c/driver/framework/catalog.h create mode 100644 3rd_party/apache-arrow-adbc/c/driver/framework/utility.cc create mode 100644 3rd_party/apache-arrow-adbc/c/driver/framework/utility.h diff --git a/3rd_party/apache-arrow-adbc/CHANGELOG.md b/3rd_party/apache-arrow-adbc/CHANGELOG.md index 2805d51..eb3a26b 100644 --- a/3rd_party/apache-arrow-adbc/CHANGELOG.md +++ b/3rd_party/apache-arrow-adbc/CHANGELOG.md @@ -699,3 +699,63 @@ - **ci**: update website_build.sh for new versioning scheme (#1972) - **dev/release**: update C# tag (#1973) - **c/vendor/nanoarrow**: Fix -Wreorder warning (#1966) + +## ADBC Libraries 15 (2024-11-08) + +### Versions + +- C/C++/GLib/Go/Python/Ruby: 1.3.0 +- C#: 0.15.0 +- Java: 0.15.0 +- R: 0.15.0 +- Rust: 0.15.0 + +### Feat + +- **c/driver/postgresql**: Enable basic connect/query workflow for Redshift (#2219) +- **rust/drivers/datafusion**: add support for bulk ingest (#2279) +- **csharp/src/Drivers/Apache**: convert Double to Float for Apache Spark on scalar conversion (#2296) +- **go/adbc/driver/snowflake**: update to the latest 1.12.0 gosnowflake driver (#2298) +- **csharp/src/Drivers/BigQuery**: support max stream count setting when creating read session (#2289) +- **rust/drivers**: adbc driver for datafusion (#2267) +- **go/adbc/driver/snowflake**: improve GetObjects performance and semantics (#2254) +- **c**: Implement ingestion and testing for float16, string_view, and binary_view (#2234) +- **r**: Add R BigQuery driver wrapper (#2235) +- **csharp/src/Drivers/Apache/Spark**: add request_timeout_ms option to allow longer HTTP request length (#2218) +- **go/adbc/driver/snowflake**: add support for a client config file (#2197) +- **csharp/src/Client**: Additional parameter support for DbCommand (#2195) +- **csharp/src/Drivers/Apache/Spark**: add option to ignore TLS/SSL certificate exceptions (#2188) +- **csharp/src/Drivers/Apache/Spark**: Perform scalar data type conversion for Spark over HTTP (#2152) +- **csharp/src/Drivers/Apache/Spark**: Azure HDInsight Spark Documentation (#2164) +- **c/driver/postgresql**: Implement ingestion of list types for PostgreSQL (#2153) +- **csharp/src/Drivers/Apache/Spark**: poc - Support for Apache Spark over HTTP (non-Arrow) (#2018) +- **c/driver/postgresql**: add `arrow.opaque` type metadata (#2122) + +### Fix + +- **csharp/src/Drivers/Apache**: fix float data type handling for tests on Databricks Spark (#2283) +- **go/adbc/driver/internal/driverbase**: proper unmarshalling for ConstraintColumnNames (#2285) +- **csharp/src/Drivers/Apache**: fix to workaround concurrency issue (#2282) +- **csharp/src/Drivers/Apache**: correctly handle empty response and add Client tests (#2275) +- **csharp/src/Drivers/Apache**: remove interleaved async look-ahead code (#2273) +- **c/driver_manager**: More robust error reporting for errors that occur before AdbcDatabaseInit() (#2266) +- **rust**: implement database/connection constructors without options (#2242) +- **csharp/src/Drivers**: update System.Text.Json to version 8.0.5 because of known vulnerability (#2238) +- **csharp/src/Drivers/Apache/Spark**: correct batch handling for the HiveServer2Reader (#2215) +- **go/adbc/driver/snowflake**: call GetObjects with null catalog at catalog depth (#2194) +- **csharp/src/Drivers/Apache/Spark**: correct BatchSize implementation for base reader (#2199) +- **csharp/src/Drivers/Apache/Spark**: correct precision/scale handling with zeros in fractional portion (#2198) +- **csharp/src/Drivers/BigQuery**: Fixed GBQ driver issue when results.TableReference is null (#2165) +- **go/adbc/driver/snowflake**: fix setting database and schema context after initial connection (#2169) +- **csharp/src/Drivers/Interop/Snowflake**: add test to demonstrate DEFAULT_ROLE behavior (#2151) +- **c/driver/postgresql**: Improve error reporting for queries that error before the COPY header is sent (#2134) + +### Refactor + +- **c/driver/postgresql**: cleanups for result_helper signatures (#2261) +- **c/driver/postgresql**: Use GetObjectsHelper from framework to build objects (#2189) +- **csharp/src/Drivers/Apache/Spark**: use UTF8 string for data conversion, instead of .NET String (#2192) +- **c/driver/postgresql**: Use Status for error handling in BindStream (#2187) +- **c/driver/postgresql**: Use Status instead of AdbcStatusCode/AdbcError in result helper (#2178) +- **c/driver**: Use non-objects framework components in Postgres driver (#2166) +- **c/driver/postgresql**: Use copy writer in BindStream for parameter binding (#2157) diff --git a/3rd_party/apache-arrow-adbc/CONTRIBUTING.md b/3rd_party/apache-arrow-adbc/CONTRIBUTING.md index 771acd1..c9cbd0a 100644 --- a/3rd_party/apache-arrow-adbc/CONTRIBUTING.md +++ b/3rd_party/apache-arrow-adbc/CONTRIBUTING.md @@ -31,8 +31,8 @@ https://github.com/apache/arrow-adbc/issues Some dependencies are required to build and test the various ADBC packages. For C/C++, you will most likely want a [Conda][conda] installation, -with [Mambaforge][mambaforge] being the most convenient distribution. -If you have Mambaforge installed, you can set up a development +with [Miniforge][miniforge] being the most convenient distribution. +If you have Miniforge installed, you can set up a development environment as follows: ```shell @@ -52,7 +52,7 @@ CMake or other build tool appropriately. However, we primarily develop and support Conda users. [conda]: https://docs.conda.io/en/latest/ -[mambaforge]: https://mamba.readthedocs.io/en/latest/installation/mamba-installation.html +[miniforge]: https://mamba.readthedocs.io/en/latest/installation/mamba-installation.html ### Running Integration Tests diff --git a/3rd_party/apache-arrow-adbc/README.md b/3rd_party/apache-arrow-adbc/README.md index 2cf24ec..cb2b9c0 100644 --- a/3rd_party/apache-arrow-adbc/README.md +++ b/3rd_party/apache-arrow-adbc/README.md @@ -57,4 +57,4 @@ User documentation can be found at https://arrow.apache.org/adbc ## Development and Contributing -For detailed instructions on how to build the various ADBC libraries, see CONTRIBUTING.md. +For detailed instructions on how to build the various ADBC libraries, see [CONTRIBUTING.md](CONTRIBUTING.md). diff --git a/3rd_party/apache-arrow-adbc/c/CMakeLists.txt b/3rd_party/apache-arrow-adbc/c/CMakeLists.txt index f090aaf..be69103 100644 --- a/3rd_party/apache-arrow-adbc/c/CMakeLists.txt +++ b/3rd_party/apache-arrow-adbc/c/CMakeLists.txt @@ -35,10 +35,6 @@ add_subdirectory(vendor/nanoarrow) add_subdirectory(driver/common) add_subdirectory(driver/framework) -install(FILES "${REPOSITORY_ROOT}/c/include/adbc.h" DESTINATION include) -install(FILES "${REPOSITORY_ROOT}/c/include/arrow-adbc/adbc.h" - DESTINATION include/arrow-adbc) - if(ADBC_BUILD_TESTS) add_subdirectory(validation) endif() diff --git a/3rd_party/apache-arrow-adbc/c/apidoc/Doxyfile b/3rd_party/apache-arrow-adbc/c/apidoc/Doxyfile index 204d9a6..2f49242 100644 --- a/3rd_party/apache-arrow-adbc/c/apidoc/Doxyfile +++ b/3rd_party/apache-arrow-adbc/c/apidoc/Doxyfile @@ -500,7 +500,7 @@ EXTRACT_ALL = NO # be included in the documentation. # The default value is: NO. -EXTRACT_PRIVATE = NO +EXTRACT_PRIVATE = YES # If the EXTRACT_PRIV_VIRTUAL tag is set to YES, documented private virtual # methods of a class will be included in the documentation. @@ -891,7 +891,7 @@ WARN_LOGFILE = # spaces. See also FILE_PATTERNS and EXTENSION_MAPPING # Note: If this tag is empty the current directory is searched. -INPUT = ../../c/include/arrow-adbc/adbc.h ../../README.md ../../c/include/arrow-adbc/adbc_driver_manager.h +INPUT = ../../c/include/arrow-adbc/adbc.h ../../README.md ../../c/include/arrow-adbc/adbc_driver_manager.h ../../c/driver/framework/ # This tag can be used to specify the character encoding of the source files # that doxygen parses. Internally doxygen uses the UTF-8 encoding. Doxygen uses @@ -920,12 +920,7 @@ INPUT_ENCODING = UTF-8 # comment), *.py, *.pyw, *.f90, *.f95, *.f03, *.f08, *.f18, *.f, *.for, *.vhd, # *.vhdl, *.ucf, *.qsf and *.ice. -FILE_PATTERNS = *.c \ - *.cc \ - *.cxx \ - *.cpp \ - *.c++ \ - *.java \ +FILE_PATTERNS = *.java \ *.ii \ *.ixx \ *.ipp \ @@ -1007,7 +1002,7 @@ EXCLUDE_PATTERNS = # Note that the wildcards are matched against the file with absolute path, so to # exclude all test directories use the pattern */test/* -EXCLUDE_SYMBOLS = +EXCLUDE_SYMBOLS = ADBC ADBC_DRIVER_MANAGER_H # The EXAMPLE_PATH tag can be used to specify one or more files or directories # that contain example code fragments that are included (see the \include diff --git a/3rd_party/apache-arrow-adbc/c/cmake_modules/AdbcVersion.cmake b/3rd_party/apache-arrow-adbc/c/cmake_modules/AdbcVersion.cmake index dfa1a84..a89cab1 100644 --- a/3rd_party/apache-arrow-adbc/c/cmake_modules/AdbcVersion.cmake +++ b/3rd_party/apache-arrow-adbc/c/cmake_modules/AdbcVersion.cmake @@ -21,7 +21,7 @@ # ------------------------------------------------------------ # Version definitions -set(ADBC_VERSION "1.2.0") +set(ADBC_VERSION "1.3.0") string(REGEX MATCH "^[0-9]+\\.[0-9]+\\.[0-9]+" ADBC_BASE_VERSION "${ADBC_VERSION}") string(REPLACE "." ";" _adbc_version_list "${ADBC_BASE_VERSION}") list(GET _adbc_version_list 0 ADBC_VERSION_MAJOR) diff --git a/3rd_party/apache-arrow-adbc/c/driver/bigquery/bigquery_test.cc b/3rd_party/apache-arrow-adbc/c/driver/bigquery/bigquery_test.cc index ae4ced8..b80f363 100644 --- a/3rd_party/apache-arrow-adbc/c/driver/bigquery/bigquery_test.cc +++ b/3rd_party/apache-arrow-adbc/c/driver/bigquery/bigquery_test.cc @@ -15,15 +15,18 @@ // specific language governing permissions and limitations // under the License. +#include +#include +#include +#include + #include #include #include #include #include #include -#include -#include -#include + #include "validation/adbc_validation.h" #include "validation/adbc_validation_util.h" @@ -120,7 +123,9 @@ class BigQueryQuirks : public adbc_validation::DriverQuirks { create += "` (int64s INT, strings TEXT)"; CHECK_OK(AdbcStatementSetSqlQuery(&statement.value, create.c_str(), error)); CHECK_OK(AdbcStatementExecuteQuery(&statement.value, nullptr, nullptr, error)); - sleep(5); + // XXX: is there a better way to wait for BigQuery? (Why does 'CREATE + // TABLE' not wait for commit?) + std::this_thread::sleep_for(std::chrono::seconds(5)); std::string insert = "INSERT INTO `ADBC_TESTING."; insert += name; diff --git a/3rd_party/apache-arrow-adbc/c/driver/common/utils.c b/3rd_party/apache-arrow-adbc/c/driver/common/utils.c index 6daac65..00ebd51 100644 --- a/3rd_party/apache-arrow-adbc/c/driver/common/utils.c +++ b/3rd_party/apache-arrow-adbc/c/driver/common/utils.c @@ -235,70 +235,8 @@ struct AdbcErrorDetail CommonErrorGetDetail(const struct AdbcError* error, int i }; } -struct SingleBatchArrayStream { - struct ArrowSchema schema; - struct ArrowArray batch; -}; -static const char* SingleBatchArrayStreamGetLastError(struct ArrowArrayStream* stream) { - (void)stream; - return NULL; -} -static int SingleBatchArrayStreamGetNext(struct ArrowArrayStream* stream, - struct ArrowArray* batch) { - if (!stream || !stream->private_data) return EINVAL; - struct SingleBatchArrayStream* impl = - (struct SingleBatchArrayStream*)stream->private_data; - - memcpy(batch, &impl->batch, sizeof(*batch)); - memset(&impl->batch, 0, sizeof(*batch)); - return 0; -} -static int SingleBatchArrayStreamGetSchema(struct ArrowArrayStream* stream, - struct ArrowSchema* schema) { - if (!stream || !stream->private_data) return EINVAL; - struct SingleBatchArrayStream* impl = - (struct SingleBatchArrayStream*)stream->private_data; - - return ArrowSchemaDeepCopy(&impl->schema, schema); -} -static void SingleBatchArrayStreamRelease(struct ArrowArrayStream* stream) { - if (!stream || !stream->private_data) return; - struct SingleBatchArrayStream* impl = - (struct SingleBatchArrayStream*)stream->private_data; - impl->schema.release(&impl->schema); - if (impl->batch.release) impl->batch.release(&impl->batch); - free(impl); - - memset(stream, 0, sizeof(*stream)); -} - -AdbcStatusCode BatchToArrayStream(struct ArrowArray* values, struct ArrowSchema* schema, - struct ArrowArrayStream* stream, - struct AdbcError* error) { - if (!values->release) { - SetError(error, "ArrowArray is not initialized"); - return ADBC_STATUS_INTERNAL; - } else if (!schema->release) { - SetError(error, "ArrowSchema is not initialized"); - return ADBC_STATUS_INTERNAL; - } else if (stream->release) { - SetError(error, "ArrowArrayStream is already initialized"); - return ADBC_STATUS_INTERNAL; - } - - struct SingleBatchArrayStream* impl = - (struct SingleBatchArrayStream*)malloc(sizeof(*impl)); - memcpy(&impl->schema, schema, sizeof(*schema)); - memcpy(&impl->batch, values, sizeof(*values)); - memset(schema, 0, sizeof(*schema)); - memset(values, 0, sizeof(*values)); - stream->private_data = impl; - stream->get_last_error = SingleBatchArrayStreamGetLastError; - stream->get_next = SingleBatchArrayStreamGetNext; - stream->get_schema = SingleBatchArrayStreamGetSchema; - stream->release = SingleBatchArrayStreamRelease; - - return ADBC_STATUS_OK; +bool IsCommonError(const struct AdbcError* error) { + return error->release == ReleaseErrorWithDetails || error->release == ReleaseError; } int StringBuilderInit(struct StringBuilder* builder, size_t initial_size) { diff --git a/3rd_party/apache-arrow-adbc/c/driver/common/utils.h b/3rd_party/apache-arrow-adbc/c/driver/common/utils.h index c61ecb0..d204821 100644 --- a/3rd_party/apache-arrow-adbc/c/driver/common/utils.h +++ b/3rd_party/apache-arrow-adbc/c/driver/common/utils.h @@ -53,6 +53,7 @@ void AppendErrorDetail(struct AdbcError* error, const char* key, const uint8_t* int CommonErrorGetDetailCount(const struct AdbcError* error); struct AdbcErrorDetail CommonErrorGetDetail(const struct AdbcError* error, int index); +bool IsCommonError(const struct AdbcError* error); struct StringBuilder { char* buffer; @@ -68,11 +69,6 @@ void StringBuilderReset(struct StringBuilder* builder); #undef ADBC_CHECK_PRINTF_ATTRIBUTE -/// Wrap a single batch as a stream. -AdbcStatusCode BatchToArrayStream(struct ArrowArray* values, struct ArrowSchema* schema, - struct ArrowArrayStream* stream, - struct AdbcError* error); - /// Check an NanoArrow status code. #define CHECK_NA(CODE, EXPR, ERROR) \ do { \ diff --git a/3rd_party/apache-arrow-adbc/c/driver/flightsql/sqlite_flightsql_test.cc b/3rd_party/apache-arrow-adbc/c/driver/flightsql/sqlite_flightsql_test.cc index 454ea02..4797d58 100644 --- a/3rd_party/apache-arrow-adbc/c/driver/flightsql/sqlite_flightsql_test.cc +++ b/3rd_party/apache-arrow-adbc/c/driver/flightsql/sqlite_flightsql_test.cc @@ -121,6 +121,7 @@ class SqliteFlightSqlQuirks : public adbc_validation::DriverQuirks { bool supports_get_objects() const override { return true; } bool supports_partitioned_data() const override { return true; } bool supports_dynamic_parameter_binding() const override { return true; } + std::string catalog() const override { return "main"; } }; class SqliteFlightSqlTest : public ::testing::Test, public adbc_validation::DatabaseTest { diff --git a/3rd_party/apache-arrow-adbc/c/driver/framework/CMakeLists.txt b/3rd_party/apache-arrow-adbc/c/driver/framework/CMakeLists.txt index 3efc3f1..f5c642b 100644 --- a/3rd_party/apache-arrow-adbc/c/driver/framework/CMakeLists.txt +++ b/3rd_party/apache-arrow-adbc/c/driver/framework/CMakeLists.txt @@ -17,7 +17,7 @@ include(FetchContent) -add_library(adbc_driver_framework STATIC catalog.cc objects.cc) +add_library(adbc_driver_framework STATIC objects.cc utility.cc) adbc_configure_target(adbc_driver_framework) set_target_properties(adbc_driver_framework PROPERTIES POSITION_INDEPENDENT_CODE ON) target_include_directories(adbc_driver_framework diff --git a/3rd_party/apache-arrow-adbc/c/driver/framework/base_driver.h b/3rd_party/apache-arrow-adbc/c/driver/framework/base_driver.h index b52e474..eecb506 100644 --- a/3rd_party/apache-arrow-adbc/c/driver/framework/base_driver.h +++ b/3rd_party/apache-arrow-adbc/c/driver/framework/base_driver.h @@ -455,11 +455,22 @@ class Driver { } auto error_obj = reinterpret_cast(error->private_data); + if (!error_obj) { + return 0; + } return error_obj->CDetailCount(); } static AdbcErrorDetail CErrorGetDetail(const AdbcError* error, int index) { + if (error->vendor_code != ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA) { + return {nullptr, nullptr, 0}; + } + auto error_obj = reinterpret_cast(error->private_data); + if (!error_obj) { + return {nullptr, nullptr, 0}; + } + return error_obj->CDetail(index); } diff --git a/3rd_party/apache-arrow-adbc/c/driver/framework/catalog.cc b/3rd_party/apache-arrow-adbc/c/driver/framework/catalog.cc deleted file mode 100644 index d5b89ea..0000000 --- a/3rd_party/apache-arrow-adbc/c/driver/framework/catalog.cc +++ /dev/null @@ -1,328 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#include "driver/framework/catalog.h" - -#include - -namespace adbc::driver { - -void AdbcMakeArrayStream(struct ArrowSchema* schema, struct ArrowArray* array, - struct ArrowArrayStream* out) { - nanoarrow::VectorArrayStream(schema, array).ToArrayStream(out); -} - -Status AdbcInitConnectionGetInfoSchema(struct ArrowSchema* schema, - struct ArrowArray* array) { - ArrowSchemaInit(schema); - UNWRAP_ERRNO(Internal, ArrowSchemaSetTypeStruct(schema, /*num_columns=*/2)); - - UNWRAP_ERRNO(Internal, ArrowSchemaSetType(schema->children[0], NANOARROW_TYPE_UINT32)); - UNWRAP_ERRNO(Internal, ArrowSchemaSetName(schema->children[0], "info_name")); - schema->children[0]->flags &= ~ARROW_FLAG_NULLABLE; - - struct ArrowSchema* info_value = schema->children[1]; - UNWRAP_ERRNO(Internal, - ArrowSchemaSetTypeUnion(info_value, NANOARROW_TYPE_DENSE_UNION, 6)); - UNWRAP_ERRNO(Internal, ArrowSchemaSetName(info_value, "info_value")); - - UNWRAP_ERRNO(Internal, - ArrowSchemaSetType(info_value->children[0], NANOARROW_TYPE_STRING)); - UNWRAP_ERRNO(Internal, ArrowSchemaSetName(info_value->children[0], "string_value")); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetType(info_value->children[1], NANOARROW_TYPE_BOOL)); - UNWRAP_ERRNO(Internal, ArrowSchemaSetName(info_value->children[1], "bool_value")); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetType(info_value->children[2], NANOARROW_TYPE_INT64)); - UNWRAP_ERRNO(Internal, ArrowSchemaSetName(info_value->children[2], "int64_value")); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetType(info_value->children[3], NANOARROW_TYPE_INT32)); - UNWRAP_ERRNO(Internal, ArrowSchemaSetName(info_value->children[3], "int32_bitmask")); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetType(info_value->children[4], NANOARROW_TYPE_LIST)); - UNWRAP_ERRNO(Internal, ArrowSchemaSetName(info_value->children[4], "string_list")); - UNWRAP_ERRNO(Internal, ArrowSchemaSetType(info_value->children[5], NANOARROW_TYPE_MAP)); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetName(info_value->children[5], "int32_to_int32_list_map")); - - UNWRAP_ERRNO(Internal, ArrowSchemaSetType(info_value->children[4]->children[0], - NANOARROW_TYPE_STRING)); - - UNWRAP_ERRNO(Internal, - ArrowSchemaSetType(info_value->children[5]->children[0]->children[0], - NANOARROW_TYPE_INT32)); - info_value->children[5]->children[0]->children[0]->flags &= ~ARROW_FLAG_NULLABLE; - UNWRAP_ERRNO(Internal, - ArrowSchemaSetType(info_value->children[5]->children[0]->children[1], - NANOARROW_TYPE_LIST)); - UNWRAP_ERRNO( - Internal, - ArrowSchemaSetType(info_value->children[5]->children[0]->children[1]->children[0], - NANOARROW_TYPE_INT32)); - - struct ArrowError na_error = {0}; - UNWRAP_NANOARROW(na_error, Internal, - ArrowArrayInitFromSchema(array, schema, &na_error)); - UNWRAP_ERRNO(Internal, ArrowArrayStartAppending(array)); - - return status::Ok(); -} - -Status AdbcConnectionGetInfoAppendString(struct ArrowArray* array, uint32_t info_code, - std::string_view info_value) { - UNWRAP_ERRNO(Internal, ArrowArrayAppendUInt(array->children[0], info_code)); - // Append to type variant - struct ArrowStringView value; - value.data = info_value.data(); - value.size_bytes = static_cast(info_value.size()); - UNWRAP_ERRNO(Internal, ArrowArrayAppendString(array->children[1]->children[0], value)); - // Append type code/offset - UNWRAP_ERRNO(Internal, ArrowArrayFinishUnionElement(array->children[1], /*type_id=*/0)); - return status::Ok(); -} - -Status AdbcConnectionGetInfoAppendInt(struct ArrowArray* array, uint32_t info_code, - int64_t info_value) { - UNWRAP_ERRNO(Internal, ArrowArrayAppendUInt(array->children[0], info_code)); - // Append to type variant - UNWRAP_ERRNO(Internal, - ArrowArrayAppendInt(array->children[1]->children[2], info_value)); - // Append type code/offset - UNWRAP_ERRNO(Internal, ArrowArrayFinishUnionElement(array->children[1], /*type_id=*/2)); - return status::Ok(); -} - -Status AdbcInitConnectionObjectsSchema(struct ArrowSchema* schema) { - ArrowSchemaInit(schema); - UNWRAP_ERRNO(Internal, ArrowSchemaSetTypeStruct(schema, /*num_columns=*/2)); - UNWRAP_ERRNO(Internal, ArrowSchemaSetType(schema->children[0], NANOARROW_TYPE_STRING)); - UNWRAP_ERRNO(Internal, ArrowSchemaSetName(schema->children[0], "catalog_name")); - UNWRAP_ERRNO(Internal, ArrowSchemaSetType(schema->children[1], NANOARROW_TYPE_LIST)); - UNWRAP_ERRNO(Internal, ArrowSchemaSetName(schema->children[1], "catalog_db_schemas")); - UNWRAP_ERRNO(Internal, ArrowSchemaSetTypeStruct(schema->children[1]->children[0], 2)); - - struct ArrowSchema* db_schema_schema = schema->children[1]->children[0]; - UNWRAP_ERRNO(Internal, - ArrowSchemaSetType(db_schema_schema->children[0], NANOARROW_TYPE_STRING)); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetName(db_schema_schema->children[0], "db_schema_name")); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetType(db_schema_schema->children[1], NANOARROW_TYPE_LIST)); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetName(db_schema_schema->children[1], "db_schema_tables")); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetTypeStruct(db_schema_schema->children[1]->children[0], 4)); - - struct ArrowSchema* table_schema = db_schema_schema->children[1]->children[0]; - UNWRAP_ERRNO(Internal, - ArrowSchemaSetType(table_schema->children[0], NANOARROW_TYPE_STRING)); - UNWRAP_ERRNO(Internal, ArrowSchemaSetName(table_schema->children[0], "table_name")); - table_schema->children[0]->flags &= ~ARROW_FLAG_NULLABLE; - UNWRAP_ERRNO(Internal, - ArrowSchemaSetType(table_schema->children[1], NANOARROW_TYPE_STRING)); - UNWRAP_ERRNO(Internal, ArrowSchemaSetName(table_schema->children[1], "table_type")); - table_schema->children[1]->flags &= ~ARROW_FLAG_NULLABLE; - UNWRAP_ERRNO(Internal, - ArrowSchemaSetType(table_schema->children[2], NANOARROW_TYPE_LIST)); - UNWRAP_ERRNO(Internal, ArrowSchemaSetName(table_schema->children[2], "table_columns")); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetTypeStruct(table_schema->children[2]->children[0], 19)); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetType(table_schema->children[3], NANOARROW_TYPE_LIST)); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetName(table_schema->children[3], "table_constraints")); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetTypeStruct(table_schema->children[3]->children[0], 4)); - - struct ArrowSchema* column_schema = table_schema->children[2]->children[0]; - UNWRAP_ERRNO(Internal, - ArrowSchemaSetType(column_schema->children[0], NANOARROW_TYPE_STRING)); - UNWRAP_ERRNO(Internal, ArrowSchemaSetName(column_schema->children[0], "column_name")); - column_schema->children[0]->flags &= ~ARROW_FLAG_NULLABLE; - UNWRAP_ERRNO(Internal, - ArrowSchemaSetType(column_schema->children[1], NANOARROW_TYPE_INT32)); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetName(column_schema->children[1], "ordinal_position")); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetType(column_schema->children[2], NANOARROW_TYPE_STRING)); - UNWRAP_ERRNO(Internal, ArrowSchemaSetName(column_schema->children[2], "remarks")); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetType(column_schema->children[3], NANOARROW_TYPE_INT16)); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetName(column_schema->children[3], "xdbc_data_type")); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetType(column_schema->children[4], NANOARROW_TYPE_STRING)); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetName(column_schema->children[4], "xdbc_type_name")); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetType(column_schema->children[5], NANOARROW_TYPE_INT32)); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetName(column_schema->children[5], "xdbc_column_size")); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetType(column_schema->children[6], NANOARROW_TYPE_INT16)); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetName(column_schema->children[6], "xdbc_decimal_digits")); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetType(column_schema->children[7], NANOARROW_TYPE_INT16)); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetName(column_schema->children[7], "xdbc_num_prec_radix")); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetType(column_schema->children[8], NANOARROW_TYPE_INT16)); - UNWRAP_ERRNO(Internal, ArrowSchemaSetName(column_schema->children[8], "xdbc_nullable")); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetType(column_schema->children[9], NANOARROW_TYPE_STRING)); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetName(column_schema->children[9], "xdbc_column_def")); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetType(column_schema->children[10], NANOARROW_TYPE_INT16)); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetName(column_schema->children[10], "xdbc_sql_data_type")); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetType(column_schema->children[11], NANOARROW_TYPE_INT16)); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetName(column_schema->children[11], "xdbc_datetime_sub")); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetType(column_schema->children[12], NANOARROW_TYPE_INT32)); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetName(column_schema->children[12], "xdbc_char_octet_length")); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetType(column_schema->children[13], NANOARROW_TYPE_STRING)); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetName(column_schema->children[13], "xdbc_is_nullable")); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetType(column_schema->children[14], NANOARROW_TYPE_STRING)); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetName(column_schema->children[14], "xdbc_scope_catalog")); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetType(column_schema->children[15], NANOARROW_TYPE_STRING)); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetName(column_schema->children[15], "xdbc_scope_schema")); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetType(column_schema->children[16], NANOARROW_TYPE_STRING)); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetName(column_schema->children[16], "xdbc_scope_table")); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetType(column_schema->children[17], NANOARROW_TYPE_BOOL)); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetName(column_schema->children[17], "xdbc_is_autoincrement")); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetType(column_schema->children[18], NANOARROW_TYPE_BOOL)); - UNWRAP_ERRNO(Internal, ArrowSchemaSetName(column_schema->children[18], - "xdbc_is_generatedcolumn")); - - struct ArrowSchema* constraint_schema = table_schema->children[3]->children[0]; - UNWRAP_ERRNO(Internal, - ArrowSchemaSetType(constraint_schema->children[0], NANOARROW_TYPE_STRING)); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetName(constraint_schema->children[0], "constraint_name")); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetType(constraint_schema->children[1], NANOARROW_TYPE_STRING)); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetName(constraint_schema->children[1], "constraint_type")); - constraint_schema->children[1]->flags &= ~ARROW_FLAG_NULLABLE; - UNWRAP_ERRNO(Internal, - ArrowSchemaSetType(constraint_schema->children[2], NANOARROW_TYPE_LIST)); - UNWRAP_ERRNO(Internal, ArrowSchemaSetName(constraint_schema->children[2], - "constraint_column_names")); - UNWRAP_ERRNO(Internal, ArrowSchemaSetType(constraint_schema->children[2]->children[0], - NANOARROW_TYPE_STRING)); - constraint_schema->children[2]->flags &= ~ARROW_FLAG_NULLABLE; - UNWRAP_ERRNO(Internal, - ArrowSchemaSetType(constraint_schema->children[3], NANOARROW_TYPE_LIST)); - UNWRAP_ERRNO(Internal, ArrowSchemaSetName(constraint_schema->children[3], - "constraint_column_usage")); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetTypeStruct(constraint_schema->children[3]->children[0], 4)); - - struct ArrowSchema* usage_schema = constraint_schema->children[3]->children[0]; - UNWRAP_ERRNO(Internal, - ArrowSchemaSetType(usage_schema->children[0], NANOARROW_TYPE_STRING)); - UNWRAP_ERRNO(Internal, ArrowSchemaSetName(usage_schema->children[0], "fk_catalog")); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetType(usage_schema->children[1], NANOARROW_TYPE_STRING)); - UNWRAP_ERRNO(Internal, ArrowSchemaSetName(usage_schema->children[1], "fk_db_schema")); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetType(usage_schema->children[2], NANOARROW_TYPE_STRING)); - UNWRAP_ERRNO(Internal, ArrowSchemaSetName(usage_schema->children[2], "fk_table")); - usage_schema->children[2]->flags &= ~ARROW_FLAG_NULLABLE; - UNWRAP_ERRNO(Internal, - ArrowSchemaSetType(usage_schema->children[3], NANOARROW_TYPE_STRING)); - UNWRAP_ERRNO(Internal, ArrowSchemaSetName(usage_schema->children[3], "fk_column_name")); - usage_schema->children[3]->flags &= ~ARROW_FLAG_NULLABLE; - - return status::Ok(); -} - -Status AdbcGetInfo(std::vector infos, struct ArrowArrayStream* out) { - nanoarrow::UniqueSchema schema; - nanoarrow::UniqueArray array; - - UNWRAP_STATUS(AdbcInitConnectionGetInfoSchema(schema.get(), array.get())); - - for (const auto& info : infos) { - UNWRAP_STATUS(std::visit( - [&](auto&& value) -> Status { - using T = std::decay_t; - if constexpr (std::is_same_v) { - return AdbcConnectionGetInfoAppendString(array.get(), info.code, value); - } else if constexpr (std::is_same_v) { - return AdbcConnectionGetInfoAppendInt(array.get(), info.code, value); - } else { - static_assert(!sizeof(T), "info value type not implemented"); - } - return status::Ok(); - }, - info.value)); - UNWRAP_ERRNO(Internal, ArrowArrayFinishElement(array.get())); - } - - struct ArrowError na_error = {0}; - UNWRAP_NANOARROW(na_error, Internal, - ArrowArrayFinishBuildingDefault(array.get(), &na_error)); - nanoarrow::VectorArrayStream(schema.get(), array.get()).ToArrayStream(out); - return status::Ok(); -} - -Status AdbcGetTableTypes(const std::vector& table_types, - struct ArrowArrayStream* out) { - nanoarrow::UniqueArray array; - nanoarrow::UniqueSchema schema; - ArrowSchemaInit(schema.get()); - - UNWRAP_ERRNO(Internal, ArrowSchemaSetType(schema.get(), NANOARROW_TYPE_STRUCT)); - UNWRAP_ERRNO(Internal, ArrowSchemaAllocateChildren(schema.get(), /*num_columns=*/1)); - ArrowSchemaInit(schema.get()->children[0]); - UNWRAP_ERRNO(Internal, - ArrowSchemaSetType(schema.get()->children[0], NANOARROW_TYPE_STRING)); - UNWRAP_ERRNO(Internal, ArrowSchemaSetName(schema.get()->children[0], "table_type")); - schema.get()->children[0]->flags &= ~ARROW_FLAG_NULLABLE; - - UNWRAP_ERRNO(Internal, ArrowArrayInitFromSchema(array.get(), schema.get(), NULL)); - UNWRAP_ERRNO(Internal, ArrowArrayStartAppending(array.get())); - - for (std::string const& table_type : table_types) { - UNWRAP_ERRNO(Internal, ArrowArrayAppendString(array->children[0], - ArrowCharView(table_type.c_str()))); - UNWRAP_ERRNO(Internal, ArrowArrayFinishElement(array.get())); - } - - UNWRAP_ERRNO(Internal, ArrowArrayFinishBuildingDefault(array.get(), nullptr)); - nanoarrow::VectorArrayStream(schema.get(), array.get()).ToArrayStream(out); - return status::Ok(); -} - -} // namespace adbc::driver diff --git a/3rd_party/apache-arrow-adbc/c/driver/framework/catalog.h b/3rd_party/apache-arrow-adbc/c/driver/framework/catalog.h deleted file mode 100644 index 8c0eff1..0000000 --- a/3rd_party/apache-arrow-adbc/c/driver/framework/catalog.h +++ /dev/null @@ -1,162 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#pragma once - -#include -#include -#include -#include -#include -#include -#include - -#include - -#include "driver/framework/status.h" - -namespace adbc::driver { - -/// \defgroup adbc-framework-catalog Catalog Utilities -/// Utilities for implementing catalog/metadata-related functions. -/// -/// @{ - -/// \brief The GetObjects level. -enum class GetObjectsDepth { - kCatalogs, - kSchemas, - kTables, - kColumns, -}; - -/// \brief Helper to implement GetObjects. -struct GetObjectsHelper { - virtual ~GetObjectsHelper() = default; - - struct Table { - std::string_view name; - std::string_view type; - }; - - struct ColumnXdbc { - std::optional xdbc_data_type; - std::optional xdbc_type_name; - std::optional xdbc_column_size; - std::optional xdbc_decimal_digits; - std::optional xdbc_num_prec_radix; - std::optional xdbc_nullable; - std::optional xdbc_column_def; - std::optional xdbc_sql_data_type; - std::optional xdbc_datetime_sub; - std::optional xdbc_char_octet_length; - std::optional xdbc_is_nullable; - std::optional xdbc_scope_catalog; - std::optional xdbc_scope_schema; - std::optional xdbc_scope_table; - std::optional xdbc_is_autoincrement; - std::optional xdbc_is_generatedcolumn; - }; - - struct Column { - std::string_view column_name; - int32_t ordinal_position; - std::optional remarks; - std::optional xdbc; - }; - - struct ConstraintUsage { - std::optional catalog; - std::optional schema; - std::string_view table; - std::string_view column; - }; - - struct Constraint { - std::optional name; - std::string_view type; - std::vector column_names; - std::optional> usage; - }; - - Status Close() { return status::Ok(); } - - /// \brief Fetch all metadata needed. The driver is free to delay loading - /// but this gives it a chance to load data up front. - virtual Status Load(GetObjectsDepth depth, - std::optional catalog_filter, - std::optional schema_filter, - std::optional table_filter, - std::optional column_filter, - const std::vector& table_types) { - return status::NotImplemented("GetObjects"); - } - - virtual Status LoadCatalogs() { - return status::NotImplemented("GetObjects at depth = catalog"); - }; - - virtual Result> NextCatalog() { return std::nullopt; } - - virtual Status LoadSchemas(std::string_view catalog) { - return status::NotImplemented("GetObjects at depth = schema"); - }; - - virtual Result> NextSchema() { return std::nullopt; } - - virtual Status LoadTables(std::string_view catalog, std::string_view schema) { - return status::NotImplemented("GetObjects at depth = table"); - }; - - virtual Result> NextTable() { return std::nullopt; } - - virtual Status LoadColumns(std::string_view catalog, std::string_view schema, - std::string_view table) { - return status::NotImplemented("GetObjects at depth = column"); - }; - - virtual Result> NextColumn() { return std::nullopt; } - - virtual Result> NextConstraint() { return std::nullopt; } -}; - -struct InfoValue { - uint32_t code; - std::variant value; - - explicit InfoValue(uint32_t code, std::variant value) - : code(code), value(std::move(value)) {} -}; - -void AdbcMakeArrayStream(struct ArrowSchema* schema, struct ArrowArray* array, - struct ArrowArrayStream* out); - -Status AdbcGetInfo(std::vector infos, struct ArrowArrayStream* out); - -Status AdbcGetTableTypes(const std::vector& table_types, - struct ArrowArrayStream* out); - -Status AdbcInitConnectionGetInfoSchema(struct ArrowSchema* schema, - struct ArrowArray* array); -Status AdbcConnectionGetInfoAppendString(struct ArrowArray* array, uint32_t info_code, - std::string_view info_value); -Status AdbcConnectionGetInfoAppendInt(struct ArrowArray* array, uint32_t info_code, - int64_t info_value); -Status AdbcInitConnectionObjectsSchema(struct ArrowSchema* schema); -/// @} - -} // namespace adbc::driver diff --git a/3rd_party/apache-arrow-adbc/c/driver/framework/connection.h b/3rd_party/apache-arrow-adbc/c/driver/framework/connection.h index f9e329e..da3aae1 100644 --- a/3rd_party/apache-arrow-adbc/c/driver/framework/connection.h +++ b/3rd_party/apache-arrow-adbc/c/driver/framework/connection.h @@ -27,8 +27,8 @@ #include #include "driver/framework/base_driver.h" -#include "driver/framework/catalog.h" #include "driver/framework/objects.h" +#include "driver/framework/utility.h" namespace adbc::driver { /// \brief The CRTP base implementation of an AdbcConnection. @@ -86,7 +86,7 @@ class Connection : public ObjectBase { std::vector codes(info_codes, info_codes + info_codes_length); RAISE_RESULT(error, auto infos, impl().InfoImpl(codes)); - RAISE_STATUS(error, AdbcGetInfo(infos, out)); + RAISE_STATUS(error, MakeGetInfoStream(infos, out)); return ADBC_STATUS_OK; } @@ -204,7 +204,7 @@ class Connection : public ObjectBase { } RAISE_RESULT(error, std::vector table_types, impl().GetTableTypesImpl()); - RAISE_STATUS(error, AdbcGetTableTypes(table_types, out)); + RAISE_STATUS(error, MakeTableTypesStream(table_types, out)); return ADBC_STATUS_OK; } diff --git a/3rd_party/apache-arrow-adbc/c/driver/framework/meson.build b/3rd_party/apache-arrow-adbc/c/driver/framework/meson.build index 432cc5b..08be53e 100644 --- a/3rd_party/apache-arrow-adbc/c/driver/framework/meson.build +++ b/3rd_party/apache-arrow-adbc/c/driver/framework/meson.build @@ -18,8 +18,8 @@ adbc_framework_lib = library( 'adbc_driver_framework', sources: [ - 'catalog.cc', 'objects.cc', + 'utility.cc', ], include_directories: [include_dir, c_dir], link_with: [adbc_common_lib], diff --git a/3rd_party/apache-arrow-adbc/c/driver/framework/objects.cc b/3rd_party/apache-arrow-adbc/c/driver/framework/objects.cc index 1d5e910..691f6e4 100644 --- a/3rd_party/apache-arrow-adbc/c/driver/framework/objects.cc +++ b/3rd_party/apache-arrow-adbc/c/driver/framework/objects.cc @@ -21,11 +21,172 @@ #include "nanoarrow/nanoarrow.hpp" -#include "driver/framework/catalog.h" #include "driver/framework/status.h" +#include "driver/framework/utility.h" namespace adbc::driver { +Status MakeGetObjectsSchema(struct ArrowSchema* schema) { + ArrowSchemaInit(schema); + UNWRAP_ERRNO(Internal, ArrowSchemaSetTypeStruct(schema, /*num_columns=*/2)); + UNWRAP_ERRNO(Internal, ArrowSchemaSetType(schema->children[0], NANOARROW_TYPE_STRING)); + UNWRAP_ERRNO(Internal, ArrowSchemaSetName(schema->children[0], "catalog_name")); + UNWRAP_ERRNO(Internal, ArrowSchemaSetType(schema->children[1], NANOARROW_TYPE_LIST)); + UNWRAP_ERRNO(Internal, ArrowSchemaSetName(schema->children[1], "catalog_db_schemas")); + UNWRAP_ERRNO(Internal, ArrowSchemaSetTypeStruct(schema->children[1]->children[0], 2)); + + struct ArrowSchema* db_schema_schema = schema->children[1]->children[0]; + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(db_schema_schema->children[0], NANOARROW_TYPE_STRING)); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetName(db_schema_schema->children[0], "db_schema_name")); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(db_schema_schema->children[1], NANOARROW_TYPE_LIST)); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetName(db_schema_schema->children[1], "db_schema_tables")); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetTypeStruct(db_schema_schema->children[1]->children[0], 4)); + + struct ArrowSchema* table_schema = db_schema_schema->children[1]->children[0]; + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(table_schema->children[0], NANOARROW_TYPE_STRING)); + UNWRAP_ERRNO(Internal, ArrowSchemaSetName(table_schema->children[0], "table_name")); + table_schema->children[0]->flags &= ~ARROW_FLAG_NULLABLE; + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(table_schema->children[1], NANOARROW_TYPE_STRING)); + UNWRAP_ERRNO(Internal, ArrowSchemaSetName(table_schema->children[1], "table_type")); + table_schema->children[1]->flags &= ~ARROW_FLAG_NULLABLE; + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(table_schema->children[2], NANOARROW_TYPE_LIST)); + UNWRAP_ERRNO(Internal, ArrowSchemaSetName(table_schema->children[2], "table_columns")); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetTypeStruct(table_schema->children[2]->children[0], 19)); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(table_schema->children[3], NANOARROW_TYPE_LIST)); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetName(table_schema->children[3], "table_constraints")); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetTypeStruct(table_schema->children[3]->children[0], 4)); + + struct ArrowSchema* column_schema = table_schema->children[2]->children[0]; + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(column_schema->children[0], NANOARROW_TYPE_STRING)); + UNWRAP_ERRNO(Internal, ArrowSchemaSetName(column_schema->children[0], "column_name")); + column_schema->children[0]->flags &= ~ARROW_FLAG_NULLABLE; + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(column_schema->children[1], NANOARROW_TYPE_INT32)); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetName(column_schema->children[1], "ordinal_position")); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(column_schema->children[2], NANOARROW_TYPE_STRING)); + UNWRAP_ERRNO(Internal, ArrowSchemaSetName(column_schema->children[2], "remarks")); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(column_schema->children[3], NANOARROW_TYPE_INT16)); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetName(column_schema->children[3], "xdbc_data_type")); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(column_schema->children[4], NANOARROW_TYPE_STRING)); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetName(column_schema->children[4], "xdbc_type_name")); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(column_schema->children[5], NANOARROW_TYPE_INT32)); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetName(column_schema->children[5], "xdbc_column_size")); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(column_schema->children[6], NANOARROW_TYPE_INT16)); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetName(column_schema->children[6], "xdbc_decimal_digits")); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(column_schema->children[7], NANOARROW_TYPE_INT16)); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetName(column_schema->children[7], "xdbc_num_prec_radix")); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(column_schema->children[8], NANOARROW_TYPE_INT16)); + UNWRAP_ERRNO(Internal, ArrowSchemaSetName(column_schema->children[8], "xdbc_nullable")); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(column_schema->children[9], NANOARROW_TYPE_STRING)); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetName(column_schema->children[9], "xdbc_column_def")); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(column_schema->children[10], NANOARROW_TYPE_INT16)); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetName(column_schema->children[10], "xdbc_sql_data_type")); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(column_schema->children[11], NANOARROW_TYPE_INT16)); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetName(column_schema->children[11], "xdbc_datetime_sub")); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(column_schema->children[12], NANOARROW_TYPE_INT32)); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetName(column_schema->children[12], "xdbc_char_octet_length")); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(column_schema->children[13], NANOARROW_TYPE_STRING)); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetName(column_schema->children[13], "xdbc_is_nullable")); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(column_schema->children[14], NANOARROW_TYPE_STRING)); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetName(column_schema->children[14], "xdbc_scope_catalog")); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(column_schema->children[15], NANOARROW_TYPE_STRING)); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetName(column_schema->children[15], "xdbc_scope_schema")); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(column_schema->children[16], NANOARROW_TYPE_STRING)); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetName(column_schema->children[16], "xdbc_scope_table")); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(column_schema->children[17], NANOARROW_TYPE_BOOL)); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetName(column_schema->children[17], "xdbc_is_autoincrement")); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(column_schema->children[18], NANOARROW_TYPE_BOOL)); + UNWRAP_ERRNO(Internal, ArrowSchemaSetName(column_schema->children[18], + "xdbc_is_generatedcolumn")); + + struct ArrowSchema* constraint_schema = table_schema->children[3]->children[0]; + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(constraint_schema->children[0], NANOARROW_TYPE_STRING)); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetName(constraint_schema->children[0], "constraint_name")); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(constraint_schema->children[1], NANOARROW_TYPE_STRING)); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetName(constraint_schema->children[1], "constraint_type")); + constraint_schema->children[1]->flags &= ~ARROW_FLAG_NULLABLE; + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(constraint_schema->children[2], NANOARROW_TYPE_LIST)); + UNWRAP_ERRNO(Internal, ArrowSchemaSetName(constraint_schema->children[2], + "constraint_column_names")); + UNWRAP_ERRNO(Internal, ArrowSchemaSetType(constraint_schema->children[2]->children[0], + NANOARROW_TYPE_STRING)); + constraint_schema->children[2]->flags &= ~ARROW_FLAG_NULLABLE; + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(constraint_schema->children[3], NANOARROW_TYPE_LIST)); + UNWRAP_ERRNO(Internal, ArrowSchemaSetName(constraint_schema->children[3], + "constraint_column_usage")); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetTypeStruct(constraint_schema->children[3]->children[0], 4)); + + struct ArrowSchema* usage_schema = constraint_schema->children[3]->children[0]; + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(usage_schema->children[0], NANOARROW_TYPE_STRING)); + UNWRAP_ERRNO(Internal, ArrowSchemaSetName(usage_schema->children[0], "fk_catalog")); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(usage_schema->children[1], NANOARROW_TYPE_STRING)); + UNWRAP_ERRNO(Internal, ArrowSchemaSetName(usage_schema->children[1], "fk_db_schema")); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(usage_schema->children[2], NANOARROW_TYPE_STRING)); + UNWRAP_ERRNO(Internal, ArrowSchemaSetName(usage_schema->children[2], "fk_table")); + usage_schema->children[2]->flags &= ~ARROW_FLAG_NULLABLE; + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(usage_schema->children[3], NANOARROW_TYPE_STRING)); + UNWRAP_ERRNO(Internal, ArrowSchemaSetName(usage_schema->children[3], "fk_column_name")); + usage_schema->children[3]->flags &= ~ARROW_FLAG_NULLABLE; + + return status::Ok(); +} + namespace { /// \brief A helper to convert std::string_view to Nanoarrow's ArrowStringView. ArrowStringView ToStringView(std::string_view s) { @@ -115,7 +276,7 @@ struct GetObjectsBuilder { private: Status InitArrowArray() { - UNWRAP_STATUS(AdbcInitConnectionObjectsSchema(schema)); + UNWRAP_STATUS(MakeGetObjectsSchema(schema)); UNWRAP_NANOARROW(na_error, Internal, ArrowArrayInitFromSchema(array, schema, &na_error)); UNWRAP_ERRNO(Internal, ArrowArrayStartAppending(array)); @@ -123,7 +284,7 @@ struct GetObjectsBuilder { } Status AppendCatalogs() { - UNWRAP_STATUS(helper->LoadCatalogs()); + UNWRAP_STATUS(helper->LoadCatalogs(catalog_filter)); while (true) { UNWRAP_RESULT(auto maybe_catalog, helper->NextCatalog()); if (!maybe_catalog.has_value()) break; @@ -141,7 +302,7 @@ struct GetObjectsBuilder { } Status AppendSchemas(std::string_view catalog) { - UNWRAP_STATUS(helper->LoadSchemas(catalog)); + UNWRAP_STATUS(helper->LoadSchemas(catalog, schema_filter)); while (true) { UNWRAP_RESULT(auto maybe_schema, helper->NextSchema()); if (!maybe_schema.has_value()) break; @@ -162,7 +323,7 @@ struct GetObjectsBuilder { } Status AppendTables(std::string_view catalog, std::string_view schema) { - UNWRAP_STATUS(helper->LoadTables(catalog, schema)); + UNWRAP_STATUS(helper->LoadTables(catalog, schema, table_filter, table_types)); while (true) { UNWRAP_RESULT(auto maybe_table, helper->NextTable()); if (!maybe_table.has_value()) break; @@ -187,7 +348,7 @@ struct GetObjectsBuilder { Status AppendColumns(std::string_view catalog, std::string_view schema, std::string_view table) { - UNWRAP_STATUS(helper->LoadColumns(catalog, schema, table)); + UNWRAP_STATUS(helper->LoadColumns(catalog, schema, table, column_filter)); while (true) { UNWRAP_RESULT(auto maybe_column, helper->NextColumn()); if (!maybe_column.has_value()) break; @@ -365,7 +526,7 @@ Status BuildGetObjects(GetObjectsHelper* helper, GetObjectsDepth depth, table_filter, column_filter, table_types, schema.get(), array.get()) .Build()); - nanoarrow::VectorArrayStream(schema.get(), array.get()).ToArrayStream(out); + MakeArrayStream(schema.get(), array.get(), out); return status::Ok(); } } // namespace adbc::driver diff --git a/3rd_party/apache-arrow-adbc/c/driver/framework/objects.h b/3rd_party/apache-arrow-adbc/c/driver/framework/objects.h index ffd2004..3e74e78 100644 --- a/3rd_party/apache-arrow-adbc/c/driver/framework/objects.h +++ b/3rd_party/apache-arrow-adbc/c/driver/framework/objects.h @@ -21,13 +21,125 @@ #include #include -#include - -#include "driver/framework/catalog.h" +#include #include "driver/framework/status.h" #include "driver/framework/type_fwd.h" namespace adbc::driver { + +/// \defgroup adbc-framework-catalog Catalog Utilities +/// Utilities for implementing catalog/metadata-related functions. +/// +/// @{ + +/// \brief Create the ArrowSchema for AdbcConnectionGetObjects(). +Status MakeGetObjectsSchema(ArrowSchema* schema); + +/// \brief The GetObjects level. +enum class GetObjectsDepth { + kCatalogs, + kSchemas, + kTables, + kColumns, +}; + +/// \brief Helper to implement GetObjects. +/// +/// Drivers can implement methods of the GetObjectsHelper in a driver-specific +/// class to get a compliant implementation of AdbcConnectionGetObjects(). +struct GetObjectsHelper { + virtual ~GetObjectsHelper() = default; + + struct Table { + std::string_view name; + std::string_view type; + }; + + struct ColumnXdbc { + std::optional xdbc_data_type; + std::optional xdbc_type_name; + std::optional xdbc_column_size; + std::optional xdbc_decimal_digits; + std::optional xdbc_num_prec_radix; + std::optional xdbc_nullable; + std::optional xdbc_column_def; + std::optional xdbc_sql_data_type; + std::optional xdbc_datetime_sub; + std::optional xdbc_char_octet_length; + std::optional xdbc_is_nullable; + std::optional xdbc_scope_catalog; + std::optional xdbc_scope_schema; + std::optional xdbc_scope_table; + std::optional xdbc_is_autoincrement; + std::optional xdbc_is_generatedcolumn; + }; + + struct Column { + std::string_view column_name; + int32_t ordinal_position; + std::optional remarks; + std::optional xdbc; + }; + + struct ConstraintUsage { + std::optional catalog; + std::optional schema; + std::string_view table; + std::string_view column; + }; + + struct Constraint { + std::optional name; + std::string_view type; + std::vector column_names; + std::optional> usage; + }; + + Status Close() { return status::Ok(); } + + /// \brief Fetch all metadata needed. The driver is free to delay loading + /// but this gives it a chance to load data up front. + virtual Status Load(GetObjectsDepth depth, + std::optional catalog_filter, + std::optional schema_filter, + std::optional table_filter, + std::optional column_filter, + const std::vector& table_types) { + return status::NotImplemented("GetObjects"); + } + + virtual Status LoadCatalogs(std::optional catalog_filter) { + return status::NotImplemented("GetObjects at depth = catalog"); + }; + + virtual Result> NextCatalog() { return std::nullopt; } + + virtual Status LoadSchemas(std::string_view catalog, + std::optional schema_filter) { + return status::NotImplemented("GetObjects at depth = schema"); + }; + + virtual Result> NextSchema() { return std::nullopt; } + + virtual Status LoadTables(std::string_view catalog, std::string_view schema, + std::optional table_filter, + const std::vector& table_types) { + return status::NotImplemented("GetObjects at depth = table"); + }; + + virtual Result> NextTable() { return std::nullopt; } + + virtual Status LoadColumns(std::string_view catalog, std::string_view schema, + std::string_view table, + std::optional column_filter) { + return status::NotImplemented("GetObjects at depth = column"); + }; + + virtual Result> NextColumn() { return std::nullopt; } + + virtual Result> NextConstraint() { return std::nullopt; } +}; + /// \brief A helper that implements GetObjects. /// The out/helper lifetime are caller-managed. Status BuildGetObjects(GetObjectsHelper* helper, GetObjectsDepth depth, @@ -36,5 +148,5 @@ Status BuildGetObjects(GetObjectsHelper* helper, GetObjectsDepth depth, std::optional table_filter, std::optional column_filter, const std::vector& table_types, - struct ArrowArrayStream* out); + ArrowArrayStream* out); } // namespace adbc::driver diff --git a/3rd_party/apache-arrow-adbc/c/driver/framework/statement.h b/3rd_party/apache-arrow-adbc/c/driver/framework/statement.h index af9eee0..c073248 100644 --- a/3rd_party/apache-arrow-adbc/c/driver/framework/statement.h +++ b/3rd_party/apache-arrow-adbc/c/driver/framework/statement.h @@ -27,6 +27,7 @@ #include "driver/framework/base_driver.h" #include "driver/framework/status.h" +#include "driver/framework/utility.h" namespace adbc::driver { @@ -87,7 +88,7 @@ class Statement : public BaseStatement { .ToAdbc(error); } if (bind_parameters_.release) bind_parameters_.release(&bind_parameters_); - AdbcMakeArrayStream(schema, values, &bind_parameters_); + MakeArrayStream(schema, values, &bind_parameters_); return ADBC_STATUS_OK; } diff --git a/3rd_party/apache-arrow-adbc/c/driver/framework/status.h b/3rd_party/apache-arrow-adbc/c/driver/framework/status.h index d7952a6..22e484d 100644 --- a/3rd_party/apache-arrow-adbc/c/driver/framework/status.h +++ b/3rd_party/apache-arrow-adbc/c/driver/framework/status.h @@ -69,6 +69,19 @@ class Status { impl_->details.push_back({std::move(key), std::move(value)}); } + /// \brief Set the sqlstate of this status + void SetSqlState(std::string sqlstate) { + assert(impl_ != nullptr); + std::memset(impl_->sql_state, 0, sizeof(impl_->sql_state)); + for (size_t i = 0; i < sqlstate.size(); i++) { + if (i >= sizeof(impl_->sql_state)) { + break; + } + + impl_->sql_state[i] = sqlstate[i]; + } + } + /// \brief Export this status to an AdbcError. AdbcStatusCode ToAdbc(AdbcError* adbc_error) const { if (impl_ == nullptr) return ADBC_STATUS_OK; @@ -112,7 +125,29 @@ class Status { return status; } + // Helpers to create statuses with known codes + static Status Ok() { return Status(); } + +#define STATUS_CTOR(NAME, CODE) \ + template \ + static Status NAME(Args&&... args) { \ + std::stringstream ss; \ + ([&] { ss << args; }(), ...); \ + return Status(ADBC_STATUS_##CODE, ss.str()); \ + } + + STATUS_CTOR(Internal, INTERNAL) + STATUS_CTOR(InvalidArgument, INVALID_ARGUMENT) + STATUS_CTOR(InvalidState, INVALID_STATE) + STATUS_CTOR(IO, IO) + STATUS_CTOR(NotFound, NOT_FOUND) + STATUS_CTOR(NotImplemented, NOT_IMPLEMENTED) + STATUS_CTOR(Unknown, UNKNOWN) + +#undef STATUS_CTOR + private: + /// \brief Private Status implementation details struct Impl { // invariant: code is never OK AdbcStatusCode code; @@ -133,6 +168,8 @@ class Status { template friend class Driver; + // Allow access to these for drivers transitioning to the framework + public: int CDetailCount() const { return impl_ ? static_cast(impl_->details.size()) : 0; } AdbcErrorDetail CDetail(int index) const { @@ -144,6 +181,7 @@ class Status { detail.second.size()}; } + private: static void CRelease(AdbcError* error) { if (error->vendor_code == ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA) { auto* error_obj = reinterpret_cast(error->private_data); diff --git a/3rd_party/apache-arrow-adbc/c/driver/framework/utility.cc b/3rd_party/apache-arrow-adbc/c/driver/framework/utility.cc new file mode 100644 index 0000000..cbcd8bb --- /dev/null +++ b/3rd_party/apache-arrow-adbc/c/driver/framework/utility.cc @@ -0,0 +1,179 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "driver/framework/utility.h" + +#include +#include + +#include "arrow-adbc/adbc.h" +#include "nanoarrow/nanoarrow.hpp" + +namespace adbc::driver { + +void MakeEmptyStream(ArrowSchema* schema, ArrowArrayStream* out) { + nanoarrow::EmptyArrayStream(schema).ToArrayStream(out); +} + +void MakeArrayStream(ArrowSchema* schema, ArrowArray* array, ArrowArrayStream* out) { + if (array->length == 0) { + ArrowArrayRelease(array); + std::memset(array, 0, sizeof(ArrowArray)); + + MakeEmptyStream(schema, out); + } else { + nanoarrow::VectorArrayStream(schema, array).ToArrayStream(out); + } +} + +Status MakeTableTypesStream(const std::vector& table_types, + ArrowArrayStream* out) { + nanoarrow::UniqueArray array; + nanoarrow::UniqueSchema schema; + ArrowSchemaInit(schema.get()); + + UNWRAP_ERRNO(Internal, ArrowSchemaSetType(schema.get(), NANOARROW_TYPE_STRUCT)); + UNWRAP_ERRNO(Internal, ArrowSchemaAllocateChildren(schema.get(), /*num_columns=*/1)); + ArrowSchemaInit(schema.get()->children[0]); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(schema.get()->children[0], NANOARROW_TYPE_STRING)); + UNWRAP_ERRNO(Internal, ArrowSchemaSetName(schema.get()->children[0], "table_type")); + schema.get()->children[0]->flags &= ~ARROW_FLAG_NULLABLE; + + UNWRAP_ERRNO(Internal, ArrowArrayInitFromSchema(array.get(), schema.get(), NULL)); + UNWRAP_ERRNO(Internal, ArrowArrayStartAppending(array.get())); + + for (std::string const& table_type : table_types) { + UNWRAP_ERRNO(Internal, ArrowArrayAppendString(array->children[0], + ArrowCharView(table_type.c_str()))); + UNWRAP_ERRNO(Internal, ArrowArrayFinishElement(array.get())); + } + + UNWRAP_ERRNO(Internal, ArrowArrayFinishBuildingDefault(array.get(), nullptr)); + MakeArrayStream(schema.get(), array.get(), out); + return status::Ok(); +} + +namespace { +Status MakeGetInfoInit(ArrowSchema* schema, ArrowArray* array) { + ArrowSchemaInit(schema); + UNWRAP_ERRNO(Internal, ArrowSchemaSetTypeStruct(schema, /*num_columns=*/2)); + + UNWRAP_ERRNO(Internal, ArrowSchemaSetType(schema->children[0], NANOARROW_TYPE_UINT32)); + UNWRAP_ERRNO(Internal, ArrowSchemaSetName(schema->children[0], "info_name")); + schema->children[0]->flags &= ~ARROW_FLAG_NULLABLE; + + ArrowSchema* info_value = schema->children[1]; + UNWRAP_ERRNO(Internal, + ArrowSchemaSetTypeUnion(info_value, NANOARROW_TYPE_DENSE_UNION, 6)); + UNWRAP_ERRNO(Internal, ArrowSchemaSetName(info_value, "info_value")); + + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(info_value->children[0], NANOARROW_TYPE_STRING)); + UNWRAP_ERRNO(Internal, ArrowSchemaSetName(info_value->children[0], "string_value")); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(info_value->children[1], NANOARROW_TYPE_BOOL)); + UNWRAP_ERRNO(Internal, ArrowSchemaSetName(info_value->children[1], "bool_value")); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(info_value->children[2], NANOARROW_TYPE_INT64)); + UNWRAP_ERRNO(Internal, ArrowSchemaSetName(info_value->children[2], "int64_value")); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(info_value->children[3], NANOARROW_TYPE_INT32)); + UNWRAP_ERRNO(Internal, ArrowSchemaSetName(info_value->children[3], "int32_bitmask")); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(info_value->children[4], NANOARROW_TYPE_LIST)); + UNWRAP_ERRNO(Internal, ArrowSchemaSetName(info_value->children[4], "string_list")); + UNWRAP_ERRNO(Internal, ArrowSchemaSetType(info_value->children[5], NANOARROW_TYPE_MAP)); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetName(info_value->children[5], "int32_to_int32_list_map")); + + UNWRAP_ERRNO(Internal, ArrowSchemaSetType(info_value->children[4]->children[0], + NANOARROW_TYPE_STRING)); + + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(info_value->children[5]->children[0]->children[0], + NANOARROW_TYPE_INT32)); + info_value->children[5]->children[0]->children[0]->flags &= ~ARROW_FLAG_NULLABLE; + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(info_value->children[5]->children[0]->children[1], + NANOARROW_TYPE_LIST)); + UNWRAP_ERRNO( + Internal, + ArrowSchemaSetType(info_value->children[5]->children[0]->children[1]->children[0], + NANOARROW_TYPE_INT32)); + + UNWRAP_ERRNO(Internal, ArrowArrayInitFromSchema(array, schema, nullptr)); + UNWRAP_ERRNO(Internal, ArrowArrayStartAppending(array)); + + return status::Ok(); +} + +Status MakeGetInfoAppendString(ArrowArray* array, uint32_t info_code, + std::string_view info_value) { + UNWRAP_ERRNO(Internal, ArrowArrayAppendUInt(array->children[0], info_code)); + // Append to type variant + ArrowStringView value; + value.data = info_value.data(); + value.size_bytes = static_cast(info_value.size()); + UNWRAP_ERRNO(Internal, ArrowArrayAppendString(array->children[1]->children[0], value)); + // Append type code/offset + UNWRAP_ERRNO(Internal, ArrowArrayFinishUnionElement(array->children[1], /*type_id=*/0)); + return status::Ok(); +} + +Status MakeGetInfoAppendInt(ArrowArray* array, uint32_t info_code, int64_t info_value) { + UNWRAP_ERRNO(Internal, ArrowArrayAppendUInt(array->children[0], info_code)); + // Append to type variant + UNWRAP_ERRNO(Internal, + ArrowArrayAppendInt(array->children[1]->children[2], info_value)); + // Append type code/offset + UNWRAP_ERRNO(Internal, ArrowArrayFinishUnionElement(array->children[1], /*type_id=*/2)); + return status::Ok(); +} +} // namespace + +Status MakeGetInfoStream(const std::vector& infos, ArrowArrayStream* out) { + nanoarrow::UniqueSchema schema; + nanoarrow::UniqueArray array; + + UNWRAP_STATUS(MakeGetInfoInit(schema.get(), array.get())); + + for (const auto& info : infos) { + UNWRAP_STATUS(std::visit( + [&](auto&& value) -> Status { + using T = std::decay_t; + if constexpr (std::is_same_v) { + return MakeGetInfoAppendString(array.get(), info.code, value); + } else if constexpr (std::is_same_v) { + return MakeGetInfoAppendInt(array.get(), info.code, value); + } else { + static_assert(!sizeof(T), "info value type not implemented"); + } + return status::Ok(); + }, + info.value)); + UNWRAP_ERRNO(Internal, ArrowArrayFinishElement(array.get())); + } + + ArrowError na_error = {0}; + UNWRAP_NANOARROW(na_error, Internal, + ArrowArrayFinishBuildingDefault(array.get(), &na_error)); + MakeArrayStream(schema.get(), array.get(), out); + return status::Ok(); +} + +} // namespace adbc::driver diff --git a/3rd_party/apache-arrow-adbc/c/driver/framework/utility.h b/3rd_party/apache-arrow-adbc/c/driver/framework/utility.h new file mode 100644 index 0000000..af60594 --- /dev/null +++ b/3rd_party/apache-arrow-adbc/c/driver/framework/utility.h @@ -0,0 +1,73 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include + +#include + +#include "driver/framework/status.h" + +namespace adbc::driver { + +/// \brief Create an ArrowArrayStream with zero batches from a given ArrowSchema. +/// \ingroup adbc-framework-catalog +/// +/// This function takes ownership of schema; the caller is responsible for +/// releasing out. +void MakeEmptyStream(ArrowSchema* schema, ArrowArrayStream* out); + +/// \brief Create an ArrowArrayStream from a given ArrowSchema and ArrowArray. +/// \ingroup adbc-framework-catalog +/// +/// The resulting ArrowArrayStream will contain zero batches if the length of the +/// array is zero, or exactly one batch if the length of the array is non-zero. +/// This function takes ownership of schema and array; the caller is responsible for +/// releasing out. +void MakeArrayStream(ArrowSchema* schema, ArrowArray* array, ArrowArrayStream* out); + +/// \brief Create an ArrowArrayStream representation of a vector of table types. +/// \ingroup adbc-framework-catalog +/// +/// Create an ArrowArrayStream representation of an array of table types +/// that can be used to implement AdbcConnectionGetTableTypes(). The caller is responsible +/// for releasing out on success. +Status MakeTableTypesStream(const std::vector& table_types, + ArrowArrayStream* out); + +/// \brief Representation of a single item in an array to be returned +/// from AdbcConnectionGetInfo(). +/// \ingroup adbc-framework-catalog +struct InfoValue { + uint32_t code; + std::variant value; + + InfoValue(uint32_t code, std::variant value) + : code(code), value(std::move(value)) {} + InfoValue(uint32_t code, const char* value) : InfoValue(code, std::string(value)) {} +}; + +/// \brief Create an ArrowArrayStream to be returned from AdbcConnectionGetInfo(). +/// \ingroup adbc-framework-catalog +/// +/// The caller is responsible for releasing out on success. +Status MakeGetInfoStream(const std::vector& infos, ArrowArrayStream* out); + +} // namespace adbc::driver diff --git a/3rd_party/apache-arrow-adbc/c/driver/postgresql/bind_stream.h b/3rd_party/apache-arrow-adbc/c/driver/postgresql/bind_stream.h index 3e440e8..df0b9d2 100644 --- a/3rd_party/apache-arrow-adbc/c/driver/postgresql/bind_stream.h +++ b/3rd_party/apache-arrow-adbc/c/driver/postgresql/bind_stream.h @@ -19,6 +19,7 @@ #include #include +#include #include #include #include @@ -26,10 +27,10 @@ #include #include "copy/writer.h" -#include "driver/common/utils.h" #include "error.h" #include "postgres_type.h" #include "postgres_util.h" +#include "result_helper.h" namespace adbcpq { @@ -44,18 +45,15 @@ struct BindStream { Handle bind_schema; int64_t current_row = -1; - struct ArrowSchemaView bind_schema_view; std::vector bind_schema_fields; + std::vector> bind_field_writers; // OIDs for parameter types std::vector param_types; std::vector param_values; - std::vector param_lengths; std::vector param_formats; - std::vector param_values_offsets; - std::vector param_values_buffer; - // XXX: this assumes fixed-length fields only - will need more - // consideration to deal with variable-length fields + std::vector param_lengths; + Handle param_buffer; bool has_tz_field = false; std::string tz_setting; @@ -73,414 +71,164 @@ struct BindStream { } template - AdbcStatusCode Begin(Callback&& callback, struct AdbcError* error) { - CHECK_NA_DETAIL(INTERNAL, - ArrowArrayStreamGetSchema(&bind.value, &bind_schema.value, &na_error), - &na_error, error); - CHECK_NA_DETAIL(INTERNAL, - ArrowSchemaViewInit(&bind_schema_view, &bind_schema.value, &na_error), - &na_error, error); - + Status Begin(Callback&& callback) { + UNWRAP_NANOARROW( + na_error, Internal, + ArrowArrayStreamGetSchema(&bind.value, &bind_schema.value, &na_error)); + + struct ArrowSchemaView bind_schema_view; + UNWRAP_NANOARROW( + na_error, Internal, + ArrowSchemaViewInit(&bind_schema_view, &bind_schema.value, &na_error)); if (bind_schema_view.type != ArrowType::NANOARROW_TYPE_STRUCT) { - SetError(error, "%s", "[libpq] Bind parameters must have type STRUCT"); - return ADBC_STATUS_INVALID_STATE; + return Status::InvalidState("[libpq] Bind parameters must have type STRUCT"); } bind_schema_fields.resize(bind_schema->n_children); for (size_t i = 0; i < bind_schema_fields.size(); i++) { - CHECK_NA(INTERNAL, - ArrowSchemaViewInit(&bind_schema_fields[i], bind_schema->children[i], - /*error*/ nullptr), - error); + UNWRAP_ERRNO(Internal, + ArrowSchemaViewInit(&bind_schema_fields[i], bind_schema->children[i], + /*error*/ nullptr)); } - CHECK_NA_DETAIL( - INTERNAL, - ArrowArrayViewInitFromSchema(&array_view.value, &bind_schema.value, &na_error), - &na_error, error); + UNWRAP_NANOARROW( + na_error, Internal, + ArrowArrayViewInitFromSchema(&array_view.value, &bind_schema.value, &na_error)); + + ArrowBufferInit(¶m_buffer.value); return std::move(callback)(); } - AdbcStatusCode SetParamTypes(const PostgresTypeResolver& type_resolver, - struct AdbcError* error) { + Status SetParamTypes(PGconn* pg_conn, const PostgresTypeResolver& type_resolver, + const bool autocommit) { param_types.resize(bind_schema->n_children); param_values.resize(bind_schema->n_children); param_lengths.resize(bind_schema->n_children); param_formats.resize(bind_schema->n_children, kPgBinaryFormat); - param_values_offsets.reserve(bind_schema->n_children); - - for (size_t i = 0; i < bind_schema_fields.size(); i++) { - PostgresTypeId type_id; - switch (bind_schema_fields[i].type) { - case ArrowType::NANOARROW_TYPE_BOOL: - type_id = PostgresTypeId::kBool; - param_lengths[i] = 1; - break; - case ArrowType::NANOARROW_TYPE_INT8: - case ArrowType::NANOARROW_TYPE_INT16: - type_id = PostgresTypeId::kInt2; - param_lengths[i] = 2; - break; - case ArrowType::NANOARROW_TYPE_INT32: - type_id = PostgresTypeId::kInt4; - param_lengths[i] = 4; - break; - case ArrowType::NANOARROW_TYPE_INT64: - type_id = PostgresTypeId::kInt8; - param_lengths[i] = 8; - break; - case ArrowType::NANOARROW_TYPE_FLOAT: - type_id = PostgresTypeId::kFloat4; - param_lengths[i] = 4; - break; - case ArrowType::NANOARROW_TYPE_DOUBLE: - type_id = PostgresTypeId::kFloat8; - param_lengths[i] = 8; - break; - case ArrowType::NANOARROW_TYPE_STRING: - case ArrowType::NANOARROW_TYPE_LARGE_STRING: - type_id = PostgresTypeId::kText; - param_lengths[i] = 0; - break; - case ArrowType::NANOARROW_TYPE_BINARY: - type_id = PostgresTypeId::kBytea; - param_lengths[i] = 0; - break; - case ArrowType::NANOARROW_TYPE_DATE32: - type_id = PostgresTypeId::kDate; - param_lengths[i] = 4; - break; - case ArrowType::NANOARROW_TYPE_TIMESTAMP: - type_id = PostgresTypeId::kTimestamp; - param_lengths[i] = 8; - break; - case ArrowType::NANOARROW_TYPE_DURATION: - case ArrowType::NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO: - type_id = PostgresTypeId::kInterval; - param_lengths[i] = 16; - break; - case ArrowType::NANOARROW_TYPE_DECIMAL128: - case ArrowType::NANOARROW_TYPE_DECIMAL256: - type_id = PostgresTypeId::kNumeric; - param_lengths[i] = 0; - break; - case ArrowType::NANOARROW_TYPE_DICTIONARY: { - struct ArrowSchemaView value_view; - CHECK_NA(INTERNAL, - ArrowSchemaViewInit(&value_view, bind_schema->children[i]->dictionary, - nullptr), - error); - switch (value_view.type) { - case NANOARROW_TYPE_BINARY: - case NANOARROW_TYPE_LARGE_BINARY: - type_id = PostgresTypeId::kBytea; - param_lengths[i] = 0; - break; - case NANOARROW_TYPE_STRING: - case NANOARROW_TYPE_LARGE_STRING: - type_id = PostgresTypeId::kText; - param_lengths[i] = 0; - break; - default: - SetError(error, "%s%" PRIu64 "%s%s%s%s", "[libpq] Field #", - static_cast(i + 1), " ('", - bind_schema->children[i]->name, - "') has unsupported dictionary value parameter type ", - ArrowTypeString(value_view.type)); - return ADBC_STATUS_NOT_IMPLEMENTED; - } - break; - } - default: - SetError(error, "%s%" PRIu64 "%s%s%s%s", "[libpq] Field #", - static_cast(i + 1), " ('", bind_schema->children[i]->name, - "') has unsupported parameter type ", - ArrowTypeString(bind_schema_fields[i].type)); - return ADBC_STATUS_NOT_IMPLEMENTED; + bind_field_writers.resize(bind_schema->n_children); + + for (size_t i = 0; i < bind_field_writers.size(); i++) { + PostgresType type; + UNWRAP_NANOARROW(na_error, Internal, + PostgresType::FromSchema(type_resolver, bind_schema->children[i], + &type, &na_error)); + + // tz-aware timestamps require special handling to set the timezone to UTC + // prior to sending over the binary protocol; must be reset after execute + if (!has_tz_field && type.type_id() == PostgresTypeId::kTimestamptz) { + UNWRAP_STATUS(SetDatabaseTimezoneUTC(pg_conn, autocommit)); + has_tz_field = true; } - param_types[i] = type_resolver.GetOID(type_id); - if (param_types[i] == 0) { - SetError(error, "%s%" PRIu64 "%s%s%s%s", "[libpq] Field #", - static_cast(i + 1), " ('", bind_schema->children[i]->name, - "') has type with no corresponding PostgreSQL type ", - ArrowTypeString(bind_schema_fields[i].type)); - return ADBC_STATUS_NOT_IMPLEMENTED; - } - } + std::unique_ptr writer; + UNWRAP_NANOARROW( + na_error, Internal, + MakeCopyFieldWriter(bind_schema->children[i], array_view->children[i], + type_resolver, &writer, &na_error)); - size_t param_values_length = 0; - for (int length : param_lengths) { - param_values_offsets.push_back(param_values_length); - param_values_length += length; + param_types[i] = type.oid(); + param_formats[i] = kPgBinaryFormat; + bind_field_writers[i] = std::move(writer); } - param_values_buffer.resize(param_values_length); - return ADBC_STATUS_OK; - } - AdbcStatusCode Prepare(PGconn* pg_conn, const std::string& query, - struct AdbcError* error, const bool autocommit) { - // tz-aware timestamps require special handling to set the timezone to UTC - // prior to sending over the binary protocol; must be reset after execute - for (int64_t col = 0; col < bind_schema->n_children; col++) { - if ((bind_schema_fields[col].type == ArrowType::NANOARROW_TYPE_TIMESTAMP) && - (strcmp("", bind_schema_fields[col].timezone))) { - has_tz_field = true; + return Status::Ok(); + } - if (autocommit) { - PGresult* begin_result = PQexec(pg_conn, "BEGIN"); - if (PQresultStatus(begin_result) != PGRES_COMMAND_OK) { - AdbcStatusCode code = - SetError(error, begin_result, - "[libpq] Failed to begin transaction for timezone data: %s", - PQerrorMessage(pg_conn)); - PQclear(begin_result); - return code; - } - PQclear(begin_result); - } - - PGresult* get_tz_result = PQexec(pg_conn, "SELECT current_setting('TIMEZONE')"); - if (PQresultStatus(get_tz_result) != PGRES_TUPLES_OK) { - AdbcStatusCode code = SetError(error, get_tz_result, - "[libpq] Could not query current timezone: %s", - PQerrorMessage(pg_conn)); - PQclear(get_tz_result); - return code; - } - - tz_setting = std::string(PQgetvalue(get_tz_result, 0, 0)); - PQclear(get_tz_result); - - PGresult* set_utc_result = PQexec(pg_conn, "SET TIME ZONE 'UTC'"); - if (PQresultStatus(set_utc_result) != PGRES_COMMAND_OK) { - AdbcStatusCode code = SetError(error, set_utc_result, - "[libpq] Failed to set time zone to UTC: %s", - PQerrorMessage(pg_conn)); - PQclear(set_utc_result); - return code; - } - PQclear(set_utc_result); - break; - } + Status SetDatabaseTimezoneUTC(PGconn* pg_conn, const bool autocommit) { + if (autocommit) { + PqResultHelper helper(pg_conn, "BEGIN"); + UNWRAP_STATUS(helper.Execute()); } - PGresult* result = PQprepare(pg_conn, /*stmtName=*/"", query.c_str(), - /*nParams=*/bind_schema->n_children, param_types.data()); - if (PQresultStatus(result) != PGRES_COMMAND_OK) { - AdbcStatusCode code = - SetError(error, result, "[libpq] Failed to prepare query: %s\nQuery was:%s", - PQerrorMessage(pg_conn), query.c_str()); - PQclear(result); - return code; + PqResultHelper get_tz(pg_conn, "SELECT current_setting('TIMEZONE')"); + UNWRAP_STATUS(get_tz.Execute()); + for (auto row : get_tz) { + tz_setting = row[0].value(); } - PQclear(result); - return ADBC_STATUS_OK; + + PqResultHelper set_utc(pg_conn, "SET TIME ZONE 'UTC'"); + UNWRAP_STATUS(set_utc.Execute()); + + return Status::Ok(); + } + + Status Prepare(PGconn* pg_conn, const std::string& query) { + PqResultHelper helper(pg_conn, query); + UNWRAP_STATUS(helper.Prepare(param_types)); + return Status::Ok(); } - AdbcStatusCode PullNextArray(AdbcError* error) { + Status PullNextArray() { if (current->release != nullptr) ArrowArrayRelease(¤t.value); - CHECK_NA_DETAIL(IO, ArrowArrayStreamGetNext(&bind.value, ¤t.value, &na_error), - &na_error, error); + UNWRAP_NANOARROW(na_error, IO, + ArrowArrayStreamGetNext(&bind.value, ¤t.value, &na_error)); if (current->release != nullptr) { - CHECK_NA_DETAIL( - INTERNAL, ArrowArrayViewSetArray(&array_view.value, ¤t.value, &na_error), - &na_error, error); + UNWRAP_NANOARROW( + na_error, Internal, + ArrowArrayViewSetArray(&array_view.value, ¤t.value, &na_error)); } - return ADBC_STATUS_OK; + return Status::Ok(); } - AdbcStatusCode EnsureNextRow(AdbcError* error) { + Status EnsureNextRow() { if (current->release != nullptr) { current_row++; if (current_row < current->length) { - return ADBC_STATUS_OK; + return Status::Ok(); } } // Pull until we have an array with at least one row or the stream is finished do { - RAISE_ADBC(PullNextArray(error)); + UNWRAP_STATUS(PullNextArray()); if (current->release == nullptr) { current_row = -1; - return ADBC_STATUS_OK; + return Status::Ok(); } } while (current->length == 0); current_row = 0; - return ADBC_STATUS_OK; + return Status::Ok(); } - AdbcStatusCode BindAndExecuteCurrentRow(PGconn* pg_conn, PGresult** result_out, - int result_format, AdbcError* error) { - int64_t row = current_row; + Status BindAndExecuteCurrentRow(PGconn* pg_conn, PGresult** result_out, + int result_format) { + param_buffer->size_bytes = 0; + int64_t last_offset = 0; for (int64_t col = 0; col < array_view->n_children; col++) { - if (ArrowArrayViewIsNull(array_view->children[col], row)) { - param_values[col] = nullptr; - continue; + if (!ArrowArrayViewIsNull(array_view->children[col], current_row)) { + // Note that this Write() call currently writes the (int32_t) byte size of the + // field in addition to the serialized value. + UNWRAP_NANOARROW( + na_error, Internal, + bind_field_writers[col]->Write(¶m_buffer.value, current_row, &na_error)); } else { - param_values[col] = param_values_buffer.data() + param_values_offsets[col]; + UNWRAP_ERRNO(Internal, ArrowBufferAppendInt32(¶m_buffer.value, 0)); } - switch (bind_schema_fields[col].type) { - case ArrowType::NANOARROW_TYPE_BOOL: { - const int8_t val = - ArrowBitGet(array_view->children[col]->buffer_views[1].data.as_uint8, row); - std::memcpy(param_values[col], &val, sizeof(int8_t)); - break; - } - - case ArrowType::NANOARROW_TYPE_INT8: { - const int16_t val = - array_view->children[col]->buffer_views[1].data.as_int8[row]; - const uint16_t value = ToNetworkInt16(val); - std::memcpy(param_values[col], &value, sizeof(int16_t)); - break; - } - case ArrowType::NANOARROW_TYPE_INT16: { - const uint16_t value = ToNetworkInt16( - array_view->children[col]->buffer_views[1].data.as_int16[row]); - std::memcpy(param_values[col], &value, sizeof(int16_t)); - break; - } - case ArrowType::NANOARROW_TYPE_INT32: { - const uint32_t value = ToNetworkInt32( - array_view->children[col]->buffer_views[1].data.as_int32[row]); - std::memcpy(param_values[col], &value, sizeof(int32_t)); - break; - } - case ArrowType::NANOARROW_TYPE_INT64: { - const int64_t value = ToNetworkInt64( - array_view->children[col]->buffer_views[1].data.as_int64[row]); - std::memcpy(param_values[col], &value, sizeof(int64_t)); - break; - } - case ArrowType::NANOARROW_TYPE_FLOAT: { - const uint32_t value = ToNetworkFloat4( - array_view->children[col]->buffer_views[1].data.as_float[row]); - std::memcpy(param_values[col], &value, sizeof(uint32_t)); - break; - } - case ArrowType::NANOARROW_TYPE_DOUBLE: { - const uint64_t value = ToNetworkFloat8( - array_view->children[col]->buffer_views[1].data.as_double[row]); - std::memcpy(param_values[col], &value, sizeof(uint64_t)); - break; - } - case ArrowType::NANOARROW_TYPE_STRING: - case ArrowType::NANOARROW_TYPE_LARGE_STRING: - case ArrowType::NANOARROW_TYPE_BINARY: { - const ArrowBufferView view = - ArrowArrayViewGetBytesUnsafe(array_view->children[col], row); - // TODO: overflow check? - param_lengths[col] = static_cast(view.size_bytes); - param_values[col] = const_cast(view.data.as_char); - break; - } - case ArrowType::NANOARROW_TYPE_DATE32: { - // 2000-01-01 - constexpr int32_t kPostgresDateEpoch = 10957; - const int32_t raw_value = - array_view->children[col]->buffer_views[1].data.as_int32[row]; - if (raw_value < INT32_MIN + kPostgresDateEpoch) { - SetError(error, "[libpq] Field #%" PRId64 "%s%s%s%" PRId64 "%s", col + 1, - "('", bind_schema->children[col]->name, "') Row #", row + 1, - "has value which exceeds postgres date limits"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - - const uint32_t value = ToNetworkInt32(raw_value - kPostgresDateEpoch); - std::memcpy(param_values[col], &value, sizeof(int32_t)); - break; - } - case ArrowType::NANOARROW_TYPE_DURATION: - case ArrowType::NANOARROW_TYPE_TIMESTAMP: { - int64_t val = array_view->children[col]->buffer_views[1].data.as_int64[row]; - - bool overflow_safe = true; - - auto unit = bind_schema_fields[col].time_unit; - - switch (unit) { - case NANOARROW_TIME_UNIT_SECOND: - overflow_safe = - val <= kMaxSafeSecondsToMicros && val >= kMinSafeSecondsToMicros; - if (overflow_safe) { - val *= 1000000; - } - - break; - case NANOARROW_TIME_UNIT_MILLI: - overflow_safe = - val <= kMaxSafeMillisToMicros && val >= kMinSafeMillisToMicros; - if (overflow_safe) { - val *= 1000; - } - break; - case NANOARROW_TIME_UNIT_MICRO: - break; - case NANOARROW_TIME_UNIT_NANO: - val /= 1000; - break; - } - - if (!overflow_safe) { - SetError(error, - "[libpq] Field #%" PRId64 " ('%s') Row #%" PRId64 - " has value '%" PRIi64 "' which exceeds PostgreSQL timestamp limits", - col + 1, bind_schema->children[col]->name, row + 1, - array_view->children[col]->buffer_views[1].data.as_int64[row]); - return ADBC_STATUS_INVALID_ARGUMENT; - } - - if (val < (std::numeric_limits::min)() + kPostgresTimestampEpoch) { - SetError(error, - "[libpq] Field #%" PRId64 " ('%s') Row #%" PRId64 - " has value '%" PRIi64 "' which would underflow", - col + 1, bind_schema->children[col]->name, row + 1, - array_view->children[col]->buffer_views[1].data.as_int64[row]); - return ADBC_STATUS_INVALID_ARGUMENT; - } - - if (bind_schema_fields[col].type == ArrowType::NANOARROW_TYPE_TIMESTAMP) { - const uint64_t value = ToNetworkInt64(val - kPostgresTimestampEpoch); - std::memcpy(param_values[col], &value, sizeof(int64_t)); - } else if (bind_schema_fields[col].type == ArrowType::NANOARROW_TYPE_DURATION) { - // postgres stores an interval as a 64 bit offset in microsecond - // resolution alongside a 32 bit day and 32 bit month - // for now we just send 0 for the day / month values - const uint64_t value = ToNetworkInt64(val); - std::memcpy(param_values[col], &value, sizeof(int64_t)); - std::memset(param_values[col] + sizeof(int64_t), 0, sizeof(int64_t)); - } - break; - } - case ArrowType::NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO: { - struct ArrowInterval interval; - ArrowIntervalInit(&interval, NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO); - ArrowArrayViewGetIntervalUnsafe(array_view->children[col], row, &interval); - - const uint32_t months = ToNetworkInt32(interval.months); - const uint32_t days = ToNetworkInt32(interval.days); - const uint64_t ms = ToNetworkInt64(interval.ns / 1000); - - std::memcpy(param_values[col], &ms, sizeof(uint64_t)); - std::memcpy(param_values[col] + sizeof(uint64_t), &days, sizeof(uint32_t)); - std::memcpy(param_values[col] + sizeof(uint64_t) + sizeof(uint32_t), &months, - sizeof(uint32_t)); - break; - } - default: - SetError(error, "%s%" PRId64 "%s%s%s%s", "[libpq] Field #", col + 1, " ('", - bind_schema->children[col]->name, - "') has unsupported type for ingestion ", - ArrowTypeString(bind_schema_fields[col].type)); - return ADBC_STATUS_NOT_IMPLEMENTED; + + int64_t param_length = param_buffer->size_bytes - last_offset - sizeof(int32_t); + if (param_length > (std::numeric_limits::max)()) { + return Status::Internal("Paramter ", col, "serialized to >2GB of binary"); } + + param_lengths[col] = static_cast(param_length); + last_offset = param_buffer->size_bytes; + } + + last_offset = 0; + for (int64_t col = 0; col < array_view->n_children; col++) { + last_offset += sizeof(int32_t); + if (param_lengths[col] == 0) { + param_values[col] = nullptr; + } else { + param_values[col] = reinterpret_cast(param_buffer->data) + last_offset; + } + last_offset += param_lengths[col]; } PGresult* result = @@ -490,60 +238,45 @@ struct BindStream { ExecStatusType pg_status = PQresultStatus(result); if (pg_status != PGRES_COMMAND_OK && pg_status != PGRES_TUPLES_OK) { - AdbcStatusCode code = - SetError(error, result, "[libpq] Failed to execute prepared statement: %s %s", - PQresStatus(pg_status), PQerrorMessage(pg_conn)); + Status status = + MakeStatus(result, "[libpq] Failed to execute prepared statement: {} {}", + PQresStatus(pg_status), PQerrorMessage(pg_conn)); PQclear(result); - return code; + return status; } *result_out = result; - return ADBC_STATUS_OK; + return Status::Ok(); } - AdbcStatusCode Cleanup(PGconn* pg_conn, AdbcError* error) { + Status Cleanup(PGconn* pg_conn) { if (has_tz_field) { - std::string reset_query = "SET TIME ZONE '" + tz_setting + "'"; - PGresult* reset_tz_result = PQexec(pg_conn, reset_query.c_str()); - if (PQresultStatus(reset_tz_result) != PGRES_COMMAND_OK) { - AdbcStatusCode code = - SetError(error, reset_tz_result, "[libpq] Failed to reset time zone: %s", - PQerrorMessage(pg_conn)); - PQclear(reset_tz_result); - return code; - } - PQclear(reset_tz_result); - - PGresult* commit_result = PQexec(pg_conn, "COMMIT"); - if (PQresultStatus(commit_result) != PGRES_COMMAND_OK) { - AdbcStatusCode code = - SetError(error, commit_result, "[libpq] Failed to commit transaction: %s", - PQerrorMessage(pg_conn)); - PQclear(commit_result); - return code; - } - PQclear(commit_result); + PqResultHelper reset(pg_conn, "SET TIME ZONE '" + tz_setting + "'"); + UNWRAP_STATUS(reset.Execute()); + + PqResultHelper commit(pg_conn, "COMMIT"); + UNWRAP_STATUS(reset.Execute()); } - return ADBC_STATUS_OK; + return Status::Ok(); } - AdbcStatusCode ExecuteCopy(PGconn* pg_conn, const PostgresTypeResolver& type_resolver, - int64_t* rows_affected, struct AdbcError* error) { + Status ExecuteCopy(PGconn* pg_conn, const PostgresTypeResolver& type_resolver, + int64_t* rows_affected) { if (rows_affected) *rows_affected = 0; PostgresCopyStreamWriter writer; - CHECK_NA(INTERNAL, writer.Init(&bind_schema.value), error); - CHECK_NA_DETAIL(INTERNAL, writer.InitFieldWriters(type_resolver, &na_error), - &na_error, error); + UNWRAP_ERRNO(Internal, writer.Init(&bind_schema.value)); + UNWRAP_NANOARROW(na_error, Internal, + writer.InitFieldWriters(type_resolver, &na_error)); - CHECK_NA_DETAIL(INTERNAL, writer.WriteHeader(&na_error), &na_error, error); + UNWRAP_NANOARROW(na_error, Internal, writer.WriteHeader(&na_error)); while (true) { - RAISE_ADBC(PullNextArray(error)); + UNWRAP_STATUS(PullNextArray()); if (!current->release) break; - CHECK_NA(INTERNAL, writer.SetArray(¤t.value), error); + UNWRAP_ERRNO(Internal, writer.SetArray(¤t.value)); // build writer buffer int write_result; @@ -553,42 +286,38 @@ struct BindStream { // check if not ENODATA at exit if (write_result != ENODATA) { - SetError(error, "Error occurred writing COPY data: %s", PQerrorMessage(pg_conn)); - return ADBC_STATUS_IO; + return Status::IO("Error occurred writing COPY data: ", PQerrorMessage(pg_conn)); } - RAISE_ADBC(FlushCopyWriterToConn(pg_conn, writer, error)); + UNWRAP_STATUS(FlushCopyWriterToConn(pg_conn, writer)); if (rows_affected) *rows_affected += current->length; writer.Rewind(); } // If there were no arrays in the stream, we haven't flushed yet - RAISE_ADBC(FlushCopyWriterToConn(pg_conn, writer, error)); + UNWRAP_STATUS(FlushCopyWriterToConn(pg_conn, writer)); if (PQputCopyEnd(pg_conn, NULL) <= 0) { - SetError(error, "Error message returned by PQputCopyEnd: %s", - PQerrorMessage(pg_conn)); - return ADBC_STATUS_IO; + return Status::IO("Error message returned by PQputCopyEnd: ", + PQerrorMessage(pg_conn)); } PGresult* result = PQgetResult(pg_conn); ExecStatusType pg_status = PQresultStatus(result); if (pg_status != PGRES_COMMAND_OK) { - AdbcStatusCode code = - SetError(error, result, "[libpq] Failed to execute COPY statement: %s %s", - PQresStatus(pg_status), PQerrorMessage(pg_conn)); + Status status = + MakeStatus(result, "[libpq] Failed to execute COPY statement: {} {}", + PQresStatus(pg_status), PQerrorMessage(pg_conn)); PQclear(result); - return code; + return status; } PQclear(result); - return ADBC_STATUS_OK; + return Status::Ok(); } - AdbcStatusCode FlushCopyWriterToConn(PGconn* pg_conn, - const PostgresCopyStreamWriter& writer, - struct AdbcError* error) { + Status FlushCopyWriterToConn(PGconn* pg_conn, const PostgresCopyStreamWriter& writer) { // https://github.com/apache/arrow-adbc/issues/1921: PostgreSQL has a max // size for a single message that we need to respect (1 GiB - 1). Since // the buffer can be chunked up as much as we want, go for 16 MiB as our @@ -602,14 +331,13 @@ struct BindStream { while (remaining > 0) { int64_t to_write = std::min(remaining, kMaxCopyBufferSize); if (PQputCopyData(pg_conn, data, to_write) <= 0) { - SetError(error, "Error writing tuple field data: %s", PQerrorMessage(pg_conn)); - return ADBC_STATUS_IO; + return Status::IO("Error writing tuple field data: ", PQerrorMessage(pg_conn)); } remaining -= to_write; data += to_write; } - return ADBC_STATUS_OK; + return Status::Ok(); } }; } // namespace adbcpq diff --git a/3rd_party/apache-arrow-adbc/c/driver/postgresql/connection.cc b/3rd_party/apache-arrow-adbc/c/driver/postgresql/connection.cc index b5c0ef1..b5f12ca 100644 --- a/3rd_party/apache-arrow-adbc/c/driver/postgresql/connection.cc +++ b/3rd_party/apache-arrow-adbc/c/driver/postgresql/connection.cc @@ -17,13 +17,16 @@ #include "connection.h" +#include #include #include #include #include #include +#include #include #include +#include #include #include #include @@ -33,10 +36,14 @@ #include "database.h" #include "driver/common/utils.h" -#include "driver/framework/catalog.h" +#include "driver/framework/objects.h" +#include "driver/framework/utility.h" #include "error.h" #include "result_helper.h" +using adbc::driver::Result; +using adbc::driver::Status; + namespace adbcpq { namespace { @@ -50,551 +57,391 @@ static const std::unordered_map kPgTableTypes = { {"table", "r"}, {"view", "v"}, {"materialized_view", "m"}, {"toast_table", "t"}, {"foreign_table", "f"}, {"partitioned_table", "p"}}; -class PqGetObjectsHelper { +static const char* kCatalogQueryAll = "SELECT datname FROM pg_catalog.pg_database"; + +// catalog_name is not a parameter here or on any other queries +// because it will always be the currently connected database. +static const char* kSchemaQueryAll = + "SELECT nspname FROM pg_catalog.pg_namespace WHERE " + "nspname !~ '^pg_' AND nspname <> 'information_schema'"; + +// Parameterized on schema_name, relkind +// Note that when binding relkind as a string it must look like {"r", "v", ...} +// (i.e., double quotes). Binding a binary list element also works. +static const char* kTablesQueryAll = + "SELECT c.relname, CASE c.relkind WHEN 'r' THEN 'table' WHEN 'v' THEN 'view' " + "WHEN 'm' THEN 'materialized view' WHEN 't' THEN 'TOAST table' " + "WHEN 'f' THEN 'foreign table' WHEN 'p' THEN 'partitioned table' END " + "AS reltype FROM pg_catalog.pg_class c " + "LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace " + "WHERE pg_catalog.pg_table_is_visible(c.oid) AND n.nspname = $1 AND c.relkind = " + "ANY($2)"; + +// Parameterized on schema_name, table_name +static const char* kColumnsQueryAll = + "SELECT attr.attname, attr.attnum, " + "pg_catalog.col_description(cls.oid, attr.attnum) " + "FROM pg_catalog.pg_attribute AS attr " + "INNER JOIN pg_catalog.pg_class AS cls ON attr.attrelid = cls.oid " + "INNER JOIN pg_catalog.pg_namespace AS nsp ON nsp.oid = cls.relnamespace " + "WHERE attr.attnum > 0 AND NOT attr.attisdropped " + "AND nsp.nspname LIKE $1 AND cls.relname LIKE $2"; + +// Parameterized on schema_name, table_name +static const char* kConstraintsQueryAll = + "WITH fk_unnest AS ( " + " SELECT " + " con.conname, " + " 'FOREIGN KEY' AS contype, " + " conrelid, " + " UNNEST(con.conkey) AS conkey, " + " confrelid, " + " UNNEST(con.confkey) AS confkey " + " FROM pg_catalog.pg_constraint AS con " + " INNER JOIN pg_catalog.pg_class AS cls ON cls.oid = conrelid " + " INNER JOIN pg_catalog.pg_namespace AS nsp ON nsp.oid = cls.relnamespace " + " WHERE con.contype = 'f' AND nsp.nspname = $1 " + " AND cls.relname = $2 " + "), " + "fk_names AS ( " + " SELECT " + " fk_unnest.conname, " + " fk_unnest.contype, " + " fk_unnest.conkey, " + " fk_unnest.confkey, " + " attr.attname, " + " fnsp.nspname AS fschema, " + " fcls.relname AS ftable, " + " fattr.attname AS fattname " + " FROM fk_unnest " + " INNER JOIN pg_catalog.pg_class AS cls ON cls.oid = fk_unnest.conrelid " + " INNER JOIN pg_catalog.pg_class AS fcls ON fcls.oid = fk_unnest.confrelid " + " INNER JOIN pg_catalog.pg_namespace AS fnsp ON fnsp.oid = fcls.relnamespace" + " INNER JOIN pg_catalog.pg_attribute AS attr ON attr.attnum = " + "fk_unnest.conkey " + " AND attr.attrelid = fk_unnest.conrelid " + " LEFT JOIN pg_catalog.pg_attribute AS fattr ON fattr.attnum = " + "fk_unnest.confkey " + " AND fattr.attrelid = fk_unnest.confrelid " + "), " + "fkeys AS ( " + " SELECT " + " conname, " + " contype, " + " ARRAY_AGG(attname ORDER BY conkey) AS colnames, " + " fschema, " + " ftable, " + " ARRAY_AGG(fattname ORDER BY confkey) AS fcolnames " + " FROM fk_names " + " GROUP BY " + " conname, " + " contype, " + " fschema, " + " ftable " + "), " + "other_constraints AS ( " + " SELECT con.conname, CASE con.contype WHEN 'c' THEN 'CHECK' WHEN 'u' THEN " + " 'UNIQUE' WHEN 'p' THEN 'PRIMARY KEY' END AS contype, " + " ARRAY_AGG(attr.attname) AS colnames " + " FROM pg_catalog.pg_constraint AS con " + " CROSS JOIN UNNEST(conkey) AS conkeys " + " INNER JOIN pg_catalog.pg_class AS cls ON cls.oid = con.conrelid " + " INNER JOIN pg_catalog.pg_namespace AS nsp ON nsp.oid = cls.relnamespace " + " INNER JOIN pg_catalog.pg_attribute AS attr ON attr.attnum = conkeys " + " AND cls.oid = attr.attrelid " + " WHERE con.contype IN ('c', 'u', 'p') AND nsp.nspname = $1 " + " AND cls.relname = $2 " + " GROUP BY conname, contype " + ") " + "SELECT " + " conname, contype, colnames, fschema, ftable, fcolnames " + "FROM fkeys " + "UNION ALL " + "SELECT " + " conname, contype, colnames, NULL, NULL, NULL " + "FROM other_constraints"; + +class PostgresGetObjectsHelper : public adbc::driver::GetObjectsHelper { public: - PqGetObjectsHelper(PGconn* conn, int depth, const char* catalog, const char* db_schema, - const char* table_name, const char** table_types, - const char* column_name, struct ArrowSchema* schema, - struct ArrowArray* array, struct AdbcError* error) - : conn_(conn), - depth_(depth), - catalog_(catalog), - db_schema_(db_schema), - table_name_(table_name), - table_types_(table_types), - column_name_(column_name), - schema_(schema), - array_(array), - error_(error) { - na_error_ = {0}; - } - - AdbcStatusCode GetObjects() { - RAISE_ADBC(InitArrowArray()); - - catalog_name_col_ = array_->children[0]; - catalog_db_schemas_col_ = array_->children[1]; - catalog_db_schemas_items_ = catalog_db_schemas_col_->children[0]; - db_schema_name_col_ = catalog_db_schemas_items_->children[0]; - db_schema_tables_col_ = catalog_db_schemas_items_->children[1]; - schema_table_items_ = db_schema_tables_col_->children[0]; - table_name_col_ = schema_table_items_->children[0]; - table_type_col_ = schema_table_items_->children[1]; - - table_columns_col_ = schema_table_items_->children[2]; - table_columns_items_ = table_columns_col_->children[0]; - column_name_col_ = table_columns_items_->children[0]; - column_position_col_ = table_columns_items_->children[1]; - column_remarks_col_ = table_columns_items_->children[2]; - - table_constraints_col_ = schema_table_items_->children[3]; - table_constraints_items_ = table_constraints_col_->children[0]; - constraint_name_col_ = table_constraints_items_->children[0]; - constraint_type_col_ = table_constraints_items_->children[1]; - - constraint_column_names_col_ = table_constraints_items_->children[2]; - constraint_column_name_col_ = constraint_column_names_col_->children[0]; - - constraint_column_usages_col_ = table_constraints_items_->children[3]; - constraint_column_usage_items_ = constraint_column_usages_col_->children[0]; - fk_catalog_col_ = constraint_column_usage_items_->children[0]; - fk_db_schema_col_ = constraint_column_usage_items_->children[1]; - fk_table_col_ = constraint_column_usage_items_->children[2]; - fk_column_name_col_ = constraint_column_usage_items_->children[3]; - - RAISE_ADBC(AppendCatalogs()); - RAISE_ADBC(FinishArrowArray()); - return ADBC_STATUS_OK; + explicit PostgresGetObjectsHelper(PGconn* conn) + : current_database_(PQdb(conn)), + all_catalogs_(conn, kCatalogQueryAll), + some_catalogs_(conn, CatalogQuery()), + all_schemas_(conn, kSchemaQueryAll), + some_schemas_(conn, SchemaQuery()), + all_tables_(conn, kTablesQueryAll), + some_tables_(conn, TablesQuery()), + all_columns_(conn, kColumnsQueryAll), + some_columns_(conn, ColumnsQuery()), + all_constraints_(conn, kConstraintsQueryAll), + some_constraints_(conn, ConstraintsQuery()) {} + + // Allow Redshift to execute this query without constraints + // TODO(paleolimbot): Investigate to see if we can simplify the constraits query so that + // it works on both! + void SetEnableConstraints(bool enable_constraints) { + enable_constraints_ = enable_constraints; } - private: - AdbcStatusCode InitArrowArray() { - RAISE_ADBC(adbc::driver::AdbcInitConnectionObjectsSchema(schema_).ToAdbc(error_)); - - CHECK_NA_DETAIL(INTERNAL, ArrowArrayInitFromSchema(array_, schema_, &na_error_), - &na_error_, error_); - - CHECK_NA(INTERNAL, ArrowArrayStartAppending(array_), error_); - return ADBC_STATUS_OK; + Status Load(adbc::driver::GetObjectsDepth depth, + std::optional catalog_filter, + std::optional schema_filter, + std::optional table_filter, + std::optional column_filter, + const std::vector& table_types) override { + return Status::Ok(); } - AdbcStatusCode AppendSchemas(std::string db_name) { - // postgres only allows you to list schemas for the currently connected db - if (!strcmp(db_name.c_str(), PQdb(conn_))) { - struct StringBuilder query; - std::memset(&query, 0, sizeof(query)); - if (StringBuilderInit(&query, /*initial_size*/ 256)) { - return ADBC_STATUS_INTERNAL; - } - - const char* stmt = - "SELECT nspname FROM pg_catalog.pg_namespace WHERE " - "nspname !~ '^pg_' AND nspname <> 'information_schema'"; - - if (StringBuilderAppend(&query, "%s", stmt)) { - StringBuilderReset(&query); - return ADBC_STATUS_INTERNAL; - } - - std::vector params; - if (db_schema_ != NULL) { - if (StringBuilderAppend(&query, "%s", " AND nspname = $1")) { - StringBuilderReset(&query); - return ADBC_STATUS_INTERNAL; - } - params.push_back(db_schema_); - } - - auto result_helper = PqResultHelper{conn_, std::string(query.buffer)}; - StringBuilderReset(&query); + Status LoadCatalogs(std::optional catalog_filter) override { + if (catalog_filter.has_value()) { + UNWRAP_STATUS(some_catalogs_.Execute({std::string(*catalog_filter)})); + next_catalog_ = some_catalogs_.Row(-1); + } else { + UNWRAP_STATUS(all_catalogs_.Execute()); + next_catalog_ = all_catalogs_.Row(-1); + } - RAISE_ADBC(result_helper.Execute(error_, params)); + return Status::Ok(); + }; - for (PqResultRow row : result_helper) { - const char* schema_name = row[0].data; - CHECK_NA(INTERNAL, - ArrowArrayAppendString(db_schema_name_col_, ArrowCharView(schema_name)), - error_); - if (depth_ == ADBC_OBJECT_DEPTH_DB_SCHEMAS) { - CHECK_NA(INTERNAL, ArrowArrayAppendNull(db_schema_tables_col_, 1), error_); - } else { - RAISE_ADBC(AppendTables(std::string(schema_name))); - } - CHECK_NA(INTERNAL, ArrowArrayFinishElement(catalog_db_schemas_items_), error_); - } + Result> NextCatalog() override { + next_catalog_ = next_catalog_.Next(); + if (!next_catalog_.IsValid()) { + return std::nullopt; } - CHECK_NA(INTERNAL, ArrowArrayFinishElement(catalog_db_schemas_col_), error_); - return ADBC_STATUS_OK; + return next_catalog_[0].value(); } - AdbcStatusCode AppendCatalogs() { - struct StringBuilder query; - std::memset(&query, 0, sizeof(query)); - if (StringBuilderInit(&query, /*initial_size=*/256) != 0) return ADBC_STATUS_INTERNAL; - - if (StringBuilderAppend(&query, "%s", "SELECT datname FROM pg_catalog.pg_database")) { - return ADBC_STATUS_INTERNAL; + Status LoadSchemas(std::string_view catalog, + std::optional schema_filter) override { + // PostgreSQL can only list for the current database + if (catalog != current_database_) { + return Status::Ok(); } - std::vector params; - if (catalog_ != NULL) { - if (StringBuilderAppend(&query, "%s", " WHERE datname = $1")) { - StringBuilderReset(&query); - return ADBC_STATUS_INTERNAL; - } - params.push_back(catalog_); + if (schema_filter.has_value()) { + UNWRAP_STATUS(some_schemas_.Execute({std::string(*schema_filter)})); + next_schema_ = some_schemas_.Row(-1); + } else { + UNWRAP_STATUS(all_schemas_.Execute()); + next_schema_ = all_schemas_.Row(-1); } + return Status::Ok(); + }; - PqResultHelper result_helper = PqResultHelper{conn_, std::string(query.buffer)}; - StringBuilderReset(&query); - - RAISE_ADBC(result_helper.Execute(error_, params)); - - for (PqResultRow row : result_helper) { - const char* db_name = row[0].data; - CHECK_NA(INTERNAL, - ArrowArrayAppendString(catalog_name_col_, ArrowCharView(db_name)), error_); - if (depth_ == ADBC_OBJECT_DEPTH_CATALOGS) { - CHECK_NA(INTERNAL, ArrowArrayAppendNull(catalog_db_schemas_col_, 1), error_); - } else { - RAISE_ADBC(AppendSchemas(std::string(db_name))); - } - CHECK_NA(INTERNAL, ArrowArrayFinishElement(array_), error_); + Result> NextSchema() override { + next_schema_ = next_schema_.Next(); + if (!next_schema_.IsValid()) { + return std::nullopt; } - return ADBC_STATUS_OK; + return next_schema_[0].value(); } - AdbcStatusCode AppendTables(std::string schema_name) { - struct StringBuilder query; - std::memset(&query, 0, sizeof(query)); - if (StringBuilderInit(&query, /*initial_size*/ 512)) { - return ADBC_STATUS_INTERNAL; - } + Status LoadTables(std::string_view catalog, std::string_view schema, + std::optional table_filter, + const std::vector& table_types) override { + std::string table_types_bind = TableTypesArrayLiteral(table_types); - std::vector params = {schema_name}; - const char* stmt = - "SELECT c.relname, CASE c.relkind WHEN 'r' THEN 'table' WHEN 'v' THEN 'view' " - "WHEN 'm' THEN 'materialized view' WHEN 't' THEN 'TOAST table' " - "WHEN 'f' THEN 'foreign table' WHEN 'p' THEN 'partitioned table' END " - "AS reltype FROM pg_catalog.pg_class c " - "LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace " - "WHERE c.relkind IN ('r','v','m','t','f','p') " - "AND pg_catalog.pg_table_is_visible(c.oid) AND n.nspname = $1"; - - if (StringBuilderAppend(&query, "%s", stmt)) { - StringBuilderReset(&query); - return ADBC_STATUS_INTERNAL; + if (table_filter.has_value()) { + UNWRAP_STATUS(some_tables_.Execute( + {std::string(schema), table_types_bind, std::string(*table_filter)})); + next_table_ = some_tables_.Row(-1); + } else { + UNWRAP_STATUS(all_tables_.Execute({std::string(schema), table_types_bind})); + next_table_ = all_tables_.Row(-1); } - if (table_name_ != nullptr) { - if (StringBuilderAppend(&query, "%s", " AND c.relname LIKE $2")) { - StringBuilderReset(&query); - return ADBC_STATUS_INTERNAL; - } + return Status::Ok(); + }; - params.push_back(std::string(table_name_)); + Result> NextTable() override { + next_table_ = next_table_.Next(); + if (!next_table_.IsValid()) { + return std::nullopt; } - if (table_types_ != nullptr) { - std::vector table_type_filter; - const char** table_types = table_types_; - while (*table_types != NULL) { - auto table_type_str = std::string(*table_types); - auto search = kPgTableTypes.find(table_type_str); - if (search != kPgTableTypes.end()) { - table_type_filter.push_back(search->second); - } - table_types++; - } + return Table{next_table_[0].value(), next_table_[1].value()}; + } - if (!table_type_filter.empty()) { - std::ostringstream oss; - bool first = true; - oss << "("; - for (const auto& str : table_type_filter) { - if (!first) { - oss << ", "; - } - oss << "'" << str << "'"; - first = false; - } - oss << ")"; + Status LoadColumns(std::string_view catalog, std::string_view schema, + std::string_view table, + std::optional column_filter) override { + if (column_filter.has_value()) { + UNWRAP_STATUS(some_columns_.Execute( + {std::string(schema), std::string(table), std::string(*column_filter)})); + next_column_ = some_columns_.Row(-1); + } else { + UNWRAP_STATUS(all_columns_.Execute({std::string(schema), std::string(table)})); + next_column_ = all_columns_.Row(-1); + } - if (StringBuilderAppend(&query, "%s%s", " AND c.relkind IN ", - oss.str().c_str())) { - StringBuilderReset(&query); - return ADBC_STATUS_INTERNAL; - } + if (enable_constraints_) { + if (column_filter.has_value()) { + UNWRAP_STATUS(some_constraints_.Execute( + {std::string(schema), std::string(table), std::string(*column_filter)})) + next_constraint_ = some_constraints_.Row(-1); } else { - // no matching table type means no records should come back - if (StringBuilderAppend(&query, "%s", " AND false")) { - StringBuilderReset(&query); - return ADBC_STATUS_INTERNAL; - } + UNWRAP_STATUS( + all_constraints_.Execute({std::string(schema), std::string(table)})); + next_constraint_ = all_constraints_.Row(-1); } } - auto result_helper = PqResultHelper{conn_, query.buffer}; - StringBuilderReset(&query); + return Status::Ok(); + }; - RAISE_ADBC(result_helper.Execute(error_, params)); - for (PqResultRow row : result_helper) { - const char* table_name = row[0].data; - const char* table_type = row[1].data; + Result> NextColumn() override { + next_column_ = next_column_.Next(); + if (!next_column_.IsValid()) { + return std::nullopt; + } - CHECK_NA(INTERNAL, - ArrowArrayAppendString(table_name_col_, ArrowCharView(table_name)), - error_); - CHECK_NA(INTERNAL, - ArrowArrayAppendString(table_type_col_, ArrowCharView(table_type)), - error_); - if (depth_ == ADBC_OBJECT_DEPTH_TABLES) { - CHECK_NA(INTERNAL, ArrowArrayAppendNull(table_columns_col_, 1), error_); - CHECK_NA(INTERNAL, ArrowArrayAppendNull(table_constraints_col_, 1), error_); - } else { - auto table_name_s = std::string(table_name); - RAISE_ADBC(AppendColumns(schema_name, table_name_s)); - RAISE_ADBC(AppendConstraints(schema_name, table_name_s)); - } - CHECK_NA(INTERNAL, ArrowArrayFinishElement(schema_table_items_), error_); + Column col; + col.column_name = next_column_[0].value(); + UNWRAP_RESULT(col.ordinal_position, next_column_[1].ParseInteger()); + if (!next_column_[2].is_null) { + col.remarks = next_column_[2].value(); } - CHECK_NA(INTERNAL, ArrowArrayFinishElement(db_schema_tables_col_), error_); - return ADBC_STATUS_OK; + return col; } - AdbcStatusCode AppendColumns(std::string schema_name, std::string table_name) { - struct StringBuilder query; - std::memset(&query, 0, sizeof(query)); - if (StringBuilderInit(&query, /*initial_size*/ 512)) { - return ADBC_STATUS_INTERNAL; - } - - std::vector params = {schema_name, table_name}; - const char* stmt = - "SELECT attr.attname, attr.attnum, " - "pg_catalog.col_description(cls.oid, attr.attnum) " - "FROM pg_catalog.pg_attribute AS attr " - "INNER JOIN pg_catalog.pg_class AS cls ON attr.attrelid = cls.oid " - "INNER JOIN pg_catalog.pg_namespace AS nsp ON nsp.oid = cls.relnamespace " - "WHERE attr.attnum > 0 AND NOT attr.attisdropped " - "AND nsp.nspname LIKE $1 AND cls.relname LIKE $2"; - - if (StringBuilderAppend(&query, "%s", stmt)) { - StringBuilderReset(&query); - return ADBC_STATUS_INTERNAL; + Result> NextConstraint() override { + next_constraint_ = next_constraint_.Next(); + if (!next_constraint_.IsValid()) { + return std::nullopt; } - if (column_name_ != NULL) { - if (StringBuilderAppend(&query, "%s", " AND attr.attname LIKE $3")) { - StringBuilderReset(&query); - return ADBC_STATUS_INTERNAL; - } + Constraint out; + out.name = next_constraint_[0].data; + out.type = next_constraint_[1].data; - params.push_back(std::string(column_name_)); + UNWRAP_RESULT(constraint_fcolumn_names_, next_constraint_[2].ParseTextArray()); + std::vector fcolumn_names_view; + for (const std::string& item : constraint_fcolumn_names_) { + fcolumn_names_view.push_back(item); } + out.column_names = std::move(fcolumn_names_view); - auto result_helper = PqResultHelper{conn_, query.buffer}; - StringBuilderReset(&query); + if (out.type == "FOREIGN KEY") { + assert(!next_constraint_[3].is_null); + assert(!next_constraint_[3].is_null); + assert(!next_constraint_[4].is_null); + assert(!next_constraint_[5].is_null); - RAISE_ADBC(result_helper.Execute(error_, params)); + out.usage = std::vector(); + UNWRAP_RESULT(constraint_fkey_names_, next_constraint_[5].ParseTextArray()); - for (PqResultRow row : result_helper) { - const char* column_name = row[0].data; - const char* position = row[1].data; + for (const auto& item : constraint_fkey_names_) { + ConstraintUsage usage; + usage.catalog = current_database_; + usage.schema = next_constraint_[3].data; + usage.table = next_constraint_[4].data; + usage.column = item; - CHECK_NA(INTERNAL, - ArrowArrayAppendString(column_name_col_, ArrowCharView(column_name)), - error_); - int ival = atol(position); - CHECK_NA(INTERNAL, - ArrowArrayAppendInt(column_position_col_, static_cast(ival)), - error_); - if (row[2].is_null) { - CHECK_NA(INTERNAL, ArrowArrayAppendNull(column_remarks_col_, 1), error_); - } else { - const char* remarks = row[2].data; - CHECK_NA(INTERNAL, - ArrowArrayAppendString(column_remarks_col_, ArrowCharView(remarks)), - error_); - } - - // no xdbc_ values for now - for (auto i = 3; i < 19; i++) { - CHECK_NA(INTERNAL, ArrowArrayAppendNull(table_columns_items_->children[i], 1), - error_); + out.usage->push_back(usage); } - - CHECK_NA(INTERNAL, ArrowArrayFinishElement(table_columns_items_), error_); } - CHECK_NA(INTERNAL, ArrowArrayFinishElement(table_columns_col_), error_); - return ADBC_STATUS_OK; + return out; } - // libpq PQexecParams can use either text or binary transfers - // For now we are using text transfer internally, so arrays are sent - // back like {element1, element2} within a const char* - std::vector PqTextArrayToVector(std::string text_array) { - text_array.erase(0, 1); - text_array.erase(text_array.size() - 1); - - std::vector elements; - std::stringstream ss(std::move(text_array)); - std::string tmp; - - while (getline(ss, tmp, ',')) { - elements.push_back(std::move(tmp)); - } - - return elements; + private: + std::string current_database_; + + // Ready-to-Execute() queries + PqResultHelper all_catalogs_; + PqResultHelper some_catalogs_; + PqResultHelper all_schemas_; + PqResultHelper some_schemas_; + PqResultHelper all_tables_; + PqResultHelper some_tables_; + PqResultHelper all_columns_; + PqResultHelper some_columns_; + PqResultHelper all_constraints_; + PqResultHelper some_constraints_; + + // On Redshift, the constraints query fails + bool enable_constraints_{true}; + + // Iterator state for the catalogs/schema/table/column queries + PqResultRow next_catalog_; + PqResultRow next_schema_; + PqResultRow next_table_; + PqResultRow next_column_; + PqResultRow next_constraint_; + + // Owning variants required because the framework versions of these + // are all based on string_view and the result helper can only parse arrays + // into std::vector. + std::vector constraint_fcolumn_names_; + std::vector constraint_fkey_names_; + + // Queries that are slightly modified versions of the generic queries that allow + // the filter for that level to be passed through as a parameter. Defined here + // because global strings should be const char* according to cpplint and using + // the + operator to concatenate them is the most concise way to construct them. + + // Parameterized on catalog_name + static std::string CatalogQuery() { + return std::string(kCatalogQueryAll) + " WHERE datname = $1"; } - AdbcStatusCode AppendConstraints(std::string schema_name, std::string table_name) { - struct StringBuilder query; - std::memset(&query, 0, sizeof(query)); - if (StringBuilderInit(&query, /*initial_size*/ 4096)) { - return ADBC_STATUS_INTERNAL; - } - - std::vector params = {schema_name, table_name}; - const char* stmt = - "WITH fk_unnest AS ( " - " SELECT " - " con.conname, " - " 'FOREIGN KEY' AS contype, " - " conrelid, " - " UNNEST(con.conkey) AS conkey, " - " confrelid, " - " UNNEST(con.confkey) AS confkey " - " FROM pg_catalog.pg_constraint AS con " - " INNER JOIN pg_catalog.pg_class AS cls ON cls.oid = conrelid " - " INNER JOIN pg_catalog.pg_namespace AS nsp ON nsp.oid = cls.relnamespace " - " WHERE con.contype = 'f' AND nsp.nspname LIKE $1 " - " AND cls.relname LIKE $2 " - "), " - "fk_names AS ( " - " SELECT " - " fk_unnest.conname, " - " fk_unnest.contype, " - " fk_unnest.conkey, " - " fk_unnest.confkey, " - " attr.attname, " - " fnsp.nspname AS fschema, " - " fcls.relname AS ftable, " - " fattr.attname AS fattname " - " FROM fk_unnest " - " INNER JOIN pg_catalog.pg_class AS cls ON cls.oid = fk_unnest.conrelid " - " INNER JOIN pg_catalog.pg_class AS fcls ON fcls.oid = fk_unnest.confrelid " - " INNER JOIN pg_catalog.pg_namespace AS fnsp ON fnsp.oid = fcls.relnamespace" - " INNER JOIN pg_catalog.pg_attribute AS attr ON attr.attnum = " - "fk_unnest.conkey " - " AND attr.attrelid = fk_unnest.conrelid " - " LEFT JOIN pg_catalog.pg_attribute AS fattr ON fattr.attnum = " - "fk_unnest.confkey " - " AND fattr.attrelid = fk_unnest.confrelid " - "), " - "fkeys AS ( " - " SELECT " - " conname, " - " contype, " - " ARRAY_AGG(attname ORDER BY conkey) AS colnames, " - " fschema, " - " ftable, " - " ARRAY_AGG(fattname ORDER BY confkey) AS fcolnames " - " FROM fk_names " - " GROUP BY " - " conname, " - " contype, " - " fschema, " - " ftable " - "), " - "other_constraints AS ( " - " SELECT con.conname, CASE con.contype WHEN 'c' THEN 'CHECK' WHEN 'u' THEN " - " 'UNIQUE' WHEN 'p' THEN 'PRIMARY KEY' END AS contype, " - " ARRAY_AGG(attr.attname) AS colnames " - " FROM pg_catalog.pg_constraint AS con " - " CROSS JOIN UNNEST(conkey) AS conkeys " - " INNER JOIN pg_catalog.pg_class AS cls ON cls.oid = con.conrelid " - " INNER JOIN pg_catalog.pg_namespace AS nsp ON nsp.oid = cls.relnamespace " - " INNER JOIN pg_catalog.pg_attribute AS attr ON attr.attnum = conkeys " - " AND cls.oid = attr.attrelid " - " WHERE con.contype IN ('c', 'u', 'p') AND nsp.nspname LIKE $1 " - " AND cls.relname LIKE $2 " - " GROUP BY conname, contype " - ") " - "SELECT " - " conname, contype, colnames, fschema, ftable, fcolnames " - "FROM fkeys " - "UNION ALL " - "SELECT " - " conname, contype, colnames, NULL, NULL, NULL " - "FROM other_constraints"; - - if (StringBuilderAppend(&query, "%s", stmt)) { - StringBuilderReset(&query); - return ADBC_STATUS_INTERNAL; - } - - if (column_name_ != NULL) { - if (StringBuilderAppend(&query, "%s", " WHERE conname LIKE $3")) { - StringBuilderReset(&query); - return ADBC_STATUS_INTERNAL; - } - - params.push_back(std::string(column_name_)); - } + // Parameterized on schema_name + static std::string SchemaQuery() { + return std::string(kSchemaQueryAll) + " AND nspname = $1"; + } - auto result_helper = PqResultHelper{conn_, query.buffer}; - StringBuilderReset(&query); + // Parameterized on schema_name, relkind, table_name + static std::string TablesQuery() { + return std::string(kTablesQueryAll) + " AND c.relname LIKE $3"; + } - RAISE_ADBC(result_helper.Execute(error_, params)); + // Parameterized on schema_name, table_name, column_name + static std::string ColumnsQuery() { + return std::string(kColumnsQueryAll) + " AND attr.attname LIKE $3"; + } - for (PqResultRow row : result_helper) { - const char* constraint_name = row[0].data; - const char* constraint_type = row[1].data; + // Parameterized on schema_name, table_name, column_name + static std::string ConstraintsQuery() { + return std::string(kConstraintsQueryAll) + " WHERE conname LIKE $3"; + } - CHECK_NA( - INTERNAL, - ArrowArrayAppendString(constraint_name_col_, ArrowCharView(constraint_name)), - error_); + std::string TableTypesArrayLiteral(const std::vector& table_types) { + std::stringstream table_types_bind; + table_types_bind << "{"; + int table_types_bind_len = 0; - CHECK_NA( - INTERNAL, - ArrowArrayAppendString(constraint_type_col_, ArrowCharView(constraint_type)), - error_); + if (table_types.empty()) { + for (const auto& item : kPgTableTypes) { + if (table_types_bind_len > 0) { + table_types_bind << ", "; + } - auto constraint_column_names = PqTextArrayToVector(std::string(row[2].data)); - for (const auto& constraint_column_name : constraint_column_names) { - CHECK_NA(INTERNAL, - ArrowArrayAppendString(constraint_column_name_col_, - ArrowCharView(constraint_column_name.c_str())), - error_); + table_types_bind << "\"" << item.second << "\""; + table_types_bind_len++; } - CHECK_NA(INTERNAL, ArrowArrayFinishElement(constraint_column_names_col_), error_); - - if (!strcmp(constraint_type, "FOREIGN KEY")) { - assert(!row[3].is_null); - assert(!row[4].is_null); - assert(!row[5].is_null); - - const char* constraint_ftable_schema = row[3].data; - const char* constraint_ftable_name = row[4].data; - auto constraint_fcolumn_names = PqTextArrayToVector(std::string(row[5].data)); - for (const auto& constraint_fcolumn_name : constraint_fcolumn_names) { - CHECK_NA(INTERNAL, - ArrowArrayAppendString(fk_catalog_col_, ArrowCharView(PQdb(conn_))), - error_); - CHECK_NA(INTERNAL, - ArrowArrayAppendString(fk_db_schema_col_, - ArrowCharView(constraint_ftable_schema)), - error_); - CHECK_NA(INTERNAL, - ArrowArrayAppendString(fk_table_col_, - ArrowCharView(constraint_ftable_name)), - error_); - CHECK_NA(INTERNAL, - ArrowArrayAppendString(fk_column_name_col_, - ArrowCharView(constraint_fcolumn_name.c_str())), - error_); - - CHECK_NA(INTERNAL, ArrowArrayFinishElement(constraint_column_usage_items_), - error_); + } else { + for (auto type : table_types) { + const auto maybe_item = kPgTableTypes.find(std::string(type)); + if (maybe_item == kPgTableTypes.end()) { + continue; } - } - CHECK_NA(INTERNAL, ArrowArrayFinishElement(constraint_column_usages_col_), error_); - CHECK_NA(INTERNAL, ArrowArrayFinishElement(table_constraints_items_), error_); - } - CHECK_NA(INTERNAL, ArrowArrayFinishElement(table_constraints_col_), error_); - return ADBC_STATUS_OK; - } + if (table_types_bind_len > 0) { + table_types_bind << ", "; + } - AdbcStatusCode FinishArrowArray() { - CHECK_NA_DETAIL(INTERNAL, ArrowArrayFinishBuildingDefault(array_, &na_error_), - &na_error_, error_); + table_types_bind << "\"" << maybe_item->second << "\""; + table_types_bind_len++; + } + } - return ADBC_STATUS_OK; + table_types_bind << "}"; + return table_types_bind.str(); } - - PGconn* conn_ = nullptr; - int depth_; - const char* catalog_ = nullptr; - const char* db_schema_ = nullptr; - const char* table_name_ = nullptr; - const char** table_types_ = nullptr; - const char* column_name_ = nullptr; - struct ArrowSchema* schema_ = nullptr; - struct ArrowArray* array_ = nullptr; - struct AdbcError* error_ = nullptr; - struct ArrowError na_error_; - struct ArrowArray* catalog_name_col_ = nullptr; - struct ArrowArray* catalog_db_schemas_col_ = nullptr; - struct ArrowArray* catalog_db_schemas_items_ = nullptr; - struct ArrowArray* db_schema_name_col_ = nullptr; - struct ArrowArray* db_schema_tables_col_ = nullptr; - struct ArrowArray* schema_table_items_ = nullptr; - struct ArrowArray* table_name_col_ = nullptr; - struct ArrowArray* table_type_col_ = nullptr; - struct ArrowArray* table_columns_col_ = nullptr; - struct ArrowArray* table_columns_items_ = nullptr; - struct ArrowArray* column_name_col_ = nullptr; - struct ArrowArray* column_position_col_ = nullptr; - struct ArrowArray* column_remarks_col_ = nullptr; - struct ArrowArray* table_constraints_col_ = nullptr; - struct ArrowArray* table_constraints_items_ = nullptr; - struct ArrowArray* constraint_name_col_ = nullptr; - struct ArrowArray* constraint_type_col_ = nullptr; - struct ArrowArray* constraint_column_names_col_ = nullptr; - struct ArrowArray* constraint_column_name_col_ = nullptr; - struct ArrowArray* constraint_column_usages_col_ = nullptr; - struct ArrowArray* constraint_column_usage_items_ = nullptr; - struct ArrowArray* fk_catalog_col_ = nullptr; - struct ArrowArray* fk_db_schema_col_ = nullptr; - struct ArrowArray* fk_table_col_ = nullptr; - struct ArrowArray* fk_column_name_col_ = nullptr; }; // A notice processor that does nothing with notices. In the future we can log @@ -634,116 +481,120 @@ AdbcStatusCode PostgresConnection::Commit(struct AdbcError* error) { return ADBC_STATUS_OK; } -AdbcStatusCode PostgresConnection::PostgresConnectionGetInfoImpl( - const uint32_t* info_codes, size_t info_codes_length, struct ArrowSchema* schema, - struct ArrowArray* array, struct AdbcError* error) { - RAISE_ADBC(adbc::driver::AdbcInitConnectionGetInfoSchema(schema, array).ToAdbc(error)); +AdbcStatusCode PostgresConnection::GetInfo(struct AdbcConnection* connection, + const uint32_t* info_codes, + size_t info_codes_length, + struct ArrowArrayStream* out, + struct AdbcError* error) { + if (!info_codes) { + info_codes = kSupportedInfoCodes; + info_codes_length = sizeof(kSupportedInfoCodes) / sizeof(kSupportedInfoCodes[0]); + } + + std::vector infos; for (size_t i = 0; i < info_codes_length; i++) { switch (info_codes[i]) { case ADBC_INFO_VENDOR_NAME: - RAISE_ADBC(adbc::driver::AdbcConnectionGetInfoAppendString(array, info_codes[i], - "PostgreSQL") - .ToAdbc(error)); + infos.push_back({info_codes[i], std::string(VendorName())}); break; case ADBC_INFO_VENDOR_VERSION: { - const char* stmt = "SHOW server_version_num"; - auto result_helper = PqResultHelper{conn_, std::string(stmt)}; - RAISE_ADBC(result_helper.Execute(error)); - auto it = result_helper.begin(); - if (it == result_helper.end()) { - SetError(error, "[libpq] PostgreSQL returned no rows for '%s'", stmt); - return ADBC_STATUS_INTERNAL; + if (VendorName() == "Redshift") { + const std::array& version = VendorVersion(); + std::string version_string = std::to_string(version[0]) + "." + + std::to_string(version[1]) + "." + + std::to_string(version[2]); + infos.push_back({info_codes[i], std::move(version_string)}); + + } else { + // Gives a version in the form 140000 instead of 14.0.0 + const char* stmt = "SHOW server_version_num"; + auto result_helper = PqResultHelper{conn_, std::string(stmt)}; + RAISE_STATUS(error, result_helper.Execute()); + auto it = result_helper.begin(); + if (it == result_helper.end()) { + SetError(error, "[libpq] PostgreSQL returned no rows for '%s'", stmt); + return ADBC_STATUS_INTERNAL; + } + const char* server_version_num = (*it)[0].data; + infos.push_back({info_codes[i], server_version_num}); } - const char* server_version_num = (*it)[0].data; - RAISE_ADBC(adbc::driver::AdbcConnectionGetInfoAppendString(array, info_codes[i], - server_version_num) - .ToAdbc(error)); break; } case ADBC_INFO_DRIVER_NAME: - RAISE_ADBC(adbc::driver::AdbcConnectionGetInfoAppendString( - array, info_codes[i], "ADBC PostgreSQL Driver") - .ToAdbc(error)); + infos.push_back({info_codes[i], "ADBC PostgreSQL Driver"}); break; case ADBC_INFO_DRIVER_VERSION: // TODO(lidavidm): fill in driver version - RAISE_ADBC(adbc::driver::AdbcConnectionGetInfoAppendString(array, info_codes[i], - "(unknown)") - .ToAdbc(error)); + infos.push_back({info_codes[i], "(unknown)"}); break; case ADBC_INFO_DRIVER_ARROW_VERSION: - RAISE_ADBC(adbc::driver::AdbcConnectionGetInfoAppendString(array, info_codes[i], - NANOARROW_VERSION) - .ToAdbc(error)); + infos.push_back({info_codes[i], NANOARROW_VERSION}); break; case ADBC_INFO_DRIVER_ADBC_VERSION: - RAISE_ADBC(adbc::driver::AdbcConnectionGetInfoAppendInt(array, info_codes[i], - ADBC_VERSION_1_1_0) - .ToAdbc(error)); + infos.push_back({info_codes[i], ADBC_VERSION_1_1_0}); break; default: // Ignore continue; } - CHECK_NA(INTERNAL, ArrowArrayFinishElement(array), error); } - struct ArrowError na_error = {0}; - CHECK_NA_DETAIL(INTERNAL, ArrowArrayFinishBuildingDefault(array, &na_error), &na_error, - error); - + RAISE_ADBC(adbc::driver::MakeGetInfoStream(infos, out).ToAdbc(error)); return ADBC_STATUS_OK; } -AdbcStatusCode PostgresConnection::GetInfo(struct AdbcConnection* connection, - const uint32_t* info_codes, - size_t info_codes_length, - struct ArrowArrayStream* out, - struct AdbcError* error) { - if (!info_codes) { - info_codes = kSupportedInfoCodes; - info_codes_length = sizeof(kSupportedInfoCodes) / sizeof(kSupportedInfoCodes[0]); - } - - struct ArrowSchema schema; - std::memset(&schema, 0, sizeof(schema)); - struct ArrowArray array; - std::memset(&array, 0, sizeof(array)); - - AdbcStatusCode status = PostgresConnectionGetInfoImpl(info_codes, info_codes_length, - &schema, &array, error); - if (status != ADBC_STATUS_OK) { - if (schema.release) schema.release(&schema); - if (array.release) array.release(&array); - return status; - } - - return BatchToArrayStream(&array, &schema, out, error); -} - AdbcStatusCode PostgresConnection::GetObjects( - struct AdbcConnection* connection, int depth, const char* catalog, - const char* db_schema, const char* table_name, const char** table_types, + struct AdbcConnection* connection, int c_depth, const char* catalog, + const char* db_schema, const char* table_name, const char** table_type, const char* column_name, struct ArrowArrayStream* out, struct AdbcError* error) { - struct ArrowSchema schema; - std::memset(&schema, 0, sizeof(schema)); - struct ArrowArray array; - std::memset(&array, 0, sizeof(array)); + PostgresGetObjectsHelper helper(conn_); + helper.SetEnableConstraints(VendorName() != "Redshift"); + + const auto catalog_filter = + catalog ? std::make_optional(std::string_view(catalog)) : std::nullopt; + const auto schema_filter = + db_schema ? std::make_optional(std::string_view(db_schema)) : std::nullopt; + const auto table_filter = + table_name ? std::make_optional(std::string_view(table_name)) : std::nullopt; + const auto column_filter = + column_name ? std::make_optional(std::string_view(column_name)) : std::nullopt; + std::vector table_type_filter; + while (table_type && *table_type) { + if (*table_type) { + table_type_filter.push_back(std::string_view(*table_type)); + } + table_type++; + } - PqGetObjectsHelper helper = - PqGetObjectsHelper(conn_, depth, catalog, db_schema, table_name, table_types, - column_name, &schema, &array, error); - AdbcStatusCode status = helper.GetObjects(); + using adbc::driver::GetObjectsDepth; - if (status != ADBC_STATUS_OK) { - if (schema.release) schema.release(&schema); - if (array.release) array.release(&array); - return status; + GetObjectsDepth depth = GetObjectsDepth::kColumns; + switch (c_depth) { + case ADBC_OBJECT_DEPTH_CATALOGS: + depth = GetObjectsDepth::kCatalogs; + break; + case ADBC_OBJECT_DEPTH_COLUMNS: + depth = GetObjectsDepth::kColumns; + break; + case ADBC_OBJECT_DEPTH_DB_SCHEMAS: + depth = GetObjectsDepth::kSchemas; + break; + case ADBC_OBJECT_DEPTH_TABLES: + depth = GetObjectsDepth::kTables; + break; + default: + return Status::InvalidArgument("[libpq] GetObjects: invalid depth ", c_depth) + .ToAdbc(error); } - return BatchToArrayStream(&array, &schema, out, error); + auto status = BuildGetObjects(&helper, depth, catalog_filter, schema_filter, + table_filter, column_filter, table_type_filter, out); + RAISE_STATUS(error, helper.Close()); + RAISE_STATUS(error, status); + + return ADBC_STATUS_OK; } AdbcStatusCode PostgresConnection::GetOption(const char* option, char* value, @@ -752,11 +603,12 @@ AdbcStatusCode PostgresConnection::GetOption(const char* option, char* value, if (std::strcmp(option, ADBC_CONNECTION_OPTION_CURRENT_CATALOG) == 0) { output = PQdb(conn_); } else if (std::strcmp(option, ADBC_CONNECTION_OPTION_CURRENT_DB_SCHEMA) == 0) { - PqResultHelper result_helper{conn_, "SELECT CURRENT_SCHEMA"}; - RAISE_ADBC(result_helper.Execute(error)); + PqResultHelper result_helper{conn_, "SELECT CURRENT_SCHEMA()"}; + RAISE_STATUS(error, result_helper.Execute()); auto it = result_helper.begin(); if (it == result_helper.end()) { - SetError(error, "[libpq] PostgreSQL returned no rows for 'SELECT CURRENT_SCHEMA'"); + SetError(error, + "[libpq] PostgreSQL returned no rows for 'SELECT CURRENT_SCHEMA()'"); return ADBC_STATUS_INTERNAL; } output = (*it)[0].data; @@ -923,7 +775,8 @@ AdbcStatusCode PostgresConnectionGetStatisticsImpl(PGconn* conn, const char* db_ { PqResultHelper result_helper{conn, query}; - RAISE_ADBC(result_helper.Execute(error, {db_schema, table_name ? table_name : "%"})); + RAISE_STATUS(error, + result_helper.Execute({db_schema, table_name ? table_name : "%"})); for (PqResultRow row : result_helper) { auto reltuples = row[5].ParseDouble(); @@ -1076,7 +929,8 @@ AdbcStatusCode PostgresConnection::GetStatistics(const char* catalog, return status; } - return BatchToArrayStream(&array, &schema, out, error); + adbc::driver::MakeArrayStream(&schema, &array, out); + return ADBC_STATUS_OK; } AdbcStatusCode PostgresConnectionGetStatisticNamesImpl(struct ArrowSchema* schema, @@ -1125,7 +979,9 @@ AdbcStatusCode PostgresConnection::GetStatisticNames(struct ArrowArrayStream* ou if (array.release) array.release(&array); return status; } - return BatchToArrayStream(&array, &schema, out, error); + + adbc::driver::MakeArrayStream(&schema, &array, out); + return ADBC_STATUS_OK; } AdbcStatusCode PostgresConnection::GetTableSchema(const char* catalog, @@ -1157,21 +1013,13 @@ AdbcStatusCode PostgresConnection::GetTableSchema(const char* catalog, PqResultHelper result_helper = PqResultHelper{conn_, std::string(query.c_str())}; - auto result = result_helper.Execute(error, params); - if (result != ADBC_STATUS_OK) { - auto error_code = std::string(error->sqlstate, 5); - if ((error_code == "42P01") || (error_code == "42602")) { - return ADBC_STATUS_NOT_FOUND; - } - return result; - } + RAISE_STATUS(error, result_helper.Execute(params)); auto uschema = nanoarrow::UniqueSchema(); ArrowSchemaInit(uschema.get()); CHECK_NA(INTERNAL, ArrowSchemaSetTypeStruct(uschema.get(), result_helper.NumRows()), error); - ArrowError na_error; int row_counter = 0; for (auto row : result_helper) { const char* colname = row[0].data; @@ -1179,14 +1027,15 @@ AdbcStatusCode PostgresConnection::GetTableSchema(const char* catalog, static_cast(std::strtol(row[1].data, /*str_end=*/nullptr, /*base=*/10)); PostgresType pg_type; - if (type_resolver_->Find(pg_oid, &pg_type, &na_error) != NANOARROW_OK) { - SetError(error, "%s%d%s%s%s%" PRIu32, "Column #", row_counter + 1, " (\"", colname, - "\") has unknown type code ", pg_oid); + if (type_resolver_->FindWithDefault(pg_oid, &pg_type) != NANOARROW_OK) { + SetError(error, "%s%d%s%s%s%" PRIu32, "Error resolving type code for column #", + row_counter + 1, " (\"", colname, "\") with oid ", pg_oid); final_status = ADBC_STATUS_NOT_IMPLEMENTED; break; } CHECK_NA(INTERNAL, - pg_type.WithFieldName(colname).SetSchema(uschema->children[row_counter]), + pg_type.WithFieldName(colname).SetSchema(uschema->children[row_counter], + std::string(VendorName())), error); row_counter++; } @@ -1195,54 +1044,17 @@ AdbcStatusCode PostgresConnection::GetTableSchema(const char* catalog, return final_status; } -AdbcStatusCode PostgresConnectionGetTableTypesImpl(struct ArrowSchema* schema, - struct ArrowArray* array, - struct AdbcError* error) { - // See 'relkind' in https://www.postgresql.org/docs/current/catalog-pg-class.html - auto uschema = nanoarrow::UniqueSchema(); - ArrowSchemaInit(uschema.get()); - - CHECK_NA(INTERNAL, ArrowSchemaSetType(uschema.get(), NANOARROW_TYPE_STRUCT), error); - CHECK_NA(INTERNAL, ArrowSchemaAllocateChildren(uschema.get(), /*num_columns=*/1), - error); - ArrowSchemaInit(uschema.get()->children[0]); - CHECK_NA(INTERNAL, - ArrowSchemaSetType(uschema.get()->children[0], NANOARROW_TYPE_STRING), error); - CHECK_NA(INTERNAL, ArrowSchemaSetName(uschema.get()->children[0], "table_type"), error); - uschema.get()->children[0]->flags &= ~ARROW_FLAG_NULLABLE; - - CHECK_NA(INTERNAL, ArrowArrayInitFromSchema(array, uschema.get(), NULL), error); - CHECK_NA(INTERNAL, ArrowArrayStartAppending(array), error); - - for (auto const& table_type : kPgTableTypes) { - CHECK_NA(INTERNAL, - ArrowArrayAppendString(array->children[0], - ArrowCharView(table_type.first.c_str())), - error); - CHECK_NA(INTERNAL, ArrowArrayFinishElement(array), error); - } - - CHECK_NA(INTERNAL, ArrowArrayFinishBuildingDefault(array, NULL), error); - - uschema.move(schema); - return ADBC_STATUS_OK; -} - AdbcStatusCode PostgresConnection::GetTableTypes(struct AdbcConnection* connection, struct ArrowArrayStream* out, struct AdbcError* error) { - struct ArrowSchema schema; - std::memset(&schema, 0, sizeof(schema)); - struct ArrowArray array; - std::memset(&array, 0, sizeof(array)); - - AdbcStatusCode status = PostgresConnectionGetTableTypesImpl(&schema, &array, error); - if (status != ADBC_STATUS_OK) { - if (schema.release) schema.release(&schema); - if (array.release) array.release(&array); - return status; + std::vector table_types; + table_types.reserve(kPgTableTypes.size()); + for (auto const& table_type : kPgTableTypes) { + table_types.push_back(table_type.first); } - return BatchToArrayStream(&array, &schema, out, error); + + RAISE_STATUS(error, adbc::driver::MakeTableTypesStream(table_types, out)); + return ADBC_STATUS_OK; } AdbcStatusCode PostgresConnection::Init(struct AdbcDatabase* database, @@ -1324,8 +1136,12 @@ AdbcStatusCode PostgresConnection::SetOption(const char* key, const char* value, return ADBC_STATUS_OK; } else if (std::strcmp(key, ADBC_CONNECTION_OPTION_CURRENT_DB_SCHEMA) == 0) { // PostgreSQL doesn't accept a parameter here - PqResultHelper result_helper{conn_, std::string("SET search_path TO ") + value}; - RAISE_ADBC(result_helper.Execute(error)); + char* value_esc = PQescapeIdentifier(conn_, value, strlen(value)); + std::string query = std::string("SET search_path TO ") + value_esc; + PQfreemem(value_esc); + + PqResultHelper result_helper{conn_, query}; + RAISE_STATUS(error, result_helper.Execute()); return ADBC_STATUS_OK; } SetError(error, "%s%s", "[libpq] Unknown option ", key); @@ -1351,4 +1167,10 @@ AdbcStatusCode PostgresConnection::SetOptionInt(const char* key, int64_t value, return ADBC_STATUS_NOT_IMPLEMENTED; } +std::string_view PostgresConnection::VendorName() { return database_->VendorName(); } + +const std::array& PostgresConnection::VendorVersion() { + return database_->VendorVersion(); +} + } // namespace adbcpq diff --git a/3rd_party/apache-arrow-adbc/c/driver/postgresql/connection.h b/3rd_party/apache-arrow-adbc/c/driver/postgresql/connection.h index 2a3b59c..7683875 100644 --- a/3rd_party/apache-arrow-adbc/c/driver/postgresql/connection.h +++ b/3rd_party/apache-arrow-adbc/c/driver/postgresql/connection.h @@ -17,6 +17,7 @@ #pragma once +#include #include #include @@ -73,13 +74,10 @@ class PostgresConnection { return type_resolver_; } bool autocommit() const { return autocommit_; } + std::string_view VendorName(); + const std::array& VendorVersion(); private: - AdbcStatusCode PostgresConnectionGetInfoImpl(const uint32_t* info_codes, - size_t info_codes_length, - struct ArrowSchema* schema, - struct ArrowArray* array, - struct AdbcError* error); std::shared_ptr database_; std::shared_ptr type_resolver_; PGconn* conn_; diff --git a/3rd_party/apache-arrow-adbc/c/driver/postgresql/copy/postgres_copy_reader_test.cc b/3rd_party/apache-arrow-adbc/c/driver/postgresql/copy/postgres_copy_reader_test.cc index 60e0b6a..7b9fe23 100644 --- a/3rd_party/apache-arrow-adbc/c/driver/postgresql/copy/postgres_copy_reader_test.cc +++ b/3rd_party/apache-arrow-adbc/c/driver/postgresql/copy/postgres_copy_reader_test.cc @@ -27,7 +27,7 @@ class PostgresCopyStreamTester { public: ArrowErrorCode Init(const PostgresType& root_type, ArrowError* error = nullptr) { NANOARROW_RETURN_NOT_OK(reader_.Init(root_type)); - NANOARROW_RETURN_NOT_OK(reader_.InferOutputSchema(error)); + NANOARROW_RETURN_NOT_OK(reader_.InferOutputSchema("PostgreSQL Tester", error)); NANOARROW_RETURN_NOT_OK(reader_.InitFieldReaders(error)); return NANOARROW_OK; } diff --git a/3rd_party/apache-arrow-adbc/c/driver/postgresql/copy/postgres_copy_writer_test.cc b/3rd_party/apache-arrow-adbc/c/driver/postgresql/copy/postgres_copy_writer_test.cc index 618f27c..5010848 100644 --- a/3rd_party/apache-arrow-adbc/c/driver/postgresql/copy/postgres_copy_writer_test.cc +++ b/3rd_party/apache-arrow-adbc/c/driver/postgresql/copy/postgres_copy_writer_test.cc @@ -836,32 +836,16 @@ TEST_P(PostgresCopyListTest, PostgresCopyWriteListSmallInt) { adbc_validation::Handle array; struct ArrowError na_error; - ASSERT_EQ(ArrowSchemaInitFromType(&schema.value, NANOARROW_TYPE_STRUCT), NANOARROW_OK); - ASSERT_EQ(ArrowSchemaAllocateChildren(&schema.value, 1), NANOARROW_OK); - - ASSERT_EQ(ArrowSchemaInitFromType(schema->children[0], GetParam()), NANOARROW_OK); - ASSERT_EQ(ArrowSchemaSetName(schema->children[0], "col"), NANOARROW_OK); - ASSERT_EQ(ArrowSchemaSetType(schema->children[0]->children[0], NANOARROW_TYPE_INT16), - NANOARROW_OK); - - ASSERT_EQ(ArrowArrayInitFromSchema(&array.value, &schema.value, nullptr), NANOARROW_OK); - ASSERT_EQ(ArrowArrayStartAppending(&array.value), NANOARROW_OK); - - ASSERT_EQ(ArrowArrayAppendInt(array->children[0]->children[0], -123), NANOARROW_OK); - ASSERT_EQ(ArrowArrayAppendInt(array->children[0]->children[0], -1), NANOARROW_OK); - ASSERT_EQ(ArrowArrayFinishElement(array->children[0]), NANOARROW_OK); - ASSERT_EQ(ArrowArrayFinishElement(&array.value), NANOARROW_OK); - - ASSERT_EQ(ArrowArrayAppendInt(array->children[0]->children[0], 0), NANOARROW_OK); - ASSERT_EQ(ArrowArrayAppendInt(array->children[0]->children[0], 1), NANOARROW_OK); - ASSERT_EQ(ArrowArrayAppendInt(array->children[0]->children[0], 123), NANOARROW_OK); - ASSERT_EQ(ArrowArrayFinishElement(array->children[0]), NANOARROW_OK); - ASSERT_EQ(ArrowArrayFinishElement(&array.value), NANOARROW_OK); - - ASSERT_EQ(ArrowArrayAppendNull(array->children[0], 1), NANOARROW_OK); - ASSERT_EQ(ArrowArrayFinishElement(&array.value), NANOARROW_OK); + ASSERT_EQ(adbc_validation::MakeSchema( + &schema.value, {adbc_validation::SchemaField::Nested( + "col", GetParam(), {{"item", NANOARROW_TYPE_INT16}})}), + ADBC_STATUS_OK); - ASSERT_EQ(ArrowArrayFinishBuildingDefault(&array.value, &na_error), NANOARROW_OK); + ASSERT_EQ(adbc_validation::MakeBatch>( + &schema.value, &array.value, &na_error, + {std::vector{-123, -1}, std::vector{0, 1, 123}, + std::nullopt}), + ADBC_STATUS_OK); PostgresCopyStreamWriteTester tester; ASSERT_EQ(tester.Init(&schema.value, &array.value, *type_resolver_), NANOARROW_OK); @@ -882,32 +866,16 @@ TEST_P(PostgresCopyListTest, PostgresCopyWriteListInteger) { adbc_validation::Handle array; struct ArrowError na_error; - ASSERT_EQ(ArrowSchemaInitFromType(&schema.value, NANOARROW_TYPE_STRUCT), NANOARROW_OK); - ASSERT_EQ(ArrowSchemaAllocateChildren(&schema.value, 1), NANOARROW_OK); - - ASSERT_EQ(ArrowSchemaInitFromType(schema->children[0], GetParam()), NANOARROW_OK); - ASSERT_EQ(ArrowSchemaSetName(schema->children[0], "col"), NANOARROW_OK); - ASSERT_EQ(ArrowSchemaSetType(schema->children[0]->children[0], NANOARROW_TYPE_INT32), - NANOARROW_OK); - - ASSERT_EQ(ArrowArrayInitFromSchema(&array.value, &schema.value, nullptr), NANOARROW_OK); - ASSERT_EQ(ArrowArrayStartAppending(&array.value), NANOARROW_OK); - - ASSERT_EQ(ArrowArrayAppendInt(array->children[0]->children[0], -123), NANOARROW_OK); - ASSERT_EQ(ArrowArrayAppendInt(array->children[0]->children[0], -1), NANOARROW_OK); - ASSERT_EQ(ArrowArrayFinishElement(array->children[0]), NANOARROW_OK); - ASSERT_EQ(ArrowArrayFinishElement(&array.value), NANOARROW_OK); - - ASSERT_EQ(ArrowArrayAppendInt(array->children[0]->children[0], 0), NANOARROW_OK); - ASSERT_EQ(ArrowArrayAppendInt(array->children[0]->children[0], 1), NANOARROW_OK); - ASSERT_EQ(ArrowArrayAppendInt(array->children[0]->children[0], 123), NANOARROW_OK); - ASSERT_EQ(ArrowArrayFinishElement(array->children[0]), NANOARROW_OK); - ASSERT_EQ(ArrowArrayFinishElement(&array.value), NANOARROW_OK); - - ASSERT_EQ(ArrowArrayAppendNull(array->children[0], 1), NANOARROW_OK); - ASSERT_EQ(ArrowArrayFinishElement(&array.value), NANOARROW_OK); + ASSERT_EQ(adbc_validation::MakeSchema( + &schema.value, {adbc_validation::SchemaField::Nested( + "col", GetParam(), {{"item", NANOARROW_TYPE_INT32}})}), + ADBC_STATUS_OK); - ASSERT_EQ(ArrowArrayFinishBuildingDefault(&array.value, &na_error), NANOARROW_OK); + ASSERT_EQ(adbc_validation::MakeBatch>( + &schema.value, &array.value, &na_error, + {std::vector{-123, -1}, std::vector{0, 1, 123}, + std::nullopt}), + ADBC_STATUS_OK); PostgresCopyStreamWriteTester tester; ASSERT_EQ(tester.Init(&schema.value, &array.value, *type_resolver_), NANOARROW_OK); @@ -942,32 +910,16 @@ TEST_P(PostgresCopyListTest, PostgresCopyWriteListBigInt) { adbc_validation::Handle array; struct ArrowError na_error; - ASSERT_EQ(ArrowSchemaInitFromType(&schema.value, NANOARROW_TYPE_STRUCT), NANOARROW_OK); - ASSERT_EQ(ArrowSchemaAllocateChildren(&schema.value, 1), NANOARROW_OK); - - ASSERT_EQ(ArrowSchemaInitFromType(schema->children[0], GetParam()), NANOARROW_OK); - ASSERT_EQ(ArrowSchemaSetName(schema->children[0], "col"), NANOARROW_OK); - ASSERT_EQ(ArrowSchemaSetType(schema->children[0]->children[0], NANOARROW_TYPE_INT64), - NANOARROW_OK); - - ASSERT_EQ(ArrowArrayInitFromSchema(&array.value, &schema.value, nullptr), NANOARROW_OK); - ASSERT_EQ(ArrowArrayStartAppending(&array.value), NANOARROW_OK); - - ASSERT_EQ(ArrowArrayAppendInt(array->children[0]->children[0], -123), NANOARROW_OK); - ASSERT_EQ(ArrowArrayAppendInt(array->children[0]->children[0], -1), NANOARROW_OK); - ASSERT_EQ(ArrowArrayFinishElement(array->children[0]), NANOARROW_OK); - ASSERT_EQ(ArrowArrayFinishElement(&array.value), NANOARROW_OK); - - ASSERT_EQ(ArrowArrayAppendInt(array->children[0]->children[0], 0), NANOARROW_OK); - ASSERT_EQ(ArrowArrayAppendInt(array->children[0]->children[0], 1), NANOARROW_OK); - ASSERT_EQ(ArrowArrayAppendInt(array->children[0]->children[0], 123), NANOARROW_OK); - ASSERT_EQ(ArrowArrayFinishElement(array->children[0]), NANOARROW_OK); - ASSERT_EQ(ArrowArrayFinishElement(&array.value), NANOARROW_OK); - - ASSERT_EQ(ArrowArrayAppendNull(array->children[0], 1), NANOARROW_OK); - ASSERT_EQ(ArrowArrayFinishElement(&array.value), NANOARROW_OK); + ASSERT_EQ(adbc_validation::MakeSchema( + &schema.value, {adbc_validation::SchemaField::Nested( + "col", GetParam(), {{"item", NANOARROW_TYPE_INT64}})}), + ADBC_STATUS_OK); - ASSERT_EQ(ArrowArrayFinishBuildingDefault(&array.value, &na_error), NANOARROW_OK); + ASSERT_EQ(adbc_validation::MakeBatch>( + &schema.value, &array.value, &na_error, + {std::vector{-123, -1}, std::vector{0, 1, 123}, + std::nullopt}), + ADBC_STATUS_OK); PostgresCopyStreamWriteTester tester; ASSERT_EQ(tester.Init(&schema.value, &array.value, *type_resolver_), NANOARROW_OK); @@ -1002,38 +954,17 @@ TEST_P(PostgresCopyListTest, PostgresCopyWriteListVarchar) { adbc_validation::Handle array; struct ArrowError na_error; - ASSERT_EQ(ArrowSchemaInitFromType(&schema.value, NANOARROW_TYPE_STRUCT), NANOARROW_OK); - ASSERT_EQ(ArrowSchemaAllocateChildren(&schema.value, 1), NANOARROW_OK); - - ASSERT_EQ(ArrowSchemaInitFromType(schema->children[0], GetParam()), NANOARROW_OK); - ASSERT_EQ(ArrowSchemaSetName(schema->children[0], "col"), NANOARROW_OK); - ASSERT_EQ(ArrowSchemaSetType(schema->children[0]->children[0], NANOARROW_TYPE_STRING), - NANOARROW_OK); - - ASSERT_EQ(ArrowArrayInitFromSchema(&array.value, &schema.value, nullptr), NANOARROW_OK); - ASSERT_EQ(ArrowArrayStartAppending(&array.value), NANOARROW_OK); - - ASSERT_EQ(ArrowArrayAppendString(array->children[0]->children[0], ArrowCharView("foo")), - NANOARROW_OK); - ASSERT_EQ(ArrowArrayAppendString(array->children[0]->children[0], ArrowCharView("bar")), - NANOARROW_OK); - ASSERT_EQ(ArrowArrayFinishElement(array->children[0]), NANOARROW_OK); - ASSERT_EQ(ArrowArrayFinishElement(&array.value), NANOARROW_OK); - - ASSERT_EQ(ArrowArrayAppendString(array->children[0]->children[0], ArrowCharView("baz")), - NANOARROW_OK); - ASSERT_EQ(ArrowArrayAppendString(array->children[0]->children[0], ArrowCharView("qux")), - NANOARROW_OK); ASSERT_EQ( - ArrowArrayAppendString(array->children[0]->children[0], ArrowCharView("quux")), - NANOARROW_OK); - ASSERT_EQ(ArrowArrayFinishElement(array->children[0]), NANOARROW_OK); - ASSERT_EQ(ArrowArrayFinishElement(&array.value), NANOARROW_OK); - - ASSERT_EQ(ArrowArrayAppendNull(array->children[0], 1), NANOARROW_OK); - ASSERT_EQ(ArrowArrayFinishElement(&array.value), NANOARROW_OK); + adbc_validation::MakeSchema( + &schema.value, {adbc_validation::SchemaField::Nested( + "col", GetParam(), {{"item", NANOARROW_TYPE_STRING}})}), + ADBC_STATUS_OK); - ASSERT_EQ(ArrowArrayFinishBuildingDefault(&array.value, &na_error), NANOARROW_OK); + ASSERT_EQ(adbc_validation::MakeBatch>( + &schema.value, &array.value, &na_error, + {std::vector{"foo", "bar"}, + std::vector{"baz", "qux", "quux"}, std::nullopt}), + ADBC_STATUS_OK); PostgresCopyStreamWriteTester tester; ASSERT_EQ(tester.Init(&schema.value, &array.value, *type_resolver_), NANOARROW_OK); @@ -1079,23 +1010,10 @@ TEST_F(PostgresCopyTest, PostgresCopyWriteFixedSizeListInteger) { ASSERT_EQ(ArrowSchemaSetType(schema->children[0]->children[0], NANOARROW_TYPE_INT32), NANOARROW_OK); - ASSERT_EQ(ArrowArrayInitFromSchema(&array.value, &schema.value, nullptr), NANOARROW_OK); - ASSERT_EQ(ArrowArrayStartAppending(&array.value), NANOARROW_OK); - - ASSERT_EQ(ArrowArrayAppendInt(array->children[0]->children[0], 1), NANOARROW_OK); - ASSERT_EQ(ArrowArrayAppendInt(array->children[0]->children[0], 2), NANOARROW_OK); - ASSERT_EQ(ArrowArrayFinishElement(array->children[0]), NANOARROW_OK); - ASSERT_EQ(ArrowArrayFinishElement(&array.value), NANOARROW_OK); - - ASSERT_EQ(ArrowArrayAppendInt(array->children[0]->children[0], -1), NANOARROW_OK); - ASSERT_EQ(ArrowArrayAppendInt(array->children[0]->children[0], -2), NANOARROW_OK); - ASSERT_EQ(ArrowArrayFinishElement(array->children[0]), NANOARROW_OK); - ASSERT_EQ(ArrowArrayFinishElement(&array.value), NANOARROW_OK); - - ASSERT_EQ(ArrowArrayAppendNull(array->children[0], 1), NANOARROW_OK); - ASSERT_EQ(ArrowArrayFinishElement(&array.value), NANOARROW_OK); - - ASSERT_EQ(ArrowArrayFinishBuildingDefault(&array.value, &na_error), NANOARROW_OK); + ASSERT_EQ(adbc_validation::MakeBatch>( + &schema.value, &array.value, &na_error, + {std::vector{1, 2}, std::vector{-1, -2}, std::nullopt}), + ADBC_STATUS_OK); PostgresCopyStreamWriteTester tester; ASSERT_EQ(tester.Init(&schema.value, &array.value, *type_resolver_), NANOARROW_OK); diff --git a/3rd_party/apache-arrow-adbc/c/driver/postgresql/copy/reader.h b/3rd_party/apache-arrow-adbc/c/driver/postgresql/copy/reader.h index 983f392..07f91d5 100644 --- a/3rd_party/apache-arrow-adbc/c/driver/postgresql/copy/reader.h +++ b/3rd_party/apache-arrow-adbc/c/driver/postgresql/copy/reader.h @@ -972,10 +972,11 @@ class PostgresCopyStreamReader { return NANOARROW_OK; } - ArrowErrorCode InferOutputSchema(ArrowError* error) { + ArrowErrorCode InferOutputSchema(const std::string& vendor_name, ArrowError* error) { schema_.reset(); ArrowSchemaInit(schema_.get()); - NANOARROW_RETURN_NOT_OK(root_reader_.InputType().SetSchema(schema_.get())); + NANOARROW_RETURN_NOT_OK( + root_reader_.InputType().SetSchema(schema_.get(), vendor_name)); return NANOARROW_OK; } diff --git a/3rd_party/apache-arrow-adbc/c/driver/postgresql/copy/writer.h b/3rd_party/apache-arrow-adbc/c/driver/postgresql/copy/writer.h index b97628f..e88ed69 100644 --- a/3rd_party/apache-arrow-adbc/c/driver/postgresql/copy/writer.h +++ b/3rd_party/apache-arrow-adbc/c/driver/postgresql/copy/writer.h @@ -590,8 +590,9 @@ static inline ArrowErrorCode MakeCopyFieldWriter( *out = T::Create(array_view); return NANOARROW_OK; } + case NANOARROW_TYPE_UINT32: case NANOARROW_TYPE_INT64: - case NANOARROW_TYPE_UINT32: { + case NANOARROW_TYPE_UINT64: { using T = PostgresCopyNetworkEndianFieldWriter; *out = T::Create(array_view); return NANOARROW_OK; @@ -612,6 +613,7 @@ static inline ArrowErrorCode MakeCopyFieldWriter( return ADBC_STATUS_NOT_IMPLEMENTED; } } + case NANOARROW_TYPE_HALF_FLOAT: case NANOARROW_TYPE_FLOAT: { using T = PostgresCopyFloatFieldWriter; *out = T::Create(array_view); @@ -637,8 +639,12 @@ static inline ArrowErrorCode MakeCopyFieldWriter( return NANOARROW_OK; } case NANOARROW_TYPE_BINARY: + case NANOARROW_TYPE_LARGE_BINARY: + case NANOARROW_TYPE_FIXED_SIZE_BINARY: + case NANOARROW_TYPE_BINARY_VIEW: case NANOARROW_TYPE_STRING: - case NANOARROW_TYPE_LARGE_STRING: { + case NANOARROW_TYPE_LARGE_STRING: + case NANOARROW_TYPE_STRING_VIEW: { using T = PostgresCopyBinaryFieldWriter; *out = T::Create(array_view); return NANOARROW_OK; diff --git a/3rd_party/apache-arrow-adbc/c/driver/postgresql/database.cc b/3rd_party/apache-arrow-adbc/c/driver/postgresql/database.cc index 97242ad..cdbad75 100644 --- a/3rd_party/apache-arrow-adbc/c/driver/postgresql/database.cc +++ b/3rd_party/apache-arrow-adbc/c/driver/postgresql/database.cc @@ -17,6 +17,8 @@ #include "database.h" +#include +#include #include #include #include @@ -28,6 +30,7 @@ #include #include "driver/common/utils.h" +#include "result_helper.h" namespace adbcpq { @@ -54,8 +57,19 @@ AdbcStatusCode PostgresDatabase::GetOptionDouble(const char* option, double* val } AdbcStatusCode PostgresDatabase::Init(struct AdbcError* error) { - // Connect to validate the parameters. - return RebuildTypeResolver(error); + // Connect to initialize the version information and build the type table + PGconn* conn = nullptr; + RAISE_ADBC(Connect(&conn, error)); + + Status status = InitVersions(conn); + if (!status.ok()) { + RAISE_ADBC(Disconnect(&conn, nullptr)); + return status.ToAdbc(error); + } + + status = RebuildTypeResolver(conn); + RAISE_ADBC(Disconnect(&conn, nullptr)); + return status.ToAdbc(error); } AdbcStatusCode PostgresDatabase::Release(struct AdbcError* error) { @@ -123,20 +137,87 @@ AdbcStatusCode PostgresDatabase::Disconnect(PGconn** conn, struct AdbcError* err return ADBC_STATUS_OK; } -// Helpers for building the type resolver from queries -static inline int32_t InsertPgAttributeResult( - PGresult* result, const std::shared_ptr& resolver); +namespace { + +// Parse an individual version in the form of "xxx.xxx.xxx". +// If the version components aren't numeric, they will be zero. +std::array ParseVersion(std::string_view version) { + std::array out{}; + size_t component = 0; + size_t component_begin = 0; + size_t component_end = 0; + + // While there are remaining version components and we haven't reached the end of the + // string + while (component_begin < version.size() && component < out.size()) { + // Find the next character that marks a version component separation or the end of the + // string + component_end = version.find_first_of(".-", component_begin); + if (component_end == version.npos) { + component_end = version.size(); + } -static inline int32_t InsertPgTypeResult( - PGresult* result, const std::shared_ptr& resolver); + // Try to parse the component as an integer (assigning zero if this fails) + int value = 0; + std::from_chars(version.data() + component_begin, version.data() + component_end, + value); + out[component] = value; -AdbcStatusCode PostgresDatabase::RebuildTypeResolver(struct AdbcError* error) { - PGconn* conn = nullptr; - AdbcStatusCode final_status = Connect(&conn, error); - if (final_status != ADBC_STATUS_OK) { - return final_status; + // Move on to the next component + component_begin = component_end + 1; + component_end = component_begin; + component++; + } + + return out; +} + +// Parse the PostgreSQL version() string that looks like: +// PostgreSQL 8.0.2 on i686-pc-linux-gnu, compiled by GCC gcc (GCC) 3.4.2 20041017 (Red +// Hat 3.4.2-6.fc3), Redshift 1.0.77467 +std::array ParsePrefixedVersion(std::string_view version_info, + std::string_view prefix) { + size_t pos = version_info.find(prefix); + if (pos == version_info.npos) { + return {0, 0, 0}; } + // Skip the prefix and any leading whitespace + pos = version_info.find_first_not_of(' ', pos + prefix.size()); + if (pos == version_info.npos) { + return {0, 0, 0}; + } + + return ParseVersion(version_info.substr(pos)); +} + +} // namespace + +Status PostgresDatabase::InitVersions(PGconn* conn) { + PqResultHelper helper(conn, "SELECT version();"); + UNWRAP_STATUS(helper.Execute()); + if (helper.NumRows() != 1 || helper.NumColumns() != 1) { + return Status::Internal("Expected 1 row and 1 column for SELECT version(); but got ", + helper.NumRows(), "/", helper.NumColumns()); + } + + std::string_view version_info = helper.Row(0)[0].value(); + postgres_server_version_ = ParsePrefixedVersion(version_info, "PostgreSQL"); + redshift_server_version_ = ParsePrefixedVersion(version_info, "Redshift"); + + return Status::Ok(); +} + +// Helpers for building the type resolver from queries +static std::string BuildPgTypeQuery(bool has_typarray); + +static Status InsertPgAttributeResult( + const PqResultHelper& result, const std::shared_ptr& resolver); + +static Status InsertPgTypeResult(const PqResultHelper& result, + const std::shared_ptr& resolver); + +Status PostgresDatabase::RebuildTypeResolver(PGconn* conn) { // We need a few queries to build the resolver. The current strategy might // fail for some recursive definitions (e.g., arrays of records of arrays). // First, one on the pg_attribute table to resolve column names/oids for @@ -156,147 +237,131 @@ ORDER BY // recursive definitions (e.g., record types with array column). This currently won't // handle range types because those rows don't have child OID information. Arrays types // are inserted after a successful insert of the element type. - const std::string kTypeQuery = R"( -SELECT - oid, - typname, - typreceive, - typbasetype, - typarray, - typrelid -FROM - pg_catalog.pg_type -WHERE - (typreceive != 0 OR typname = 'aclitem') AND typtype != 'r' AND typreceive::TEXT != 'array_recv' -ORDER BY - oid -)"; + std::string type_query = + BuildPgTypeQuery(/*has_typarray*/ redshift_server_version_[0] == 0); // Create a new type resolver (this instance's type_resolver_ member // will be updated at the end if this succeeds). auto resolver = std::make_shared(); // Insert record type definitions (this includes table schemas) - PGresult* result = PQexec(conn, kColumnsQuery.c_str()); - ExecStatusType pq_status = PQresultStatus(result); - if (pq_status == PGRES_TUPLES_OK) { - InsertPgAttributeResult(result, resolver); - } else { - SetError(error, "%s%s", - "[libpq] Failed to build type mapping table: ", PQerrorMessage(conn)); - final_status = ADBC_STATUS_IO; - } - - PQclear(result); + PqResultHelper columns(conn, kColumnsQuery.c_str()); + UNWRAP_STATUS(columns.Execute()); + UNWRAP_STATUS(InsertPgAttributeResult(columns, resolver)); // Attempt filling the resolver a few times to handle recursive definitions. int32_t max_attempts = 3; + PqResultHelper types(conn, type_query); for (int32_t i = 0; i < max_attempts; i++) { - result = PQexec(conn, kTypeQuery.c_str()); - ExecStatusType pq_status = PQresultStatus(result); - if (pq_status == PGRES_TUPLES_OK) { - InsertPgTypeResult(result, resolver); - } else { - SetError(error, "%s%s", - "[libpq] Failed to build type mapping table: ", PQerrorMessage(conn)); - final_status = ADBC_STATUS_IO; - } - - PQclear(result); - if (final_status != ADBC_STATUS_OK) { - break; - } + UNWRAP_STATUS(types.Execute()); + UNWRAP_STATUS(InsertPgTypeResult(types, resolver)); } - // Disconnect since PostgreSQL connections can be heavy. - { - AdbcStatusCode status = Disconnect(&conn, error); - if (status != ADBC_STATUS_OK) final_status = status; - } + type_resolver_ = std::move(resolver); + return Status::Ok(); +} - if (final_status == ADBC_STATUS_OK) { - type_resolver_ = std::move(resolver); +static std::string BuildPgTypeQuery(bool has_typarray) { + std::string maybe_typarray_col; + std::string maybe_array_recv_filter; + if (has_typarray) { + maybe_typarray_col = ", typarray"; + maybe_array_recv_filter = "AND typreceive::TEXT != 'array_recv'"; } - return final_status; + return std::string() + "SELECT oid, typname, typreceive, typbasetype, typrelid" + + maybe_typarray_col + " FROM pg_catalog.pg_type " + + " WHERE (typreceive != 0 OR typsend != 0) AND typtype != 'r' " + + maybe_array_recv_filter; } -static inline int32_t InsertPgAttributeResult( - PGresult* result, const std::shared_ptr& resolver) { - int num_rows = PQntuples(result); +static Status InsertPgAttributeResult( + const PqResultHelper& result, const std::shared_ptr& resolver) { + int num_rows = result.NumRows(); std::vector> columns; - uint32_t current_type_oid = 0; - int32_t n_added = 0; + int64_t current_type_oid = 0; + + if (result.NumColumns() != 3) { + return Status::Internal( + "Expected 3 columns from type resolver pg_attribute query but got ", + result.NumColumns()); + } for (int row = 0; row < num_rows; row++) { - const uint32_t type_oid = static_cast( - std::strtol(PQgetvalue(result, row, 0), /*str_end=*/nullptr, /*base=*/10)); - const char* col_name = PQgetvalue(result, row, 1); - const uint32_t col_oid = static_cast( - std::strtol(PQgetvalue(result, row, 2), /*str_end=*/nullptr, /*base=*/10)); + PqResultRow item = result.Row(row); + UNWRAP_RESULT(int64_t type_oid, item[0].ParseInteger()); + std::string_view col_name = item[1].value(); + UNWRAP_RESULT(int64_t col_oid, item[2].ParseInteger()); if (type_oid != current_type_oid && !columns.empty()) { resolver->InsertClass(current_type_oid, columns); columns.clear(); current_type_oid = type_oid; - n_added++; } - columns.push_back({col_name, col_oid}); + columns.push_back({std::string(col_name), static_cast(col_oid)}); } if (!columns.empty()) { - resolver->InsertClass(current_type_oid, columns); - n_added++; + resolver->InsertClass(static_cast(current_type_oid), columns); } - return n_added; + return Status::Ok(); } -static inline int32_t InsertPgTypeResult( - PGresult* result, const std::shared_ptr& resolver) { - int num_rows = PQntuples(result); - PostgresTypeResolver::Item item; - int32_t n_added = 0; +static Status InsertPgTypeResult(const PqResultHelper& result, + const std::shared_ptr& resolver) { + if (result.NumColumns() != 5 && result.NumColumns() != 6) { + return Status::Internal( + "Expected 5 or 6 columns from type resolver pg_type query but got ", + result.NumColumns()); + } + + int num_rows = result.NumRows(); + int num_cols = result.NumColumns(); + PostgresTypeResolver::Item type_item; for (int row = 0; row < num_rows; row++) { - const uint32_t oid = static_cast( - std::strtol(PQgetvalue(result, row, 0), /*str_end=*/nullptr, /*base=*/10)); - const char* typname = PQgetvalue(result, row, 1); - const char* typreceive = PQgetvalue(result, row, 2); - const uint32_t typbasetype = static_cast( - std::strtol(PQgetvalue(result, row, 3), /*str_end=*/nullptr, /*base=*/10)); - const uint32_t typarray = static_cast( - std::strtol(PQgetvalue(result, row, 4), /*str_end=*/nullptr, /*base=*/10)); - const uint32_t typrelid = static_cast( - std::strtol(PQgetvalue(result, row, 5), /*str_end=*/nullptr, /*base=*/10)); + PqResultRow item = result.Row(row); + UNWRAP_RESULT(int64_t oid, item[0].ParseInteger()); + const char* typname = item[1].data; + const char* typreceive = item[2].data; + UNWRAP_RESULT(int64_t typbasetype, item[3].ParseInteger()); + UNWRAP_RESULT(int64_t typrelid, item[4].ParseInteger()); + + int64_t typarray; + if (num_cols == 6) { + UNWRAP_RESULT(typarray, item[5].ParseInteger()); + } else { + typarray = 0; + } // Special case the aclitem because it shows up in a bunch of internal tables if (strcmp(typname, "aclitem") == 0) { typreceive = "aclitem_recv"; } - item.oid = oid; - item.typname = typname; - item.typreceive = typreceive; - item.class_oid = typrelid; - item.base_oid = typbasetype; + type_item.oid = static_cast(oid); + type_item.typname = typname; + type_item.typreceive = typreceive; + type_item.class_oid = static_cast(typrelid); + type_item.base_oid = static_cast(typbasetype); - int result = resolver->Insert(item, nullptr); + int result = resolver->Insert(type_item, nullptr); // If there's an array type and the insert succeeded, add that now too if (result == NANOARROW_OK && typarray != 0) { std::string array_typname = "_" + std::string(typname); - item.oid = typarray; - item.typname = array_typname.c_str(); - item.typreceive = "array_recv"; - item.child_oid = oid; + type_item.oid = typarray; + type_item.typname = array_typname.c_str(); + type_item.typreceive = "array_recv"; + type_item.child_oid = static_cast(oid); - resolver->Insert(item, nullptr); + resolver->Insert(type_item, nullptr); } } - return n_added; + return Status::Ok(); } } // namespace adbcpq diff --git a/3rd_party/apache-arrow-adbc/c/driver/postgresql/database.h b/3rd_party/apache-arrow-adbc/c/driver/postgresql/database.h index d246ea0..e0a0026 100644 --- a/3rd_party/apache-arrow-adbc/c/driver/postgresql/database.h +++ b/3rd_party/apache-arrow-adbc/c/driver/postgresql/database.h @@ -17,6 +17,7 @@ #pragma once +#include #include #include #include @@ -24,9 +25,12 @@ #include #include +#include "driver/framework/status.h" #include "postgres_type.h" namespace adbcpq { +using adbc::driver::Status; + class PostgresDatabase { public: PostgresDatabase(); @@ -58,12 +62,29 @@ class PostgresDatabase { return type_resolver_; } - AdbcStatusCode RebuildTypeResolver(struct AdbcError* error); + Status InitVersions(PGconn* conn); + Status RebuildTypeResolver(PGconn* conn); + std::string_view VendorName() { + if (redshift_server_version_[0] != 0) { + return "Redshift"; + } else { + return "PostgreSQL"; + } + } + const std::array& VendorVersion() { + if (redshift_server_version_[0] != 0) { + return redshift_server_version_; + } else { + return postgres_server_version_; + } + } private: int32_t open_connections_; std::string uri_; std::shared_ptr type_resolver_; + std::array postgres_server_version_{}; + std::array redshift_server_version_{}; }; } // namespace adbcpq diff --git a/3rd_party/apache-arrow-adbc/c/driver/postgresql/error.cc b/3rd_party/apache-arrow-adbc/c/driver/postgresql/error.cc index 276aadc..173868b 100644 --- a/3rd_party/apache-arrow-adbc/c/driver/postgresql/error.cc +++ b/3rd_party/apache-arrow-adbc/c/driver/postgresql/error.cc @@ -17,8 +17,8 @@ #include "error.h" -#include #include +#include #include #include #include @@ -29,72 +29,27 @@ namespace adbcpq { -namespace { -struct DetailField { - int code; - std::string key; -}; - -static const std::vector kDetailFields = { - {PG_DIAG_COLUMN_NAME, "PG_DIAG_COLUMN_NAME"}, - {PG_DIAG_CONTEXT, "PG_DIAG_CONTEXT"}, - {PG_DIAG_CONSTRAINT_NAME, "PG_DIAG_CONSTRAINT_NAME"}, - {PG_DIAG_DATATYPE_NAME, "PG_DIAG_DATATYPE_NAME"}, - {PG_DIAG_INTERNAL_POSITION, "PG_DIAG_INTERNAL_POSITION"}, - {PG_DIAG_INTERNAL_QUERY, "PG_DIAG_INTERNAL_QUERY"}, - {PG_DIAG_MESSAGE_PRIMARY, "PG_DIAG_MESSAGE_PRIMARY"}, - {PG_DIAG_MESSAGE_DETAIL, "PG_DIAG_MESSAGE_DETAIL"}, - {PG_DIAG_MESSAGE_HINT, "PG_DIAG_MESSAGE_HINT"}, - {PG_DIAG_SEVERITY_NONLOCALIZED, "PG_DIAG_SEVERITY_NONLOCALIZED"}, - {PG_DIAG_SQLSTATE, "PG_DIAG_SQLSTATE"}, - {PG_DIAG_STATEMENT_POSITION, "PG_DIAG_STATEMENT_POSITION"}, - {PG_DIAG_SCHEMA_NAME, "PG_DIAG_SCHEMA_NAME"}, - {PG_DIAG_TABLE_NAME, "PG_DIAG_TABLE_NAME"}, -}; -} // namespace - AdbcStatusCode SetError(struct AdbcError* error, PGresult* result, const char* format, ...) { + if (error && error->release) { + // TODO: combine the errors if possible + error->release(error); + } + va_list args; va_start(args, format); - SetErrorVariadic(error, format, args); + std::string message; + message.resize(1024); + int chars_needed = vsnprintf(message.data(), message.size(), format, args); va_end(args); - AdbcStatusCode code = ADBC_STATUS_IO; - - const char* sqlstate = PQresultErrorField(result, PG_DIAG_SQLSTATE); - if (sqlstate) { - // https://www.postgresql.org/docs/current/errcodes-appendix.html - // This can be extended in the future - if (std::strcmp(sqlstate, "57014") == 0) { - code = ADBC_STATUS_CANCELLED; - } else if (std::strcmp(sqlstate, "42P01") == 0 || - std::strcmp(sqlstate, "42602") == 0) { - code = ADBC_STATUS_NOT_FOUND; - } else if (std::strncmp(sqlstate, "42", 0) == 0) { - // Class 42 — Syntax Error or Access Rule Violation - code = ADBC_STATUS_INVALID_ARGUMENT; - } - - static_assert(sizeof(error->sqlstate) == 5, ""); - // N.B. strncpy generates warnings when used for this purpose - int i = 0; - for (; sqlstate[i] != '\0' && i < 5; i++) { - error->sqlstate[i] = sqlstate[i]; - } - for (; i < 5; i++) { - error->sqlstate[i] = '\0'; - } + if (chars_needed > 0) { + message.resize(chars_needed); + } else { + message.resize(0); } - for (const auto& field : kDetailFields) { - const char* value = PQresultErrorField(result, field.code); - if (value) { - AppendErrorDetail(error, field.key.c_str(), reinterpret_cast(value), - std::strlen(value)); - } - } - return code; + return MakeStatus(result, "{}", message).ToAdbc(error); } } // namespace adbcpq diff --git a/3rd_party/apache-arrow-adbc/c/driver/postgresql/error.h b/3rd_party/apache-arrow-adbc/c/driver/postgresql/error.h index 825de97..f24d417 100644 --- a/3rd_party/apache-arrow-adbc/c/driver/postgresql/error.h +++ b/3rd_party/apache-arrow-adbc/c/driver/postgresql/error.h @@ -19,11 +19,42 @@ #pragma once +#include +#include + #include #include +#include + +#include "driver/framework/status.h" + +using adbc::driver::Status; + namespace adbcpq { +struct DetailField { + int code; + std::string key; +}; + +static const std::vector kDetailFields = { + {PG_DIAG_COLUMN_NAME, "PG_DIAG_COLUMN_NAME"}, + {PG_DIAG_CONTEXT, "PG_DIAG_CONTEXT"}, + {PG_DIAG_CONSTRAINT_NAME, "PG_DIAG_CONSTRAINT_NAME"}, + {PG_DIAG_DATATYPE_NAME, "PG_DIAG_DATATYPE_NAME"}, + {PG_DIAG_INTERNAL_POSITION, "PG_DIAG_INTERNAL_POSITION"}, + {PG_DIAG_INTERNAL_QUERY, "PG_DIAG_INTERNAL_QUERY"}, + {PG_DIAG_MESSAGE_PRIMARY, "PG_DIAG_MESSAGE_PRIMARY"}, + {PG_DIAG_MESSAGE_DETAIL, "PG_DIAG_MESSAGE_DETAIL"}, + {PG_DIAG_MESSAGE_HINT, "PG_DIAG_MESSAGE_HINT"}, + {PG_DIAG_SEVERITY_NONLOCALIZED, "PG_DIAG_SEVERITY_NONLOCALIZED"}, + {PG_DIAG_SQLSTATE, "PG_DIAG_SQLSTATE"}, + {PG_DIAG_STATEMENT_POSITION, "PG_DIAG_STATEMENT_POSITION"}, + {PG_DIAG_SCHEMA_NAME, "PG_DIAG_SCHEMA_NAME"}, + {PG_DIAG_TABLE_NAME, "PG_DIAG_TABLE_NAME"}, +}; + // The printf checking attribute doesn't work properly on gcc 4.8 // and results in spurious compiler warnings #if defined(__clang__) || (defined(__GNUC__) && __GNUC__ >= 5) @@ -33,10 +64,50 @@ namespace adbcpq { #endif /// \brief Set an error based on a PGresult, inferring the proper ADBC status -/// code from the PGresult. +/// code from the PGresult. Deprecated and is currently a thin wrapper around +/// MakeStatus() below. AdbcStatusCode SetError(struct AdbcError* error, PGresult* result, const char* format, ...) ADBC_CHECK_PRINTF_ATTRIBUTE(3, 4); #undef ADBC_CHECK_PRINTF_ATTRIBUTE +template +Status MakeStatus(PGresult* result, const char* format_string, Args&&... args) { + auto message = ::fmt::vformat(format_string, ::fmt::make_format_args(args...)); + + AdbcStatusCode code = ADBC_STATUS_IO; + char sqlstate_out[5]; + std::memset(sqlstate_out, 0, sizeof(sqlstate_out)); + + if (result == nullptr) { + return Status(code, message); + } + + const char* sqlstate = PQresultErrorField(result, PG_DIAG_SQLSTATE); + if (sqlstate) { + // https://www.postgresql.org/docs/current/errcodes-appendix.html + // This can be extended in the future + if (std::strcmp(sqlstate, "57014") == 0) { + code = ADBC_STATUS_CANCELLED; + } else if (std::strcmp(sqlstate, "42P01") == 0 || + std::strcmp(sqlstate, "42602") == 0) { + code = ADBC_STATUS_NOT_FOUND; + } else if (std::strncmp(sqlstate, "42", 0) == 0) { + // Class 42 — Syntax Error or Access Rule Violation + code = ADBC_STATUS_INVALID_ARGUMENT; + } + } + + Status status(code, message); + status.SetSqlState(sqlstate); + for (const auto& field : kDetailFields) { + const char* value = PQresultErrorField(result, field.code); + if (value) { + status.AddDetail(field.key, value); + } + } + + return status; +} + } // namespace adbcpq diff --git a/3rd_party/apache-arrow-adbc/c/driver/postgresql/postgres_type.h b/3rd_party/apache-arrow-adbc/c/driver/postgresql/postgres_type.h index 02748cf..d2a5356 100644 --- a/3rd_party/apache-arrow-adbc/c/driver/postgresql/postgres_type.h +++ b/3rd_party/apache-arrow-adbc/c/driver/postgresql/postgres_type.h @@ -111,7 +111,11 @@ enum class PostgresTypeId { kXid8, kXid, kXml, - kUserDefined + kUserDefined, + // This is not an actual type, but there are cases where all we have is an Oid + // that was not inserted into the type resolver. We can't use "unknown" or "opaque" + // or "void" because those names show up in actual pg_type tables. + kUnnamedArrowOpaque }; // Returns the receive function name as defined in the typrecieve column @@ -139,6 +143,11 @@ class PostgresType { PostgresType() : PostgresType(PostgresTypeId::kUninitialized) {} + static PostgresType Unnamed(uint32_t oid) { + return PostgresType(PostgresTypeId::kUnnamedArrowOpaque) + .WithPgTypeInfo(oid, "unnamed"); + } + void AppendChild(const std::string& field_name, const PostgresType& type) { PostgresType child(type); children_.push_back(child.WithFieldName(field_name)); @@ -184,6 +193,19 @@ class PostgresType { int64_t n_children() const { return static_cast(children_.size()); } const PostgresType& child(int64_t i) const { return children_[i]; } + // The name used to communicate this type in a CREATE TABLE statement. + // These are not necessarily the most idiomatic names to use but PostgreSQL + // will accept typname() according to the "aliases" column in + // https://www.postgresql.org/docs/current/datatype.html + const std::string sql_type_name() const { + switch (type_id_) { + case PostgresTypeId::kArray: + return children_[0].sql_type_name() + " ARRAY"; + default: + return typname_; + } + } + // Sets appropriate fields of an ArrowSchema that has been initialized using // ArrowSchemaInit. This is a recursive operation (i.e., nested types will // initialize and set the appropriate number of children). Returns NANOARROW_OK @@ -191,7 +213,8 @@ class PostgresType { // do not have a corresponding Arrow type are returned as Binary with field // metadata ADBC:posgresql:typname. These types can be represented as their // binary COPY representation in the output. - ArrowErrorCode SetSchema(ArrowSchema* schema) const { + ArrowErrorCode SetSchema(ArrowSchema* schema, + const std::string& vendor_name = "PostgreSQL") const { switch (type_id_) { // ---- Primitive types -------------------- case PostgresTypeId::kBool: @@ -222,7 +245,7 @@ class PostgresType { // ---- Numeric/Decimal------------------- case PostgresTypeId::kNumeric: NANOARROW_RETURN_NOT_OK(ArrowSchemaSetType(schema, NANOARROW_TYPE_STRING)); - NANOARROW_RETURN_NOT_OK(AddPostgresTypeMetadata(schema)); + NANOARROW_RETURN_NOT_OK(AddPostgresTypeMetadata(schema, vendor_name)); break; @@ -277,13 +300,14 @@ class PostgresType { case PostgresTypeId::kRecord: NANOARROW_RETURN_NOT_OK(ArrowSchemaSetTypeStruct(schema, n_children())); for (int64_t i = 0; i < n_children(); i++) { - NANOARROW_RETURN_NOT_OK(children_[i].SetSchema(schema->children[i])); + NANOARROW_RETURN_NOT_OK( + children_[i].SetSchema(schema->children[i], vendor_name)); } break; case PostgresTypeId::kArray: NANOARROW_RETURN_NOT_OK(ArrowSchemaSetType(schema, NANOARROW_TYPE_LIST)); - NANOARROW_RETURN_NOT_OK(children_[0].SetSchema(schema->children[0])); + NANOARROW_RETURN_NOT_OK(children_[0].SetSchema(schema->children[0], vendor_name)); break; case PostgresTypeId::kUserDefined: @@ -292,7 +316,7 @@ class PostgresType { // can still return the bytes postgres gives us and attach the type name as // metadata NANOARROW_RETURN_NOT_OK(ArrowSchemaSetType(schema, NANOARROW_TYPE_BINARY)); - NANOARROW_RETURN_NOT_OK(AddPostgresTypeMetadata(schema)); + NANOARROW_RETURN_NOT_OK(AddPostgresTypeMetadata(schema, vendor_name)); break; } @@ -312,8 +336,12 @@ class PostgresType { std::vector children_; static constexpr const char* kPostgresTypeKey = "ADBC:postgresql:typname"; + static constexpr const char* kExtensionName = "ARROW:extension:name"; + static constexpr const char* kOpaqueExtensionName = "arrow.opaque"; + static constexpr const char* kExtensionMetadata = "ARROW:extension:metadata"; - ArrowErrorCode AddPostgresTypeMetadata(ArrowSchema* schema) const { + ArrowErrorCode AddPostgresTypeMetadata(ArrowSchema* schema, + const std::string& vendor_name) const { // the typname_ may not always be set: an instance of this class can be // created with just the type id. That's why there is this here fallback to // resolve the type name of built-in types. @@ -322,8 +350,25 @@ class PostgresType { nanoarrow::UniqueBuffer buffer; ArrowMetadataBuilderInit(buffer.get(), nullptr); + // TODO(lidavidm): we have deprecated this in favor of arrow.opaque, + // remove once we feel enough time has passed NANOARROW_RETURN_NOT_OK(ArrowMetadataBuilderAppend( buffer.get(), ArrowCharView(kPostgresTypeKey), ArrowCharView(typname))); + + // Add the Opaque extension type metadata + std::string metadata = R"({"type_name": ")"; + metadata += typname; + metadata += R"(", "vendor_name": ")" + vendor_name + R"("})"; + NANOARROW_RETURN_NOT_OK( + ArrowMetadataBuilderAppend(buffer.get(), ArrowCharView(kExtensionName), + ArrowCharView(kOpaqueExtensionName))); + NANOARROW_RETURN_NOT_OK( + ArrowMetadataBuilderAppend(buffer.get(), ArrowCharView(kExtensionMetadata), + ArrowStringView{ + metadata.c_str(), + static_cast(metadata.size()), + })); + NANOARROW_RETURN_NOT_OK( ArrowSchemaSetMetadata(schema, reinterpret_cast(buffer->data))); @@ -362,7 +407,18 @@ class PostgresTypeResolver { return EINVAL; } - *type_out = (*result).second; + *type_out = result->second; + return NANOARROW_OK; + } + + ArrowErrorCode FindWithDefault(uint32_t oid, PostgresType* type_out) { + auto result = mapping_.find(oid); + if (result == mapping_.end()) { + *type_out = PostgresType::Unnamed(oid); + } else { + *type_out = result->second; + } + return NANOARROW_OK; } @@ -525,16 +581,40 @@ inline ArrowErrorCode PostgresType::FromSchema(const PostgresTypeResolver& resol return resolver.Find(resolver.GetOID(PostgresTypeId::kInt4), out, error); case NANOARROW_TYPE_UINT32: case NANOARROW_TYPE_INT64: + case NANOARROW_TYPE_UINT64: return resolver.Find(resolver.GetOID(PostgresTypeId::kInt8), out, error); + case NANOARROW_TYPE_HALF_FLOAT: case NANOARROW_TYPE_FLOAT: return resolver.Find(resolver.GetOID(PostgresTypeId::kFloat4), out, error); case NANOARROW_TYPE_DOUBLE: return resolver.Find(resolver.GetOID(PostgresTypeId::kFloat8), out, error); case NANOARROW_TYPE_STRING: + case NANOARROW_TYPE_LARGE_STRING: + case NANOARROW_TYPE_STRING_VIEW: return resolver.Find(resolver.GetOID(PostgresTypeId::kText), out, error); case NANOARROW_TYPE_BINARY: + case NANOARROW_TYPE_LARGE_BINARY: case NANOARROW_TYPE_FIXED_SIZE_BINARY: + case NANOARROW_TYPE_BINARY_VIEW: return resolver.Find(resolver.GetOID(PostgresTypeId::kBytea), out, error); + case NANOARROW_TYPE_DATE32: + case NANOARROW_TYPE_DATE64: + return resolver.Find(resolver.GetOID(PostgresTypeId::kDate), out, error); + case NANOARROW_TYPE_TIME32: + case NANOARROW_TYPE_TIME64: + return resolver.Find(resolver.GetOID(PostgresTypeId::kTime), out, error); + case NANOARROW_TYPE_DURATION: + case NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO: + return resolver.Find(resolver.GetOID(PostgresTypeId::kInterval), out, error); + case NANOARROW_TYPE_TIMESTAMP: + if (strcmp("", schema_view.timezone) == 0) { + return resolver.Find(resolver.GetOID(PostgresTypeId::kTimestamptz), out, error); + } else { + return resolver.Find(resolver.GetOID(PostgresTypeId::kTimestamp), out, error); + } + case NANOARROW_TYPE_DECIMAL128: + case NANOARROW_TYPE_DECIMAL256: + return resolver.Find(resolver.GetOID(PostgresTypeId::kNumeric), out, error); case NANOARROW_TYPE_LIST: case NANOARROW_TYPE_LARGE_LIST: case NANOARROW_TYPE_FIXED_SIZE_LIST: { diff --git a/3rd_party/apache-arrow-adbc/c/driver/postgresql/postgres_type_test.cc b/3rd_party/apache-arrow-adbc/c/driver/postgresql/postgres_type_test.cc index 9d6152f..2c76f4c 100644 --- a/3rd_party/apache-arrow-adbc/c/driver/postgresql/postgres_type_test.cc +++ b/3rd_party/apache-arrow-adbc/c/driver/postgresql/postgres_type_test.cc @@ -174,6 +174,14 @@ TEST(PostgresTypeTest, PostgresTypeSetSchema) { &typnameMetadataValue); EXPECT_EQ(std::string(typnameMetadataValue.data, typnameMetadataValue.size_bytes), "numeric"); + ArrowMetadataGetValue(schema->metadata, ArrowCharView("ARROW:extension:name"), + &typnameMetadataValue); + EXPECT_EQ(std::string(typnameMetadataValue.data, typnameMetadataValue.size_bytes), + "arrow.opaque"); + ArrowMetadataGetValue(schema->metadata, ArrowCharView("ARROW:extension:metadata"), + &typnameMetadataValue); + EXPECT_EQ(std::string(typnameMetadataValue.data, typnameMetadataValue.size_bytes), + R"({"type_name": "numeric", "vendor_name": "PostgreSQL"})"); schema.reset(); ArrowSchemaInit(schema.get()); @@ -312,11 +320,10 @@ TEST(PostgresTypeTest, PostgresTypeFromSchema) { schema.reset(); ArrowError error; - ASSERT_EQ(ArrowSchemaInitFromType(schema.get(), NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO), + ASSERT_EQ(ArrowSchemaInitFromType(schema.get(), NANOARROW_TYPE_INTERVAL_MONTHS), NANOARROW_OK); EXPECT_EQ(PostgresType::FromSchema(resolver, schema.get(), &type, &error), ENOTSUP); - EXPECT_STREQ(error.message, - "Can't map Arrow type 'interval_month_day_nano' to Postgres type"); + EXPECT_STREQ(error.message, "Can't map Arrow type 'interval_months' to Postgres type"); schema.reset(); } @@ -330,6 +337,11 @@ TEST(PostgresTypeTest, PostgresTypeResolver) { EXPECT_EQ(resolver.Find(123, &type, &error), EINVAL); EXPECT_STREQ(ArrowErrorMessage(&error), "Postgres type with oid 123 not found"); + EXPECT_EQ(resolver.FindWithDefault(123, &type), NANOARROW_OK); + EXPECT_EQ(type.oid(), 123); + EXPECT_EQ(type.type_id(), PostgresTypeId::kUnnamedArrowOpaque); + EXPECT_EQ(type.typname(), "unnamed"); + // Check error for Array with unknown child item.oid = 123; item.typname = "some_array"; diff --git a/3rd_party/apache-arrow-adbc/c/driver/postgresql/postgresql.cc b/3rd_party/apache-arrow-adbc/c/driver/postgresql/postgresql.cc index 033c446..e43db98 100644 --- a/3rd_party/apache-arrow-adbc/c/driver/postgresql/postgresql.cc +++ b/3rd_party/apache-arrow-adbc/c/driver/postgresql/postgresql.cc @@ -25,8 +25,10 @@ #include "connection.h" #include "database.h" #include "driver/common/utils.h" +#include "driver/framework/status.h" #include "statement.h" +using adbc::driver::Status; using adbcpq::PostgresConnection; using adbcpq::PostgresDatabase; using adbcpq::PostgresStatement; @@ -56,14 +58,36 @@ const struct AdbcError* PostgresErrorFromArrayStream(struct ArrowArrayStream* st // Currently only valid for TupleReader return adbcpq::TupleReader::ErrorFromArrayStream(stream, status); } + +int PostgresErrorGetDetailCount(const struct AdbcError* error) { + if (IsCommonError(error)) { + return CommonErrorGetDetailCount(error); + } + + if (error->vendor_code != ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA) { + return 0; + } + + auto error_obj = reinterpret_cast(error->private_data); + return error_obj->CDetailCount(); +} + +struct AdbcErrorDetail PostgresErrorGetDetail(const struct AdbcError* error, int index) { + if (IsCommonError(error)) { + return CommonErrorGetDetail(error, index); + } + + auto error_obj = reinterpret_cast(error->private_data); + return error_obj->CDetail(index); +} } // namespace int AdbcErrorGetDetailCount(const struct AdbcError* error) { - return CommonErrorGetDetailCount(error); + return PostgresErrorGetDetailCount(error); } struct AdbcErrorDetail AdbcErrorGetDetail(const struct AdbcError* error, int index) { - return CommonErrorGetDetail(error, index); + return PostgresErrorGetDetail(error, index); } const struct AdbcError* AdbcErrorFromArrayStream(struct ArrowArrayStream* stream, @@ -860,8 +884,8 @@ AdbcStatusCode PostgresqlDriverInit(int version, void* raw_driver, if (version >= ADBC_VERSION_1_1_0) { std::memset(driver, 0, ADBC_DRIVER_1_1_0_SIZE); - driver->ErrorGetDetailCount = CommonErrorGetDetailCount; - driver->ErrorGetDetail = CommonErrorGetDetail; + driver->ErrorGetDetailCount = PostgresErrorGetDetailCount; + driver->ErrorGetDetail = PostgresErrorGetDetail; driver->ErrorFromArrayStream = PostgresErrorFromArrayStream; driver->DatabaseGetOption = PostgresDatabaseGetOption; diff --git a/3rd_party/apache-arrow-adbc/c/driver/postgresql/postgresql_test.cc b/3rd_party/apache-arrow-adbc/c/driver/postgresql/postgresql_test.cc index c45168e..be32bd8 100644 --- a/3rd_party/apache-arrow-adbc/c/driver/postgresql/postgresql_test.cc +++ b/3rd_party/apache-arrow-adbc/c/driver/postgresql/postgresql_test.cc @@ -116,11 +116,24 @@ class PostgresQuirks : public adbc_validation::DriverQuirks { ArrowType IngestSelectRoundTripType(ArrowType ingest_type) const override { switch (ingest_type) { case NANOARROW_TYPE_INT8: + case NANOARROW_TYPE_UINT8: return NANOARROW_TYPE_INT16; + case NANOARROW_TYPE_UINT16: + return NANOARROW_TYPE_INT32; + case NANOARROW_TYPE_UINT32: + case NANOARROW_TYPE_UINT64: + return NANOARROW_TYPE_INT64; + case NANOARROW_TYPE_HALF_FLOAT: + return NANOARROW_TYPE_FLOAT; case NANOARROW_TYPE_DURATION: return NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO; case NANOARROW_TYPE_LARGE_STRING: + case NANOARROW_TYPE_STRING_VIEW: return NANOARROW_TYPE_STRING; + case NANOARROW_TYPE_LARGE_BINARY: + case NANOARROW_TYPE_FIXED_SIZE_BINARY: + case NANOARROW_TYPE_BINARY_VIEW: + return NANOARROW_TYPE_BINARY; case NANOARROW_TYPE_DECIMAL128: case NANOARROW_TYPE_DECIMAL256: return NANOARROW_TYPE_STRING; @@ -886,11 +899,6 @@ class PostgresStatementTest : public ::testing::Test, void SetUp() override { ASSERT_NO_FATAL_FAILURE(SetUpTest()); } void TearDown() override { ASSERT_NO_FATAL_FAILURE(TearDownTest()); } - void TestSqlIngestUInt8() { GTEST_SKIP() << "Not implemented"; } - void TestSqlIngestUInt16() { GTEST_SKIP() << "Not implemented"; } - void TestSqlIngestUInt32() { GTEST_SKIP() << "Not implemented"; } - void TestSqlIngestUInt64() { GTEST_SKIP() << "Not implemented"; } - void TestSqlPrepareErrorParamCountMismatch() { GTEST_SKIP() << "Not yet implemented"; } void TestSqlPrepareGetParameterSchema() { GTEST_SKIP() << "Not yet implemented"; } void TestSqlPrepareSelectParams() { GTEST_SKIP() << "Not yet implemented"; } @@ -1139,10 +1147,11 @@ TEST_F(PostgresStatementTest, SqlIngestTimestampOverflow) { IsOkStatus(&error)); ASSERT_THAT(AdbcStatementPrepare(&statement, &error), IsOkStatus(&error)); ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, nullptr, &error), - IsStatus(ADBC_STATUS_INVALID_ARGUMENT, &error)); - ASSERT_THAT(error.message, - ::testing::HasSubstr("Row #1 has value '9223372036854775807' which " - "exceeds PostgreSQL timestamp limits")); + IsStatus(ADBC_STATUS_INTERNAL, &error)); + ASSERT_THAT( + error.message, + ::testing::HasSubstr( + "Row 0 timestamp value 9223372036854775807 with unit 0 would overflow")); } { @@ -1169,10 +1178,11 @@ TEST_F(PostgresStatementTest, SqlIngestTimestampOverflow) { IsOkStatus(&error)); ASSERT_THAT(AdbcStatementPrepare(&statement, &error), IsOkStatus(&error)); ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, nullptr, &error), - IsStatus(ADBC_STATUS_INVALID_ARGUMENT, &error)); - ASSERT_THAT(error.message, - ::testing::HasSubstr("Row #1 has value '-9223372036854775808' which " - "exceeds PostgreSQL timestamp limits")); + IsStatus(ADBC_STATUS_INTERNAL, &error)); + ASSERT_THAT( + error.message, + ::testing::HasSubstr( + "Row 0 timestamp value -9223372036854775808 with unit 0 would overflow")); } } @@ -1436,6 +1446,66 @@ TEST_F(PostgresStatementTest, ExecuteParameterizedQueryWithRowsAffected) { } } +TEST_F(PostgresStatementTest, SqlExecuteCopyZeroRowOutputError) { + ASSERT_THAT(quirks()->DropTable(&connection, "adbc_test", &error), IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error)); + + { + ASSERT_THAT(AdbcStatementSetSqlQuery( + &statement, "CREATE TABLE adbc_test (id int primary key, data jsonb)", + &error), + IsOkStatus(&error)); + adbc_validation::StreamReader reader; + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value, + &reader.rows_affected, &error), + IsOkStatus(&error)); + } + + { + ASSERT_THAT( + AdbcStatementSetSqlQuery( + &statement, "insert into adbc_test (id, data) values (1, null)", &error), + IsOkStatus(&error)); + adbc_validation::StreamReader reader; + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value, + &reader.rows_affected, &error), + IsOkStatus(&error)); + } + + { + ASSERT_THAT( + AdbcStatementSetSqlQuery( + &statement, "insert into adbc_test (id, data) values (2, '1')", &error), + IsOkStatus(&error)); + adbc_validation::StreamReader reader; + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value, + &reader.rows_affected, &error), + IsOkStatus(&error)); + } + + { + ASSERT_THAT( + AdbcStatementSetSqlQuery(&statement, + "SELECT id, data from adbc_test JOIN " + "jsonb_array_elements(adbc_test.data) AS foo ON true", + &error), + IsOkStatus(&error)); + adbc_validation::StreamReader reader; + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value, + &reader.rows_affected, &error), + IsOkStatus()); + ASSERT_NO_FATAL_FAILURE(reader.GetSchema()); + ASSERT_EQ(reader.MaybeNext(), EINVAL); + + AdbcStatusCode status = ADBC_STATUS_OK; + const struct AdbcError* detail = + AdbcErrorFromArrayStream(&reader.stream.value, &status); + ASSERT_NE(nullptr, detail); + ASSERT_EQ(ADBC_STATUS_INVALID_ARGUMENT, status); + ASSERT_EQ("22023", std::string_view(detail->sqlstate, 5)); + } +} + TEST_F(PostgresStatementTest, BatchSizeHint) { ASSERT_THAT(quirks()->EnsureSampleTable(&connection, "batch_size_hint_test", &error), IsOkStatus(&error)); @@ -1719,7 +1789,7 @@ TEST_P(PostgresTypeTest, SelectValue) { // check type ASSERT_NO_FATAL_FAILURE(reader.GetSchema()); ASSERT_NO_FATAL_FAILURE(adbc_validation::CompareSchema( - &reader.schema.value, {{std::nullopt, GetParam().arrow_type, true}})); + &reader.schema.value, {{"", GetParam().arrow_type, true}})); if (GetParam().arrow_type == NANOARROW_TYPE_TIMESTAMP) { if (GetParam().sql_type.find("WITH TIME ZONE") == std::string::npos) { ASSERT_STREQ(reader.schema->children[0]->format, "tsu:"); diff --git a/3rd_party/apache-arrow-adbc/c/driver/postgresql/result_helper.cc b/3rd_party/apache-arrow-adbc/c/driver/postgresql/result_helper.cc index 157b100..6dd7527 100644 --- a/3rd_party/apache-arrow-adbc/c/driver/postgresql/result_helper.cc +++ b/3rd_party/apache-arrow-adbc/c/driver/postgresql/result_helper.cc @@ -20,56 +20,51 @@ #include #include -#include "driver/common/utils.h" +#define ADBC_FRAMEWORK_USE_FMT +#include "driver/framework/status.h" #include "error.h" namespace adbcpq { PqResultHelper::~PqResultHelper() { ClearResult(); } -AdbcStatusCode PqResultHelper::PrepareInternal(int n_params, const Oid* param_oids, - struct AdbcError* error) { +Status PqResultHelper::PrepareInternal(int n_params, const Oid* param_oids) const { // TODO: make stmtName a unique identifier? PGresult* result = PQprepare(conn_, /*stmtName=*/"", query_.c_str(), n_params, param_oids); if (PQresultStatus(result) != PGRES_COMMAND_OK) { - AdbcStatusCode code = - SetError(error, result, "[libpq] Failed to prepare query: %s\nQuery was:%s", - PQerrorMessage(conn_), query_.c_str()); + auto status = MakeStatus(result, "Failed to prepare query: {}\nQuery was:{}", + PQerrorMessage(conn_), query_.c_str()); PQclear(result); - return code; + return status; } PQclear(result); - return ADBC_STATUS_OK; + return Status::Ok(); } -AdbcStatusCode PqResultHelper::Prepare(struct AdbcError* error) { - return PrepareInternal(0, nullptr, error); -} +Status PqResultHelper::Prepare() const { return PrepareInternal(0, nullptr); } -AdbcStatusCode PqResultHelper::Prepare(const std::vector& param_oids, - struct AdbcError* error) { - return PrepareInternal(param_oids.size(), param_oids.data(), error); +Status PqResultHelper::Prepare(const std::vector& param_oids) const { + return PrepareInternal(param_oids.size(), param_oids.data()); } -AdbcStatusCode PqResultHelper::DescribePrepared(struct AdbcError* error) { +Status PqResultHelper::DescribePrepared() { ClearResult(); result_ = PQdescribePrepared(conn_, /*stmtName=*/""); if (PQresultStatus(result_) != PGRES_COMMAND_OK) { - AdbcStatusCode code = SetError( - error, result_, "[libpq] Failed to describe prepared statement: %s\nQuery was:%s", + Status status = MakeStatus( + result_, "[libpq] Failed to describe prepared statement: {}\nQuery was:{}", PQerrorMessage(conn_), query_.c_str()); ClearResult(); - return code; + return status; } - return ADBC_STATUS_OK; + return Status::Ok(); } -AdbcStatusCode PqResultHelper::Execute(struct AdbcError* error, - const std::vector& params, - PostgresType* param_types) { +Status PqResultHelper::Execute(const std::vector& params, + PostgresType* param_types) { if (params.size() == 0 && param_types == nullptr && output_format_ == Format::kText) { ClearResult(); result_ = PQexec(conn_, query_.c_str()); @@ -102,16 +97,14 @@ AdbcStatusCode PqResultHelper::Execute(struct AdbcError* error, ExecStatusType status = PQresultStatus(result_); if (status != PGRES_TUPLES_OK && status != PGRES_COMMAND_OK) { - AdbcStatusCode status = - SetError(error, result_, "[libpq] Failed to execute query '%s': %s", - query_.c_str(), PQerrorMessage(conn_)); - return status; + return MakeStatus(result_, "[libpq] Failed to execute query '{}': {}", query_.c_str(), + PQerrorMessage(conn_)); } - return ADBC_STATUS_OK; + return Status::Ok(); } -AdbcStatusCode PqResultHelper::ExecuteCopy(struct AdbcError* error) { +Status PqResultHelper::ExecuteCopy() { // Remove trailing semicolon(s) from the query before feeding it into COPY while (!query_.empty() && query_.back() == ';') { query_.pop_back(); @@ -125,20 +118,19 @@ AdbcStatusCode PqResultHelper::ExecuteCopy(struct AdbcError* error) { static_cast(Format::kBinary)); if (PQresultStatus(result_) != PGRES_COPY_OUT) { - AdbcStatusCode code = SetError( - error, result_, - "[libpq] Failed to execute query: could not begin COPY: %s\nQuery was: %s", + Status status = MakeStatus( + result_, + "[libpq] Failed to execute query: could not begin COPY: {}\nQuery was: {}", PQerrorMessage(conn_), copy_query.c_str()); ClearResult(); - return code; + return status; } - return ADBC_STATUS_OK; + return Status::Ok(); } -AdbcStatusCode PqResultHelper::ResolveParamTypes(PostgresTypeResolver& type_resolver, - PostgresType* param_types, - struct AdbcError* error) { +Status PqResultHelper::ResolveParamTypes(PostgresTypeResolver& type_resolver, + PostgresType* param_types) { struct ArrowError na_error; ArrowErrorInit(&na_error); @@ -149,22 +141,22 @@ AdbcStatusCode PqResultHelper::ResolveParamTypes(PostgresTypeResolver& type_reso const Oid pg_oid = PQparamtype(result_, i); PostgresType pg_type; if (type_resolver.Find(pg_oid, &pg_type, &na_error) != NANOARROW_OK) { - SetError(error, "%s%d%s%s%s%d", "[libpq] Parameter #", i + 1, " (\"", - PQfname(result_, i), "\") has unknown type code ", pg_oid); + Status status = Status::NotImplemented("[libpq] Parameter #", i + 1, " (\"", + PQfname(result_, i), + "\") has unknown type code ", pg_oid); ClearResult(); - return ADBC_STATUS_NOT_IMPLEMENTED; + return status; } root_type.AppendChild(PQfname(result_, i), pg_type); } *param_types = root_type; - return ADBC_STATUS_OK; + return Status::Ok(); } -AdbcStatusCode PqResultHelper::ResolveOutputTypes(PostgresTypeResolver& type_resolver, - PostgresType* result_types, - struct AdbcError* error) { +Status PqResultHelper::ResolveOutputTypes(PostgresTypeResolver& type_resolver, + PostgresType* result_types) { struct ArrowError na_error; ArrowErrorInit(&na_error); @@ -175,17 +167,18 @@ AdbcStatusCode PqResultHelper::ResolveOutputTypes(PostgresTypeResolver& type_res const Oid pg_oid = PQftype(result_, i); PostgresType pg_type; if (type_resolver.Find(pg_oid, &pg_type, &na_error) != NANOARROW_OK) { - SetError(error, "%s%d%s%s%s%d", "[libpq] Column #", i + 1, " (\"", - PQfname(result_, i), "\") has unknown type code ", pg_oid); + Status status = + Status::NotImplemented("[libpq] Column #", i + 1, " (\"", PQfname(result_, i), + "\") has unknown type code ", pg_oid); ClearResult(); - return ADBC_STATUS_NOT_IMPLEMENTED; + return status; } root_type.AppendChild(PQfname(result_, i), pg_type); } *result_types = root_type; - return ADBC_STATUS_OK; + return Status::Ok(); } PGresult* PqResultHelper::ReleaseResult() { @@ -194,7 +187,7 @@ PGresult* PqResultHelper::ReleaseResult() { return out; } -int64_t PqResultHelper::AffectedRows() { +int64_t PqResultHelper::AffectedRows() const { if (result_ == nullptr) { return -1; } diff --git a/3rd_party/apache-arrow-adbc/c/driver/postgresql/result_helper.h b/3rd_party/apache-arrow-adbc/c/driver/postgresql/result_helper.h index 18de795..1f3f93c 100644 --- a/3rd_party/apache-arrow-adbc/c/driver/postgresql/result_helper.h +++ b/3rd_party/apache-arrow-adbc/c/driver/postgresql/result_helper.h @@ -18,6 +18,8 @@ #pragma once #include +#include +#include #include #include #include @@ -28,6 +30,10 @@ #include #include "copy/reader.h" +#include "driver/framework/status.h" + +using adbc::driver::Result; +using adbc::driver::Status; namespace adbcpq { @@ -45,18 +51,46 @@ struct PqRecord { } return result; } + + Result ParseInteger() const { + const char* last = data + len; + int64_t value = 0; + auto result = std::from_chars(data, last, value, 10); + if (result.ec == std::errc() && result.ptr == last) { + return value; + } else { + return Status::Internal("Can't parse '", data, "' as integer"); + } + } + + Result> ParseTextArray() const { + std::string text_array(data, len); + text_array.erase(0, 1); + text_array.erase(text_array.size() - 1); + + std::vector elements; + std::stringstream ss(std::move(text_array)); + std::string tmp; + + while (getline(ss, tmp, ',')) { + elements.push_back(std::move(tmp)); + } + + return elements; + } + + std::string_view value() { return std::string_view(data, len); } }; // Used by PqResultHelper to provide index-based access to the records within each // row of a PGresult class PqResultRow { public: - PqResultRow(PGresult* result, int row_num) : result_(result), row_num_(row_num) { - ncols_ = PQnfields(result); - } + PqResultRow() : result_(nullptr), row_num_(-1) {} + PqResultRow(PGresult* result, int row_num) : result_(result), row_num_(row_num) {} - PqRecord operator[](const int& col_num) { - assert(col_num < ncols_); + PqRecord operator[](int col_num) const { + assert(col_num < PQnfields(result_)); const char* data = PQgetvalue(result_, row_num_, col_num); const int len = PQgetlength(result_, row_num_, col_num); const bool is_null = PQgetisnull(result_, row_num_, col_num); @@ -64,10 +98,15 @@ class PqResultRow { return PqRecord{data, len, is_null}; } + bool IsValid() const { + return result_ && row_num_ >= 0 && row_num_ < PQntuples(result_); + } + + PqResultRow Next() const { return PqResultRow(result_, row_num_ + 1); } + private: PGresult* result_ = nullptr; int row_num_; - int ncols_; }; // Helper to manager the lifecycle of a PQResult. The query argument @@ -95,19 +134,18 @@ class PqResultHelper { void set_param_format(Format format) { param_format_ = format; } void set_output_format(Format format) { output_format_ = format; } - AdbcStatusCode Prepare(struct AdbcError* error); - AdbcStatusCode Prepare(const std::vector& param_oids, struct AdbcError* error); - AdbcStatusCode DescribePrepared(struct AdbcError* error); - AdbcStatusCode Execute(struct AdbcError* error, - const std::vector& params = {}, - PostgresType* param_types = nullptr); - AdbcStatusCode ExecuteCopy(struct AdbcError* error); - AdbcStatusCode ResolveParamTypes(PostgresTypeResolver& type_resolver, - PostgresType* param_types, struct AdbcError* error); - AdbcStatusCode ResolveOutputTypes(PostgresTypeResolver& type_resolver, - PostgresType* result_types, struct AdbcError* error); + Status Prepare() const; + Status Prepare(const std::vector& param_oids) const; + Status DescribePrepared(); + Status Execute(const std::vector& params = {}, + PostgresType* param_types = nullptr); + Status ExecuteCopy(); + Status ResolveParamTypes(PostgresTypeResolver& type_resolver, + PostgresType* param_types); + Status ResolveOutputTypes(PostgresTypeResolver& type_resolver, + PostgresType* result_types); - bool HasResult() { return result_ != nullptr; } + bool HasResult() const { return result_ != nullptr; } void SetResult(PGresult* result) { ClearResult(); @@ -121,7 +159,7 @@ class PqResultHelper { result_ = nullptr; } - int64_t AffectedRows(); + int64_t AffectedRows() const; int NumRows() const { return PQntuples(result_); } @@ -131,6 +169,7 @@ class PqResultHelper { return PQfname(result_, column_number); } Oid FieldType(int column_number) const { return PQftype(result_, column_number); } + PqResultRow Row(int i) const { return PqResultRow(result_, i); } class iterator { const PqResultHelper& outer_; @@ -152,7 +191,7 @@ class PqResultHelper { return outer_.result_ == other.outer_.result_ && curr_row_ == other.curr_row_; } bool operator!=(iterator other) const { return !(*this == other); } - PqResultRow operator*() { return PqResultRow(outer_.result_, curr_row_); } + PqResultRow operator*() const { return PqResultRow(outer_.result_, curr_row_); } using iterator_category = std::forward_iterator_tag; using difference_type = std::ptrdiff_t; using value_type = std::vector; @@ -160,8 +199,8 @@ class PqResultHelper { using reference = const std::vector&; }; - iterator begin() { return iterator(*this); } - iterator end() { return iterator(*this, NumRows()); } + iterator begin() const { return iterator(*this); } + iterator end() const { return iterator(*this, NumRows()); } private: PGresult* result_ = nullptr; @@ -170,8 +209,7 @@ class PqResultHelper { Format param_format_ = Format::kText; Format output_format_ = Format::kText; - AdbcStatusCode PrepareInternal(int n_params, const Oid* param_oids, - struct AdbcError* error); + Status PrepareInternal(int n_params, const Oid* param_oids) const; }; } // namespace adbcpq diff --git a/3rd_party/apache-arrow-adbc/c/driver/postgresql/result_reader.cc b/3rd_party/apache-arrow-adbc/c/driver/postgresql/result_reader.cc index 21bc2bd..464bad7 100644 --- a/3rd_party/apache-arrow-adbc/c/driver/postgresql/result_reader.cc +++ b/3rd_party/apache-arrow-adbc/c/driver/postgresql/result_reader.cc @@ -21,9 +21,7 @@ #include #include "copy/reader.h" -#include "driver/common/utils.h" - -#include "error.h" +#include "driver/framework/status.h" namespace adbcpq { @@ -31,8 +29,9 @@ int PqResultArrayReader::GetSchema(struct ArrowSchema* out) { ResetErrors(); if (schema_->release == nullptr) { - AdbcStatusCode status = Initialize(nullptr, &error_); - if (status != ADBC_STATUS_OK) { + Status status = Initialize(nullptr); + if (!status.ok()) { + status.ToAdbc(&error_); return EINVAL; } } @@ -43,10 +42,11 @@ int PqResultArrayReader::GetSchema(struct ArrowSchema* out) { int PqResultArrayReader::GetNext(struct ArrowArray* out) { ResetErrors(); - AdbcStatusCode status; + Status status; if (schema_->release == nullptr) { - AdbcStatusCode status = Initialize(nullptr, &error_); - if (status != ADBC_STATUS_OK) { + status = Initialize(nullptr); + if (!status.ok()) { + status.ToAdbc(&error_); return EINVAL; } } @@ -63,8 +63,9 @@ int PqResultArrayReader::GetNext(struct ArrowArray* out) { } // Keep binding and executing until we have a result to return - status = BindNextAndExecute(nullptr, &error_); - if (status != ADBC_STATUS_OK) { + status = BindNextAndExecute(nullptr); + if (!status.ok()) { + status.ToAdbc(&error_); return EIO; } @@ -133,25 +134,24 @@ const char* PqResultArrayReader::GetLastError() { } } -AdbcStatusCode PqResultArrayReader::Initialize(int64_t* rows_affected, - struct AdbcError* error) { +Status PqResultArrayReader::Initialize(int64_t* rows_affected) { helper_.set_output_format(PqResultHelper::Format::kBinary); helper_.set_param_format(PqResultHelper::Format::kBinary); // If we have to do binding, set up the bind stream an execute until // there is a result with more than zero rows to populate. if (bind_stream_) { - RAISE_ADBC(bind_stream_->Begin([] { return ADBC_STATUS_OK; }, error)); - RAISE_ADBC(bind_stream_->SetParamTypes(*type_resolver_, error)); - RAISE_ADBC(helper_.Prepare(bind_stream_->param_types, error)); + UNWRAP_STATUS(bind_stream_->Begin([] { return Status::Ok(); })); - RAISE_ADBC(BindNextAndExecute(nullptr, error)); + UNWRAP_STATUS(bind_stream_->SetParamTypes(conn_, *type_resolver_, autocommit_)); + UNWRAP_STATUS(helper_.Prepare(bind_stream_->param_types)); + UNWRAP_STATUS(BindNextAndExecute(nullptr)); // If there were no arrays in the bind stream, we still need a result // to populate the schema. If there were any arrays in the bind stream, // the last one will still be in helper_ even if it had zero rows. if (!helper_.HasResult()) { - RAISE_ADBC(helper_.DescribePrepared(error)); + UNWRAP_STATUS(helper_.DescribePrepared()); } // We can't provide affected row counts if there is a bind stream and @@ -161,7 +161,7 @@ AdbcStatusCode PqResultArrayReader::Initialize(int64_t* rows_affected, *rows_affected = -1; } } else { - RAISE_ADBC(helper_.Execute(error)); + UNWRAP_STATUS(helper_.Execute()); if (rows_affected != nullptr) { *rows_affected = helper_.AffectedRows(); } @@ -169,90 +169,86 @@ AdbcStatusCode PqResultArrayReader::Initialize(int64_t* rows_affected, // Build the schema for which we are about to build results ArrowSchemaInit(schema_.get()); - CHECK_NA_DETAIL(INTERNAL, ArrowSchemaSetTypeStruct(schema_.get(), helper_.NumColumns()), - &na_error_, error); + UNWRAP_NANOARROW(na_error_, Internal, + ArrowSchemaSetTypeStruct(schema_.get(), helper_.NumColumns())); for (int i = 0; i < helper_.NumColumns(); i++) { PostgresType child_type; - CHECK_NA_DETAIL(INTERNAL, - type_resolver_->Find(helper_.FieldType(i), &child_type, &na_error_), - &na_error_, error); + UNWRAP_ERRNO(Internal, + type_resolver_->FindWithDefault(helper_.FieldType(i), &child_type)); - CHECK_NA(INTERNAL, child_type.SetSchema(schema_->children[i]), error); - CHECK_NA(INTERNAL, ArrowSchemaSetName(schema_->children[i], helper_.FieldName(i)), - error); + UNWRAP_ERRNO(Internal, child_type.SetSchema(schema_->children[i], vendor_name_)); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetName(schema_->children[i], helper_.FieldName(i))); std::unique_ptr child_reader; - CHECK_NA_DETAIL( - INTERNAL, - MakeCopyFieldReader(child_type, schema_->children[i], &child_reader, &na_error_), - &na_error_, error); + UNWRAP_NANOARROW( + na_error_, Internal, + MakeCopyFieldReader(child_type, schema_->children[i], &child_reader, &na_error_)); child_reader->Init(child_type); - CHECK_NA_DETAIL(INTERNAL, child_reader->InitSchema(schema_->children[i]), &na_error_, - error); + UNWRAP_NANOARROW(na_error_, Internal, child_reader->InitSchema(schema_->children[i])); field_readers_.push_back(std::move(child_reader)); } - return ADBC_STATUS_OK; + return Status::Ok(); } -AdbcStatusCode PqResultArrayReader::ToArrayStream(int64_t* affected_rows, - struct ArrowArrayStream* out, - struct AdbcError* error) { +Status PqResultArrayReader::ToArrayStream(int64_t* affected_rows, + struct ArrowArrayStream* out) { if (out == nullptr) { // If there is no output requested, we still need to execute and // set affected_rows if needed. We don't need an output schema or to set up a copy // reader, so we can skip those steps by going straight to Execute(). This also // enables us to support queries with multiple statements because we can call PQexec() // instead of PQexecParams(). - RAISE_ADBC(ExecuteAll(affected_rows, error)); - return ADBC_STATUS_OK; + UNWRAP_STATUS(ExecuteAll(affected_rows)); + return Status::Ok(); } // Otherwise, execute until we have a result to return. We need this to provide row // counts for DELETE and CREATE TABLE queries as well as to provide more informative // errors until this reader class is wired up to provide extended AdbcError information. - RAISE_ADBC(Initialize(affected_rows, error)); + UNWRAP_STATUS(Initialize(affected_rows)); nanoarrow::ArrayStreamFactory::InitArrayStream( new PqResultArrayReader(this), out); - return ADBC_STATUS_OK; + return Status::Ok(); } -AdbcStatusCode PqResultArrayReader::BindNextAndExecute(int64_t* affected_rows, - AdbcError* error) { +Status PqResultArrayReader::BindNextAndExecute(int64_t* affected_rows) { // Keep pulling from the bind stream and executing as long as // we receive results with zero rows. do { - RAISE_ADBC(bind_stream_->EnsureNextRow(error)); + UNWRAP_STATUS(bind_stream_->EnsureNextRow()); + if (!bind_stream_->current->release) { - RAISE_ADBC(bind_stream_->Cleanup(conn_, error)); + UNWRAP_STATUS(bind_stream_->Cleanup(conn_)); bind_stream_.reset(); - return ADBC_STATUS_OK; + return Status::Ok(); } PGresult* result; - RAISE_ADBC(bind_stream_->BindAndExecuteCurrentRow( - conn_, &result, /*result_format*/ kPgBinaryFormat, error)); + UNWRAP_STATUS(bind_stream_->BindAndExecuteCurrentRow( + conn_, &result, /*result_format*/ kPgBinaryFormat)); helper_.SetResult(result); if (affected_rows) { (*affected_rows) += helper_.AffectedRows(); } } while (helper_.NumRows() == 0); - return ADBC_STATUS_OK; + return Status::Ok(); } -AdbcStatusCode PqResultArrayReader::ExecuteAll(int64_t* affected_rows, AdbcError* error) { +Status PqResultArrayReader::ExecuteAll(int64_t* affected_rows) { // For the case where we don't need a result, we either need to exhaust the bind // stream (if there is one) or execute the query without binding. if (bind_stream_) { - RAISE_ADBC(bind_stream_->Begin([] { return ADBC_STATUS_OK; }, error)); - RAISE_ADBC(bind_stream_->SetParamTypes(*type_resolver_, error)); - RAISE_ADBC(helper_.Prepare(bind_stream_->param_types, error)); + UNWRAP_STATUS(bind_stream_->Begin([] { return Status::Ok(); })); + UNWRAP_STATUS(bind_stream_->SetParamTypes(conn_, *type_resolver_, autocommit_)); + UNWRAP_STATUS(helper_.Prepare(bind_stream_->param_types)); // Reset affected rows to zero before binding and executing any if (affected_rows) { @@ -260,17 +256,17 @@ AdbcStatusCode PqResultArrayReader::ExecuteAll(int64_t* affected_rows, AdbcError } do { - RAISE_ADBC(BindNextAndExecute(affected_rows, error)); + UNWRAP_STATUS(BindNextAndExecute(affected_rows)); } while (bind_stream_); } else { - RAISE_ADBC(helper_.Execute(error)); + UNWRAP_STATUS(helper_.Execute()); if (affected_rows != nullptr) { *affected_rows = helper_.AffectedRows(); } } - return ADBC_STATUS_OK; + return Status::Ok(); } } // namespace adbcpq diff --git a/3rd_party/apache-arrow-adbc/c/driver/postgresql/result_reader.h b/3rd_party/apache-arrow-adbc/c/driver/postgresql/result_reader.h index 51da639..90b35ba 100644 --- a/3rd_party/apache-arrow-adbc/c/driver/postgresql/result_reader.h +++ b/3rd_party/apache-arrow-adbc/c/driver/postgresql/result_reader.h @@ -17,6 +17,10 @@ #pragma once +#if !defined(NOMINMAX) +#define NOMINMAX +#endif + #include #include #include @@ -34,26 +38,37 @@ class PqResultArrayReader { public: PqResultArrayReader(PGconn* conn, std::shared_ptr type_resolver, std::string query) - : conn_(conn), helper_(conn, std::move(query)), type_resolver_(type_resolver) { + : conn_(conn), + helper_(conn, std::move(query)), + type_resolver_(type_resolver), + autocommit_(false) { ArrowErrorInit(&na_error_); error_ = ADBC_ERROR_INIT; } ~PqResultArrayReader() { ResetErrors(); } + // Ensure the reader knows what the autocommit status was on creation. This is used + // so that the temporary timezone setting required for parameter binding can be wrapped + // in a transaction (or not) accordingly. + void SetAutocommit(bool autocommit) { autocommit_ = autocommit; } + void SetBind(struct ArrowArrayStream* stream) { bind_stream_ = std::make_unique(); bind_stream_->SetBind(stream); } + void SetVendorName(std::string_view vendor_name) { + vendor_name_ = std::string(vendor_name); + } + int GetSchema(struct ArrowSchema* out); int GetNext(struct ArrowArray* out); const char* GetLastError(); - AdbcStatusCode ToArrayStream(int64_t* affected_rows, struct ArrowArrayStream* out, - struct AdbcError* error); + Status ToArrayStream(int64_t* affected_rows, struct ArrowArrayStream* out); - AdbcStatusCode Initialize(int64_t* affected_rows, struct AdbcError* error); + Status Initialize(int64_t* affected_rows); private: PGconn* conn_; @@ -62,6 +77,8 @@ class PqResultArrayReader { std::shared_ptr type_resolver_; std::vector> field_readers_; nanoarrow::UniqueSchema schema_; + bool autocommit_; + std::string vendor_name_; struct AdbcError error_; struct ArrowError na_error_; @@ -76,8 +93,8 @@ class PqResultArrayReader { error_ = ADBC_ERROR_INIT; } - AdbcStatusCode BindNextAndExecute(int64_t* affected_rows, AdbcError* error); - AdbcStatusCode ExecuteAll(int64_t* affected_rows, AdbcError* error); + Status BindNextAndExecute(int64_t* affected_rows); + Status ExecuteAll(int64_t* affected_rows); void ResetErrors() { ArrowErrorInit(&na_error_); diff --git a/3rd_party/apache-arrow-adbc/c/driver/postgresql/statement.cc b/3rd_party/apache-arrow-adbc/c/driver/postgresql/statement.cc index 224472b..129ddeb 100644 --- a/3rd_party/apache-arrow-adbc/c/driver/postgresql/statement.cc +++ b/3rd_party/apache-arrow-adbc/c/driver/postgresql/statement.cc @@ -40,6 +40,7 @@ #include "connection.h" #include "driver/common/options.h" #include "driver/common/utils.h" +#include "driver/framework/utility.h" #include "error.h" #include "postgres_type.h" #include "postgres_util.h" @@ -50,6 +51,7 @@ namespace adbcpq { int TupleReader::GetSchema(struct ArrowSchema* out) { assert(copy_reader_ != nullptr); + ArrowErrorInit(&na_error_); int na_res = copy_reader_->GetSchema(out); if (out->release == nullptr) { @@ -65,75 +67,74 @@ int TupleReader::GetSchema(struct ArrowSchema* out) { return na_res; } -int TupleReader::InitQueryAndFetchFirst(struct ArrowError* error) { - // Fetch + parse the header - int get_copy_res = PQgetCopyData(conn_, &pgbuf_, /*async=*/0); - data_.size_bytes = get_copy_res; - data_.data.as_char = pgbuf_; +int TupleReader::GetCopyData() { + if (pgbuf_ != nullptr) { + PQfreemem(pgbuf_); + pgbuf_ = nullptr; + } + data_.size_bytes = 0; + data_.data.as_char = nullptr; + + int get_copy_res = PQgetCopyData(conn_, &pgbuf_, /*async=*/0); if (get_copy_res == -2) { - SetError(&error_, "[libpq] Fetch header failed: %s", PQerrorMessage(conn_)); + SetError(&error_, "[libpq] PQgetCopyData() failed: %s", PQerrorMessage(conn_)); status_ = ADBC_STATUS_IO; return AdbcStatusCodeToErrno(status_); } - int na_res = copy_reader_->ReadHeader(&data_, error); - if (na_res != NANOARROW_OK) { - SetError(&error_, "[libpq] ReadHeader failed: %s", error->message); - status_ = ADBC_STATUS_IO; - return AdbcStatusCodeToErrno(status_); + if (get_copy_res == -1) { + // Check the server-side response + PQclear(result_); + result_ = PQgetResult(conn_); + const ExecStatusType pq_status = PQresultStatus(result_); + if (pq_status != PGRES_COMMAND_OK) { + status_ = SetError(&error_, result_, "[libpq] Execution error [%s]: %s", + PQresStatus(pq_status), PQresultErrorMessage(result_)); + return AdbcStatusCodeToErrno(status_); + } else { + return ENODATA; + } } + data_.size_bytes = get_copy_res; + data_.data.as_char = pgbuf_; return NANOARROW_OK; } -int TupleReader::AppendRowAndFetchNext(struct ArrowError* error) { +int TupleReader::AppendRowAndFetchNext() { // Parse the result (the header AND the first row are included in the first // call to PQgetCopyData()) - int na_res = copy_reader_->ReadRecord(&data_, error); + int na_res = copy_reader_->ReadRecord(&data_, &na_error_); if (na_res != NANOARROW_OK && na_res != ENODATA) { SetError(&error_, "[libpq] ReadRecord failed at row %" PRId64 ": %s", row_id_, - error->message); + na_error_.message); status_ = ADBC_STATUS_IO; return na_res; } row_id_++; - // Fetch + check - PQfreemem(pgbuf_); - pgbuf_ = nullptr; - int get_copy_res = PQgetCopyData(conn_, &pgbuf_, /*async=*/0); - data_.size_bytes = get_copy_res; - data_.data.as_char = pgbuf_; - - if (get_copy_res == -2) { - SetError(&error_, "[libpq] PQgetCopyData failed at row %" PRId64 ": %s", row_id_, - PQerrorMessage(conn_)); - status_ = ADBC_STATUS_IO; - return AdbcStatusCodeToErrno(status_); - } else if (get_copy_res == -1) { - // Returned when COPY has finished successfully - return ENODATA; - } else if ((copy_reader_->array_size_approx_bytes() + get_copy_res) >= - batch_size_hint_bytes_) { + NANOARROW_RETURN_NOT_OK(GetCopyData()); + if ((copy_reader_->array_size_approx_bytes() + data_.size_bytes) >= + batch_size_hint_bytes_) { // Appending the next row will result in an array larger than requested. // Return EOVERFLOW to force GetNext() to build the current result and return. return EOVERFLOW; - } else { - return NANOARROW_OK; } + + return NANOARROW_OK; } -int TupleReader::BuildOutput(struct ArrowArray* out, struct ArrowError* error) { +int TupleReader::BuildOutput(struct ArrowArray* out) { if (copy_reader_->array_size_approx_bytes() == 0) { out->release = nullptr; return NANOARROW_OK; } - int na_res = copy_reader_->GetArray(out, error); + int na_res = copy_reader_->GetArray(out, &na_error_); if (na_res != NANOARROW_OK) { - SetError(&error_, "[libpq] Failed to build result array: %s", error->message); + SetError(&error_, "[libpq] Failed to build result array: %s", na_error_.message); status_ = ADBC_STATUS_INTERNAL; return na_res; } @@ -147,22 +148,35 @@ int TupleReader::GetNext(struct ArrowArray* out) { return 0; } - struct ArrowError error; - error.message[0] = '\0'; + int na_res; + ArrowErrorInit(&na_error_); if (row_id_ == -1) { - NANOARROW_RETURN_NOT_OK(InitQueryAndFetchFirst(&error)); + na_res = GetCopyData(); + if (na_res == ENODATA) { + is_finished_ = true; + out->release = nullptr; + return 0; + } else if (na_res != NANOARROW_OK) { + return na_res; + } + + na_res = copy_reader_->ReadHeader(&data_, &na_error_); + if (na_res != NANOARROW_OK) { + SetError(&error_, "[libpq] ReadHeader() failed: %s", na_error_.message); + return na_res; + } + row_id_++; } - int na_res; do { - na_res = AppendRowAndFetchNext(&error); + na_res = AppendRowAndFetchNext(); if (na_res == EOVERFLOW) { // The result would be too big to return if we appended the row. When EOVERFLOW is // returned, the copy reader leaves the output in a valid state. The data is left in // pg_buf_/data_ and will attempt to be appended on the next call to GetNext() - return BuildOutput(out, &error); + return BuildOutput(out); } } while (na_res == NANOARROW_OK); @@ -175,31 +189,7 @@ int TupleReader::GetNext(struct ArrowArray* out) { // Finish the result properly and return the last result. Note that BuildOutput() may // set tmp.release = nullptr if there were zero rows in the copy reader (can // occur in an overflow scenario). - struct ArrowArray tmp; - NANOARROW_RETURN_NOT_OK(BuildOutput(&tmp, &error)); - - PQclear(result_); - // Check the server-side response - result_ = PQgetResult(conn_); - const ExecStatusType pq_status = PQresultStatus(result_); - if (pq_status != PGRES_COMMAND_OK) { - const char* sqlstate = PQresultErrorField(result_, PG_DIAG_SQLSTATE); - SetError(&error_, result_, "[libpq] Query failed [%s]: %s", PQresStatus(pq_status), - PQresultErrorMessage(result_)); - - if (tmp.release != nullptr) { - tmp.release(&tmp); - } - - if (sqlstate != nullptr && std::strcmp(sqlstate, "57014") == 0) { - status_ = ADBC_STATUS_CANCELLED; - } else { - status_ = ADBC_STATUS_IO; - } - return AdbcStatusCodeToErrno(status_); - } - - ArrowArrayMove(&tmp, out); + NANOARROW_RETURN_NOT_OK(BuildOutput(out)); return NANOARROW_OK; } @@ -307,7 +297,7 @@ AdbcStatusCode PostgresStatement::Bind(struct ArrowArray* values, if (bind_.release) bind_.release(&bind_); // Make a one-value stream - nanoarrow::VectorArrayStream(schema, values).ToArrayStream(&bind_); + adbc::driver::MakeArrayStream(schema, values, &bind_); return ADBC_STATUS_OK; } @@ -329,11 +319,11 @@ AdbcStatusCode PostgresStatement::Cancel(struct AdbcError* error) { return connection_->Cancel(error); } -AdbcStatusCode PostgresStatement::CreateBulkTable( - const std::string& current_schema, const struct ArrowSchema& source_schema, - const std::vector& source_schema_fields, - std::string* escaped_table, std::string* escaped_field_list, - struct AdbcError* error) { +AdbcStatusCode PostgresStatement::CreateBulkTable(const std::string& current_schema, + const struct ArrowSchema& source_schema, + std::string* escaped_table, + std::string* escaped_field_list, + struct AdbcError* error) { PGconn* conn = connection_->conn(); if (!ingest_.db_schema.empty() && ingest_.temporary) { @@ -416,7 +406,7 @@ AdbcStatusCode PostgresStatement::CreateBulkTable( create += *escaped_table; create += " ("; - for (size_t i = 0; i < source_schema_fields.size(); i++) { + for (int64_t i = 0; i < source_schema.n_children; i++) { if (i > 0) { create += ", "; *escaped_field_list += ", "; @@ -433,82 +423,13 @@ AdbcStatusCode PostgresStatement::CreateBulkTable( *escaped_field_list += escaped; PQfreemem(escaped); - switch (source_schema_fields[i].type) { - case ArrowType::NANOARROW_TYPE_BOOL: - create += " BOOLEAN"; - break; - case ArrowType::NANOARROW_TYPE_INT8: - case ArrowType::NANOARROW_TYPE_INT16: - create += " SMALLINT"; - break; - case ArrowType::NANOARROW_TYPE_INT32: - create += " INTEGER"; - break; - case ArrowType::NANOARROW_TYPE_INT64: - create += " BIGINT"; - break; - case ArrowType::NANOARROW_TYPE_FLOAT: - create += " REAL"; - break; - case ArrowType::NANOARROW_TYPE_DOUBLE: - create += " DOUBLE PRECISION"; - break; - case ArrowType::NANOARROW_TYPE_STRING: - case ArrowType::NANOARROW_TYPE_LARGE_STRING: - create += " TEXT"; - break; - case ArrowType::NANOARROW_TYPE_BINARY: - create += " BYTEA"; - break; - case ArrowType::NANOARROW_TYPE_DATE32: - create += " DATE"; - break; - case ArrowType::NANOARROW_TYPE_TIMESTAMP: - if (strcmp("", source_schema_fields[i].timezone)) { - create += " TIMESTAMPTZ"; - } else { - create += " TIMESTAMP"; - } - break; - case ArrowType::NANOARROW_TYPE_DURATION: - case ArrowType::NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO: - create += " INTERVAL"; - break; - case ArrowType::NANOARROW_TYPE_DECIMAL128: - case ArrowType::NANOARROW_TYPE_DECIMAL256: - create += " DECIMAL"; - break; - case ArrowType::NANOARROW_TYPE_DICTIONARY: { - struct ArrowSchemaView value_view; - CHECK_NA(INTERNAL, - ArrowSchemaViewInit(&value_view, source_schema.children[i]->dictionary, - nullptr), - error); - switch (value_view.type) { - case NANOARROW_TYPE_BINARY: - case NANOARROW_TYPE_LARGE_BINARY: - create += " BYTEA"; - break; - case NANOARROW_TYPE_STRING: - case NANOARROW_TYPE_LARGE_STRING: - create += " TEXT"; - break; - default: - SetError(error, "%s%" PRIu64 "%s%s%s%s", "[libpq] Field #", - static_cast(i + 1), " ('", source_schema.children[i]->name, - "') has unsupported dictionary value type for ingestion ", - ArrowTypeString(value_view.type)); - return ADBC_STATUS_NOT_IMPLEMENTED; - } - break; - } - default: - SetError(error, "%s%" PRIu64 "%s%s%s%s", "[libpq] Field #", - static_cast(i + 1), " ('", source_schema.children[i]->name, - "') has unsupported type for ingestion ", - ArrowTypeString(source_schema_fields[i].type)); - return ADBC_STATUS_NOT_IMPLEMENTED; - } + PostgresType pg_type; + struct ArrowError na_error; + CHECK_NA_DETAIL(INTERNAL, + PostgresType::FromSchema(*type_resolver_, source_schema.children[i], + &pg_type, &na_error), + &na_error, error); + create += " " + pg_type.sql_type_name(); } if (ingest_.mode == IngestMode::kAppend) { @@ -536,8 +457,10 @@ AdbcStatusCode PostgresStatement::ExecuteBind(struct ArrowArrayStream* stream, int64_t* rows_affected, struct AdbcError* error) { PqResultArrayReader reader(connection_->conn(), type_resolver_, query_); + reader.SetAutocommit(connection_->autocommit()); reader.SetBind(&bind_); - RAISE_ADBC(reader.ToArrayStream(rows_affected, stream, error)); + reader.SetVendorName(connection_->VendorName()); + RAISE_STATUS(error, reader.ToArrayStream(rows_affected, stream)); return ADBC_STATUS_OK; } @@ -563,41 +486,45 @@ AdbcStatusCode PostgresStatement::ExecuteQuery(struct ArrowArrayStream* stream, // If we have been requested to avoid COPY or there is no output requested, // execute using the PqResultArrayReader. - if (!stream || !use_copy_) { + if (!stream || !UseCopy()) { PqResultArrayReader reader(connection_->conn(), type_resolver_, query_); - RAISE_ADBC(reader.ToArrayStream(rows_affected, stream, error)); + reader.SetVendorName(connection_->VendorName()); + RAISE_STATUS(error, reader.ToArrayStream(rows_affected, stream)); return ADBC_STATUS_OK; } PqResultHelper helper(connection_->conn(), query_); - RAISE_ADBC(helper.Prepare(error)); - RAISE_ADBC(helper.DescribePrepared(error)); + RAISE_STATUS(error, helper.Prepare()); + RAISE_STATUS(error, helper.DescribePrepared()); // Initialize the copy reader and infer the output schema (i.e., error for // unsupported types before issuing the COPY query). This could be lazier // (i.e., executed on the first call to GetSchema() or GetNext()). PostgresType root_type; - RAISE_ADBC(helper.ResolveOutputTypes(*type_resolver_, &root_type, error)); + RAISE_STATUS(error, helper.ResolveOutputTypes(*type_resolver_, &root_type)); // If there will be no columns in the result, we can also avoid COPY if (root_type.n_children() == 0) { // Could/should move the helper into the reader instead of repreparing PqResultArrayReader reader(connection_->conn(), type_resolver_, query_); - RAISE_ADBC(reader.ToArrayStream(rows_affected, stream, error)); + reader.SetVendorName(connection_->VendorName()); + RAISE_STATUS(error, reader.ToArrayStream(rows_affected, stream)); return ADBC_STATUS_OK; } struct ArrowError na_error; reader_.copy_reader_ = std::make_unique(); CHECK_NA(INTERNAL, reader_.copy_reader_->Init(root_type), error); - CHECK_NA_DETAIL(INTERNAL, reader_.copy_reader_->InferOutputSchema(&na_error), &na_error, - error); + CHECK_NA_DETAIL(INTERNAL, + reader_.copy_reader_->InferOutputSchema( + std::string(connection_->VendorName()), &na_error), + &na_error, error); CHECK_NA_DETAIL(INTERNAL, reader_.copy_reader_->InitFieldReaders(&na_error), &na_error, error); // Execute the COPY query - RAISE_ADBC(helper.ExecuteCopy(error)); + RAISE_STATUS(error, helper.ExecuteCopy()); // We need the PQresult back for the reader reader_.result_ = helper.ReleaseResult(); @@ -640,19 +567,21 @@ AdbcStatusCode PostgresStatement::ExecuteSchema(struct ArrowSchema* schema, param_oids[i] = pg_type.oid(); } - RAISE_ADBC(helper.Prepare(param_oids, error)); + RAISE_STATUS(error, helper.Prepare(param_oids)); } else { - RAISE_ADBC(helper.Prepare(error)); + RAISE_STATUS(error, helper.Prepare()); } - RAISE_ADBC(helper.DescribePrepared(error)); + RAISE_STATUS(error, helper.DescribePrepared()); PostgresType output_type; - RAISE_ADBC(helper.ResolveOutputTypes(*type_resolver_, &output_type, error)); + RAISE_STATUS(error, helper.ResolveOutputTypes(*type_resolver_, &output_type)); nanoarrow::UniqueSchema tmp; ArrowSchemaInit(tmp.get()); - CHECK_NA(INTERNAL, output_type.SetSchema(tmp.get()), error); + CHECK_NA(INTERNAL, + output_type.SetSchema(tmp.get(), std::string(connection_->VendorName())), + error); tmp.move(schema); return ADBC_STATUS_OK; @@ -675,11 +604,12 @@ AdbcStatusCode PostgresStatement::ExecuteIngest(struct ArrowArrayStream* stream, // This is a little unfortunate; we need another DB roundtrip std::string current_schema; { - PqResultHelper result_helper{connection_->conn(), "SELECT CURRENT_SCHEMA"}; - RAISE_ADBC(result_helper.Execute(error)); + PqResultHelper result_helper{connection_->conn(), "SELECT CURRENT_SCHEMA()"}; + RAISE_STATUS(error, result_helper.Execute()); auto it = result_helper.begin(); if (it == result_helper.end()) { - SetError(error, "[libpq] PostgreSQL returned no rows for 'SELECT CURRENT_SCHEMA'"); + SetError(error, + "[libpq] PostgreSQL returned no rows for 'SELECT CURRENT_SCHEMA()'"); return ADBC_STATUS_INTERNAL; } current_schema = (*it)[0].data; @@ -690,14 +620,13 @@ AdbcStatusCode PostgresStatement::ExecuteIngest(struct ArrowArrayStream* stream, std::memset(&bind_, 0, sizeof(bind_)); std::string escaped_table; std::string escaped_field_list; - RAISE_ADBC(bind_stream.Begin( - [&]() -> AdbcStatusCode { - return CreateBulkTable(current_schema, bind_stream.bind_schema.value, - bind_stream.bind_schema_fields, &escaped_table, - &escaped_field_list, error); - }, - error)); - RAISE_ADBC(bind_stream.SetParamTypes(*type_resolver_, error)); + RAISE_STATUS(error, bind_stream.Begin([&]() -> Status { + struct AdbcError tmp_error = ADBC_ERROR_INIT; + AdbcStatusCode status_code = + CreateBulkTable(current_schema, bind_stream.bind_schema.value, &escaped_table, + &escaped_field_list, &tmp_error); + return Status::FromAdbc(status_code, tmp_error); + })); std::string query = "COPY "; query += escaped_table; @@ -714,8 +643,9 @@ AdbcStatusCode PostgresStatement::ExecuteIngest(struct ArrowArrayStream* stream, } PQclear(result); - RAISE_ADBC(bind_stream.ExecuteCopy(connection_->conn(), *connection_->type_resolver(), - rows_affected, error)); + RAISE_STATUS(error, + bind_stream.ExecuteCopy(connection_->conn(), *connection_->type_resolver(), + rows_affected)); return ADBC_STATUS_OK; } @@ -744,7 +674,7 @@ AdbcStatusCode PostgresStatement::GetOption(const char* key, char* value, size_t } else if (std::strcmp(key, ADBC_POSTGRESQL_OPTION_BATCH_SIZE_HINT_BYTES) == 0) { result = std::to_string(reader_.batch_size_hint_bytes_); } else if (std::strcmp(key, ADBC_POSTGRESQL_OPTION_USE_COPY) == 0) { - if (use_copy_) { + if (UseCopy()) { result = "true"; } else { result = "false"; @@ -916,4 +846,12 @@ void PostgresStatement::ClearResult() { reader_.Release(); } +int PostgresStatement::UseCopy() { + if (use_copy_ == -1) { + return connection_->VendorName() != "Redshift"; + } else { + return use_copy_; + } +} + } // namespace adbcpq diff --git a/3rd_party/apache-arrow-adbc/c/driver/postgresql/statement.h b/3rd_party/apache-arrow-adbc/c/driver/postgresql/statement.h index 1cd60bf..60ada99 100644 --- a/3rd_party/apache-arrow-adbc/c/driver/postgresql/statement.h +++ b/3rd_party/apache-arrow-adbc/c/driver/postgresql/statement.h @@ -52,6 +52,7 @@ class TupleReader final { row_id_(-1), batch_size_hint_bytes_(16777216), is_finished_(false) { + ArrowErrorInit(&na_error_); data_.data.as_char = nullptr; data_.size_bytes = 0; } @@ -68,9 +69,9 @@ class TupleReader final { private: friend class PostgresStatement; - int InitQueryAndFetchFirst(struct ArrowError* error); - int AppendRowAndFetchNext(struct ArrowError* error); - int BuildOutput(struct ArrowArray* out, struct ArrowError* error); + int GetCopyData(); + int AppendRowAndFetchNext(); + int BuildOutput(struct ArrowArray* out); static int GetSchemaTrampoline(struct ArrowArrayStream* self, struct ArrowSchema* out); static int GetNextTrampoline(struct ArrowArrayStream* self, struct ArrowArray* out); @@ -79,6 +80,7 @@ class TupleReader final { AdbcStatusCode status_; struct AdbcError error_; + struct ArrowError na_error_; PGconn* conn_; PGresult* result_; char* pgbuf_; @@ -95,7 +97,7 @@ class PostgresStatement { : connection_(nullptr), query_(), prepared_(false), - use_copy_(true), + use_copy_(-1), reader_(nullptr) { std::memset(&bind_, 0, sizeof(bind_)); } @@ -131,11 +133,11 @@ class PostgresStatement { // Helper methods void ClearResult(); - AdbcStatusCode CreateBulkTable( - const std::string& current_schema, const struct ArrowSchema& source_schema, - const std::vector& source_schema_fields, - std::string* escaped_table, std::string* escaped_field_list, - struct AdbcError* error); + AdbcStatusCode CreateBulkTable(const std::string& current_schema, + const struct ArrowSchema& source_schema, + std::string* escaped_table, + std::string* escaped_field_list, + struct AdbcError* error); AdbcStatusCode ExecuteIngest(struct ArrowArrayStream* stream, int64_t* rows_affected, struct AdbcError* error); AdbcStatusCode ExecuteBind(struct ArrowArrayStream* stream, int64_t* rows_affected, @@ -159,7 +161,7 @@ class PostgresStatement { }; // Options - bool use_copy_; + int use_copy_; struct { std::string db_schema; @@ -169,5 +171,7 @@ class PostgresStatement { } ingest_; TupleReader reader_; + + int UseCopy(); }; } // namespace adbcpq diff --git a/3rd_party/apache-arrow-adbc/c/driver/snowflake/snowflake_test.cc b/3rd_party/apache-arrow-adbc/c/driver/snowflake/snowflake_test.cc index 6000335..2622861 100644 --- a/3rd_party/apache-arrow-adbc/c/driver/snowflake/snowflake_test.cc +++ b/3rd_party/apache-arrow-adbc/c/driver/snowflake/snowflake_test.cc @@ -99,7 +99,7 @@ class SnowflakeQuirks : public adbc_validation::DriverQuirks { adbc_validation::Handle statement; CHECK_OK(AdbcStatementNew(connection, &statement.value, error)); - std::string create = "CREATE TABLE \""; + std::string create = "CREATE OR REPLACE TABLE \""; create += name; create += "\" (int64s INT, strings TEXT)"; CHECK_OK(AdbcStatementSetSqlQuery(&statement.value, create.c_str(), error)); @@ -131,7 +131,13 @@ class SnowflakeQuirks : public adbc_validation::DriverQuirks { return NANOARROW_TYPE_DOUBLE; case NANOARROW_TYPE_STRING: case NANOARROW_TYPE_LARGE_STRING: + case NANOARROW_TYPE_LIST: + case NANOARROW_TYPE_LARGE_LIST: return NANOARROW_TYPE_STRING; + case NANOARROW_TYPE_BINARY: + case NANOARROW_TYPE_LARGE_BINARY: + case NANOARROW_TYPE_FIXED_SIZE_BINARY: + return NANOARROW_TYPE_BINARY; default: return ingest_type; } @@ -149,7 +155,11 @@ class SnowflakeQuirks : public adbc_validation::DriverQuirks { bool supports_dynamic_parameter_binding() const override { return true; } bool supports_error_on_incompatible_schema() const override { return false; } bool ddl_implicit_commit_txn() const override { return true; } + bool supports_ingest_view_types() const override { return false; } + bool supports_ingest_float16() const override { return false; } + std::string db_schema() const override { return schema_; } + std::string catalog() const override { return "ADBC_TESTING"; } const char* uri_; bool skip_{false}; diff --git a/3rd_party/apache-arrow-adbc/c/driver/sqlite/sqlite.cc b/3rd_party/apache-arrow-adbc/c/driver/sqlite/sqlite.cc index 6628acd..a5186d0 100644 --- a/3rd_party/apache-arrow-adbc/c/driver/sqlite/sqlite.cc +++ b/3rd_party/apache-arrow-adbc/c/driver/sqlite/sqlite.cc @@ -15,19 +15,14 @@ // specific language governing permissions and limitations // under the License. -#include #include #include -#include #include #include #define ADBC_FRAMEWORK_USE_FMT -#include "driver/common/options.h" -#include "driver/common/utils.h" #include "driver/framework/base_driver.h" -#include "driver/framework/catalog.h" #include "driver/framework/connection.h" #include "driver/framework/database.h" #include "driver/framework/statement.h" @@ -223,10 +218,6 @@ struct SqliteGetObjectsHelper : public driver::GetObjectsHelper { std::string query = "SELECT DISTINCT name FROM pragma_database_list() WHERE name LIKE ?"; - this->table_filter = table_filter; - this->column_filter = column_filter; - this->table_types = table_types; - UNWRAP_STATUS(SqliteQuery::Scan( conn, query, [&](sqlite3_stmt* stmt) { @@ -250,14 +241,17 @@ struct SqliteGetObjectsHelper : public driver::GetObjectsHelper { return status::Ok(); } - Status LoadCatalogs() override { return status::Ok(); }; + Status LoadCatalogs(std::optional catalog_filter) override { + return status::Ok(); + }; Result> NextCatalog() override { if (next_catalog >= catalogs.size()) return std::nullopt; return catalogs[next_catalog++]; } - Status LoadSchemas(std::string_view catalog) override { + Status LoadSchemas(std::string_view catalog, + std::optional schema_filter) override { next_schema = 0; return status::Ok(); }; @@ -267,7 +261,9 @@ struct SqliteGetObjectsHelper : public driver::GetObjectsHelper { return schemas[next_schema++]; } - Status LoadTables(std::string_view catalog, std::string_view schema) override { + Status LoadTables(std::string_view catalog, std::string_view schema, + std::optional table_filter, + const std::vector& table_types) override { next_table = 0; tables.clear(); if (!schema.empty()) return status::Ok(); @@ -310,7 +306,8 @@ struct SqliteGetObjectsHelper : public driver::GetObjectsHelper { } Status LoadColumns(std::string_view catalog, std::string_view schema, - std::string_view table) override { + std::string_view table, + std::optional column_filter) override { // XXX: pragma_table_info doesn't appear to work with bind parameters // XXX: because we're saving the SqliteQuery, we also need to save the string builder columns_query.Reset(); @@ -487,9 +484,6 @@ struct SqliteGetObjectsHelper : public driver::GetObjectsHelper { }; sqlite3* conn = nullptr; - std::optional table_filter; - std::optional column_filter; - std::vector table_types; std::vector catalogs; std::vector schemas; std::vector> tables; @@ -934,21 +928,25 @@ class SqliteStatement : public driver::Statement { } assert(stmt != nullptr); - AdbcStatusCode status = ADBC_STATUS_OK; + AdbcStatusCode status_code = ADBC_STATUS_OK; + Status status = status::Ok(); struct AdbcError error = ADBC_ERROR_INIT; while (true) { char finished = 0; - status = AdbcSqliteBinderBindNext(&binder_, conn_, stmt, &finished, &error); - if (status != ADBC_STATUS_OK || finished) break; + status_code = AdbcSqliteBinderBindNext(&binder_, conn_, stmt, &finished, &error); + if (status_code != ADBC_STATUS_OK || finished) { + status = Status::FromAdbc(status_code, error); + break; + } int rc = 0; do { rc = sqlite3_step(stmt); } while (rc == SQLITE_ROW); if (rc != SQLITE_DONE) { - SetError(&error, "failed to execute: %s\nquery was: %s", sqlite3_errmsg(conn_), - insert.data()); - status = ADBC_STATUS_INTERNAL; + status = status::fmt::Internal("failed to execute: {}\nquery was: {}", + sqlite3_errmsg(conn_), insert.data()); + status_code = ADBC_STATUS_INTERNAL; break; } row_count++; @@ -956,15 +954,15 @@ class SqliteStatement : public driver::Statement { std::ignore = sqlite3_finalize(stmt); if (is_autocommit) { - if (status == ADBC_STATUS_OK) { + if (status_code == ADBC_STATUS_OK) { UNWRAP_STATUS(::adbc::sqlite::SqliteQuery::Execute(conn_, "COMMIT")); } else { UNWRAP_STATUS(::adbc::sqlite::SqliteQuery::Execute(conn_, "ROLLBACK")); } } - if (status != ADBC_STATUS_OK) { - return Status::FromAdbc(status, error); + if (status_code != ADBC_STATUS_OK) { + return status; } return row_count; } @@ -1008,7 +1006,8 @@ class SqliteStatement : public driver::Statement { "parameter count mismatch: expected {} but found {}", expected, actual); } - int64_t rows = 0; + int64_t output_rows = 0; + int64_t changed_rows = 0; SqliteMutexGuard guard(conn_); @@ -1027,7 +1026,11 @@ class SqliteStatement : public driver::Statement { } while (sqlite3_step(stmt_) == SQLITE_ROW) { - rows++; + output_rows++; + } + + if (sqlite3_column_count(stmt_) == 0) { + changed_rows += sqlite3_changes(conn_); } if (!binder_.schema.release) break; @@ -1041,9 +1044,10 @@ class SqliteStatement : public driver::Statement { } if (sqlite3_column_count(stmt_) == 0) { - rows = sqlite3_changes(conn_); + return changed_rows; + } else { + return output_rows; } - return rows; } Result ExecuteUpdateImpl(PreparedState& state) { return ExecuteUpdateImpl(); } diff --git a/3rd_party/apache-arrow-adbc/c/driver/sqlite/sqlite_test.cc b/3rd_party/apache-arrow-adbc/c/driver/sqlite/sqlite_test.cc index 320eec0..8ceb747 100644 --- a/3rd_party/apache-arrow-adbc/c/driver/sqlite/sqlite_test.cc +++ b/3rd_party/apache-arrow-adbc/c/driver/sqlite/sqlite_test.cc @@ -79,10 +79,16 @@ class SqliteQuirks : public adbc_validation::DriverQuirks { case NANOARROW_TYPE_UINT32: case NANOARROW_TYPE_UINT64: return NANOARROW_TYPE_INT64; + case NANOARROW_TYPE_HALF_FLOAT: case NANOARROW_TYPE_FLOAT: - case NANOARROW_TYPE_DOUBLE: return NANOARROW_TYPE_DOUBLE; case NANOARROW_TYPE_LARGE_STRING: + case NANOARROW_TYPE_STRING_VIEW: + return NANOARROW_TYPE_STRING; + case NANOARROW_TYPE_LARGE_BINARY: + case NANOARROW_TYPE_FIXED_SIZE_BINARY: + case NANOARROW_TYPE_BINARY_VIEW: + return NANOARROW_TYPE_BINARY; case NANOARROW_TYPE_DATE32: case NANOARROW_TYPE_TIMESTAMP: return NANOARROW_TYPE_STRING; @@ -328,6 +334,12 @@ class SqliteStatementTest : public ::testing::Test, void TestSqlIngestInterval() { GTEST_SKIP() << "Cannot ingest Interval (not implemented)"; } + void TestSqlIngestListOfInt32() { + GTEST_SKIP() << "Cannot ingest list (not implemented)"; + } + void TestSqlIngestListOfString() { + GTEST_SKIP() << "Cannot ingest list (not implemented)"; + } protected: void ValidateIngestedTemporalData(struct ArrowArrayView* values, ArrowType type, @@ -439,8 +451,11 @@ class SqliteReaderTest : public ::testing::Test { } void Bind(struct ArrowArray* batch, struct ArrowSchema* schema) { - ASSERT_THAT(AdbcSqliteBinderSetArray(&binder, batch, schema, &error), - IsOkStatus(&error)); + Handle stream; + struct ArrowArray batch_internal = *batch; + batch->release = nullptr; + adbc_validation::MakeStream(&stream.value, schema, {batch_internal}); + ASSERT_NO_FATAL_FAILURE(Bind(&stream.value)); } void Bind(struct ArrowArrayStream* stream) { diff --git a/3rd_party/apache-arrow-adbc/c/driver/sqlite/statement_reader.c b/3rd_party/apache-arrow-adbc/c/driver/sqlite/statement_reader.c index dc036a9..f731516 100644 --- a/3rd_party/apache-arrow-adbc/c/driver/sqlite/statement_reader.c +++ b/3rd_party/apache-arrow-adbc/c/driver/sqlite/statement_reader.c @@ -89,8 +89,11 @@ AdbcStatusCode AdbcSqliteBinderSet(struct AdbcSqliteBinder* binder, switch (value_view.type) { case NANOARROW_TYPE_STRING: case NANOARROW_TYPE_LARGE_STRING: + case NANOARROW_TYPE_STRING_VIEW: case NANOARROW_TYPE_BINARY: case NANOARROW_TYPE_LARGE_BINARY: + case NANOARROW_TYPE_FIXED_SIZE_BINARY: + case NANOARROW_TYPE_BINARY_VIEW: break; default: SetError(error, "Column %d dictionary has unsupported type %s", i, @@ -105,14 +108,6 @@ AdbcStatusCode AdbcSqliteBinderSet(struct AdbcSqliteBinder* binder, return ADBC_STATUS_OK; } -AdbcStatusCode AdbcSqliteBinderSetArray(struct AdbcSqliteBinder* binder, - struct ArrowArray* values, - struct ArrowSchema* schema, - struct AdbcError* error) { - AdbcSqliteBinderRelease(binder); - RAISE_ADBC(BatchToArrayStream(values, schema, &binder->params, error)); - return AdbcSqliteBinderSet(binder, error); -} // NOLINT(whitespace/indent) AdbcStatusCode AdbcSqliteBinderSetArrayStream(struct AdbcSqliteBinder* binder, struct ArrowArrayStream* values, struct AdbcError* error) { @@ -334,7 +329,9 @@ AdbcStatusCode AdbcSqliteBinderBindNext(struct AdbcSqliteBinder* binder, sqlite3 } else { switch (binder->types[col]) { case NANOARROW_TYPE_BINARY: - case NANOARROW_TYPE_LARGE_BINARY: { + case NANOARROW_TYPE_LARGE_BINARY: + case NANOARROW_TYPE_FIXED_SIZE_BINARY: + case NANOARROW_TYPE_BINARY_VIEW: { struct ArrowBufferView value = ArrowArrayViewGetBytesUnsafe(binder->batch.children[col], binder->next_row); status = sqlite3_bind_blob(stmt, col + 1, value.data.as_char, value.size_bytes, @@ -367,6 +364,7 @@ AdbcStatusCode AdbcSqliteBinderBindNext(struct AdbcSqliteBinder* binder, sqlite3 status = sqlite3_bind_int64(stmt, col + 1, value); break; } + case NANOARROW_TYPE_HALF_FLOAT: case NANOARROW_TYPE_FLOAT: case NANOARROW_TYPE_DOUBLE: { double value = ArrowArrayViewGetDoubleUnsafe(binder->batch.children[col], @@ -375,7 +373,8 @@ AdbcStatusCode AdbcSqliteBinderBindNext(struct AdbcSqliteBinder* binder, sqlite3 break; } case NANOARROW_TYPE_STRING: - case NANOARROW_TYPE_LARGE_STRING: { + case NANOARROW_TYPE_LARGE_STRING: + case NANOARROW_TYPE_STRING_VIEW: { struct ArrowBufferView value = ArrowArrayViewGetBytesUnsafe(binder->batch.children[col], binder->next_row); status = sqlite3_bind_text(stmt, col + 1, value.data.as_char, value.size_bytes, diff --git a/3rd_party/apache-arrow-adbc/c/driver/sqlite/statement_reader.h b/3rd_party/apache-arrow-adbc/c/driver/sqlite/statement_reader.h index 77333a9..2e6b190 100644 --- a/3rd_party/apache-arrow-adbc/c/driver/sqlite/statement_reader.h +++ b/3rd_party/apache-arrow-adbc/c/driver/sqlite/statement_reader.h @@ -40,11 +40,6 @@ struct ADBC_EXPORT AdbcSqliteBinder { int64_t next_row; }; -ADBC_EXPORT -AdbcStatusCode AdbcSqliteBinderSetArray(struct AdbcSqliteBinder* binder, - struct ArrowArray* values, - struct ArrowSchema* schema, - struct AdbcError* error); ADBC_EXPORT AdbcStatusCode AdbcSqliteBinderSetArrayStream(struct AdbcSqliteBinder* binder, struct ArrowArrayStream* values, diff --git a/3rd_party/apache-arrow-adbc/c/driver_manager/CMakeLists.txt b/3rd_party/apache-arrow-adbc/c/driver_manager/CMakeLists.txt index f08b41e..0eb17f0 100644 --- a/3rd_party/apache-arrow-adbc/c/driver_manager/CMakeLists.txt +++ b/3rd_party/apache-arrow-adbc/c/driver_manager/CMakeLists.txt @@ -35,6 +35,10 @@ include_directories(SYSTEM ${REPOSITORY_ROOT}/c/include/) include_directories(SYSTEM ${REPOSITORY_ROOT}/c/vendor) include_directories(SYSTEM ${REPOSITORY_ROOT}/c/driver) +install(FILES "${REPOSITORY_ROOT}/c/include/adbc.h" DESTINATION include) +install(FILES "${REPOSITORY_ROOT}/c/include/arrow-adbc/adbc.h" + DESTINATION include/arrow-adbc) + foreach(LIB_TARGET ${ADBC_LIBRARIES}) target_compile_definitions(${LIB_TARGET} PRIVATE ADBC_EXPORTING) endforeach() diff --git a/3rd_party/apache-arrow-adbc/c/driver_manager/adbc_driver_manager.cc b/3rd_party/apache-arrow-adbc/c/driver_manager/adbc_driver_manager.cc index 44c3d9f..0ce173a 100644 --- a/3rd_party/apache-arrow-adbc/c/driver_manager/adbc_driver_manager.cc +++ b/3rd_party/apache-arrow-adbc/c/driver_manager/adbc_driver_manager.cc @@ -84,6 +84,36 @@ void SetError(struct AdbcError* error, const std::string& message) { error->release = ReleaseError; } +// Copies src_error into error and releases src_error +void SetError(struct AdbcError* error, struct AdbcError* src_error) { + if (!error) return; + if (error->release) error->release(error); + + if (src_error->message) { + size_t message_size = strlen(src_error->message); + error->message = new char[message_size]; + std::memcpy(error->message, src_error->message, message_size); + error->message[message_size] = '\0'; + } else { + error->message = nullptr; + } + + error->release = ReleaseError; + if (src_error->release) { + src_error->release(src_error); + } +} + +struct OwnedError { + struct AdbcError error = ADBC_ERROR_INIT; + + ~OwnedError() { + if (error.release) { + error.release(&error); + } + } +}; + // Driver state /// A driver DLL. @@ -666,7 +696,7 @@ std::string AdbcDriverManagerDefaultEntrypoint(const std::string& driver) { int AdbcErrorGetDetailCount(const struct AdbcError* error) { if (error->vendor_code == ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA && error->private_data && - error->private_driver) { + error->private_driver && error->private_driver->ErrorGetDetailCount) { return error->private_driver->ErrorGetDetailCount(error); } return 0; @@ -674,7 +704,7 @@ int AdbcErrorGetDetailCount(const struct AdbcError* error) { struct AdbcErrorDetail AdbcErrorGetDetail(const struct AdbcError* error, int index) { if (error->vendor_code == ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA && error->private_data && - error->private_driver) { + error->private_driver && error->private_driver->ErrorGetDetail) { return error->private_driver->ErrorGetDetail(error, index); } return {nullptr, nullptr, 0}; @@ -900,6 +930,7 @@ AdbcStatusCode AdbcDatabaseInit(struct AdbcDatabase* database, struct AdbcError* status = AdbcLoadDriver(args->driver.c_str(), nullptr, ADBC_VERSION_1_1_0, database->private_driver, error); } + if (status != ADBC_STATUS_OK) { // Restore private_data so it will be released by AdbcDatabaseRelease database->private_data = args; @@ -910,10 +941,18 @@ AdbcStatusCode AdbcDatabaseInit(struct AdbcDatabase* database, struct AdbcError* database->private_driver = nullptr; return status; } - status = database->private_driver->DatabaseNew(database, error); + + // Errors that occur during AdbcDatabaseXXX() refer to the driver via + // the private_driver member; however, after we return we have released + // the driver and inspecting the error might segfault. Here, we scope + // the driver-produced error to this function and make a copy if necessary. + OwnedError driver_error; + + status = database->private_driver->DatabaseNew(database, &driver_error.error); if (status != ADBC_STATUS_OK) { if (database->private_driver->release) { - database->private_driver->release(database->private_driver, error); + SetError(error, &driver_error.error); + database->private_driver->release(database->private_driver, nullptr); } delete database->private_driver; database->private_driver = nullptr; @@ -927,33 +966,34 @@ AdbcStatusCode AdbcDatabaseInit(struct AdbcDatabase* database, struct AdbcError* INIT_ERROR(error, database); for (const auto& option : options) { - status = database->private_driver->DatabaseSetOption(database, option.first.c_str(), - option.second.c_str(), error); + status = database->private_driver->DatabaseSetOption( + database, option.first.c_str(), option.second.c_str(), &driver_error.error); if (status != ADBC_STATUS_OK) break; } for (const auto& option : bytes_options) { status = database->private_driver->DatabaseSetOptionBytes( database, option.first.c_str(), reinterpret_cast(option.second.data()), option.second.size(), - error); + &driver_error.error); if (status != ADBC_STATUS_OK) break; } for (const auto& option : int_options) { status = database->private_driver->DatabaseSetOptionInt( - database, option.first.c_str(), option.second, error); + database, option.first.c_str(), option.second, &driver_error.error); if (status != ADBC_STATUS_OK) break; } for (const auto& option : double_options) { status = database->private_driver->DatabaseSetOptionDouble( - database, option.first.c_str(), option.second, error); + database, option.first.c_str(), option.second, &driver_error.error); if (status != ADBC_STATUS_OK) break; } if (status != ADBC_STATUS_OK) { // Release the database - std::ignore = database->private_driver->DatabaseRelease(database, error); + std::ignore = database->private_driver->DatabaseRelease(database, nullptr); if (database->private_driver->release) { - database->private_driver->release(database->private_driver, error); + SetError(error, &driver_error.error); + database->private_driver->release(database->private_driver, nullptr); } delete database->private_driver; database->private_driver = nullptr; @@ -962,6 +1002,7 @@ AdbcStatusCode AdbcDatabaseInit(struct AdbcDatabase* database, struct AdbcError* database->private_data = nullptr; return status; } + return database->private_driver->DatabaseInit(database, error); } diff --git a/3rd_party/apache-arrow-adbc/c/driver_manager/adbc_driver_manager_test.cc b/3rd_party/apache-arrow-adbc/c/driver_manager/adbc_driver_manager_test.cc index 0d8e362..c2342eb 100644 --- a/3rd_party/apache-arrow-adbc/c/driver_manager/adbc_driver_manager_test.cc +++ b/3rd_party/apache-arrow-adbc/c/driver_manager/adbc_driver_manager_test.cc @@ -187,10 +187,18 @@ class SqliteQuirks : public adbc_validation::DriverQuirks { case NANOARROW_TYPE_UINT32: case NANOARROW_TYPE_UINT64: return NANOARROW_TYPE_INT64; + case NANOARROW_TYPE_HALF_FLOAT: case NANOARROW_TYPE_FLOAT: - case NANOARROW_TYPE_DOUBLE: return NANOARROW_TYPE_DOUBLE; case NANOARROW_TYPE_LARGE_STRING: + case NANOARROW_TYPE_STRING_VIEW: + return NANOARROW_TYPE_STRING; + case NANOARROW_TYPE_LARGE_BINARY: + case NANOARROW_TYPE_FIXED_SIZE_BINARY: + case NANOARROW_TYPE_BINARY_VIEW: + return NANOARROW_TYPE_BINARY; + case NANOARROW_TYPE_DATE32: + case NANOARROW_TYPE_TIMESTAMP: return NANOARROW_TYPE_STRING; default: return ingest_type; @@ -267,8 +275,6 @@ class SqliteStatementTest : public ::testing::Test, void TearDown() override { ASSERT_NO_FATAL_FAILURE(TearDownTest()); } void TestSqlIngestUInt64() { GTEST_SKIP() << "Cannot ingest UINT64 (out of range)"; } - void TestSqlIngestBinary() { GTEST_SKIP() << "Cannot ingest BINARY (not implemented)"; } - void TestSqlIngestDate32() { GTEST_SKIP() << "Cannot ingest DATE (not implemented)"; } void TestSqlIngestTimestamp() { GTEST_SKIP() << "Cannot ingest TIMESTAMP (not implemented)"; } @@ -281,6 +287,12 @@ class SqliteStatementTest : public ::testing::Test, void TestSqlIngestInterval() { GTEST_SKIP() << "Cannot ingest Interval (not implemented)"; } + void TestSqlIngestListOfInt32() { + GTEST_SKIP() << "Cannot ingest list (not implemented)"; + } + void TestSqlIngestListOfString() { + GTEST_SKIP() << "Cannot ingest list (not implemented)"; + } protected: SqliteQuirks quirks_; diff --git a/3rd_party/apache-arrow-adbc/c/include/arrow-adbc/adbc.h b/3rd_party/apache-arrow-adbc/c/include/arrow-adbc/adbc.h index d6c2a4a..b965672 100644 --- a/3rd_party/apache-arrow-adbc/c/include/arrow-adbc/adbc.h +++ b/3rd_party/apache-arrow-adbc/c/include/arrow-adbc/adbc.h @@ -1972,7 +1972,7 @@ AdbcStatusCode AdbcStatementExecuteQuery(struct AdbcStatement* statement, /// \since ADBC API revision 1.1.0 /// /// \param[in] statement The statement to execute. -/// \param[out] out The result schema. +/// \param[out] schema The result schema. /// \param[out] error An optional location to return an error /// message if necessary. /// diff --git a/3rd_party/apache-arrow-adbc/c/include/arrow-adbc/adbc_driver_manager.h b/3rd_party/apache-arrow-adbc/c/include/arrow-adbc/adbc_driver_manager.h index c65bff2..c32368a 100644 --- a/3rd_party/apache-arrow-adbc/c/include/arrow-adbc/adbc_driver_manager.h +++ b/3rd_party/apache-arrow-adbc/c/include/arrow-adbc/adbc_driver_manager.h @@ -15,6 +15,11 @@ // specific language governing permissions and limitations // under the License. +/// \file arrow-adbc/adbc_driver_manager.h ADBC Driver Manager +/// +/// A helper library to dynamically load and use multiple ADBC drivers in the +/// same process. + #pragma once #include diff --git a/3rd_party/apache-arrow-adbc/c/meson.build b/3rd_party/apache-arrow-adbc/c/meson.build index 1cf45d1..f636146 100644 --- a/3rd_party/apache-arrow-adbc/c/meson.build +++ b/3rd_party/apache-arrow-adbc/c/meson.build @@ -18,7 +18,7 @@ project( 'arrow-adbc', 'c', 'cpp', - version: '1.2.0', + version: '1.3.0', license: 'Apache-2.0', meson_version: '>=1.3.0', default_options: [ diff --git a/3rd_party/apache-arrow-adbc/c/subprojects/nanoarrow.wrap b/3rd_party/apache-arrow-adbc/c/subprojects/nanoarrow.wrap index 1a7e856..612d111 100644 --- a/3rd_party/apache-arrow-adbc/c/subprojects/nanoarrow.wrap +++ b/3rd_party/apache-arrow-adbc/c/subprojects/nanoarrow.wrap @@ -1,10 +1,8 @@ [wrap-file] -directory = arrow-nanoarrow-apache-arrow-nanoarrow-0.5.0 -source_url = https://github.com/apache/arrow-nanoarrow/archive/refs/tags/apache-arrow-nanoarrow-0.5.0.tar.gz -source_filename = apache-arrow-nanoarrow-0.5.0.tar.gz -source_hash = 0ceeaa1fb005dbc89c8c7d1b39f2dba07344e40aa9d885ee25fb55b4d57e331a -source_fallback_url = https://github.com/mesonbuild/wrapdb/releases/download/nanoarrow_0.5.0-1/apache-arrow-nanoarrow-0.5.0.tar.gz -wrapdb_version = 0.5.0-1 +directory = arrow-nanoarrow-33d2c8b973d8f8f424e02ac92ddeaace2a92f8dd +source_url = https://github.com/apache/arrow-nanoarrow/archive/33d2c8b973d8f8f424e02ac92ddeaace2a92f8dd.tar.gz +source_filename = arrow-nanoarrow-33d2c8b973d8f8f424e02ac92ddeaace2a92f8dd.tar.gz +source_hash = be4d2a6f1467793fe1b02c6ecf12383ed9ecf29557531715a3b9e11578ab18e8 [provide] nanoarrow = nanoarrow_dep diff --git a/3rd_party/apache-arrow-adbc/c/validation/adbc_validation.h b/3rd_party/apache-arrow-adbc/c/validation/adbc_validation.h index ab665ac..427e39b 100644 --- a/3rd_party/apache-arrow-adbc/c/validation/adbc_validation.h +++ b/3rd_party/apache-arrow-adbc/c/validation/adbc_validation.h @@ -27,6 +27,8 @@ #include #include +#include "adbc_validation_util.h" + namespace adbc_validation { #define ADBCV_STRINGIFY(s) #s @@ -160,6 +162,18 @@ class DriverQuirks { return ingest_type; } + /// \brief For a given Arrow type of (possibly nested) ingested data, what Arrow type + /// will the database return when that column is selected? + virtual SchemaField IngestSelectRoundTripType(SchemaField ingest_field) const { + SchemaField out(ingest_field.name, IngestSelectRoundTripType(ingest_field.type), + ingest_field.nullable); + for (const auto& child : ingest_field.children) { + out.children.push_back(IngestSelectRoundTripType(child)); + } + + return out; + } + /// \brief Whether bulk ingest is supported virtual bool supports_bulk_ingest(const char* mode) const { return true; } @@ -224,6 +238,12 @@ class DriverQuirks { /// column matching. virtual bool supports_error_on_incompatible_schema() const { return true; } + /// \brief Whether ingestion supports StringView/BinaryView types + virtual bool supports_ingest_view_types() const { return true; } + + /// \brief Whether ingestion supports Float16 + virtual bool supports_ingest_float16() const { return true; } + /// \brief Default catalog to use for tests virtual std::string catalog() const { return ""; } @@ -344,7 +364,7 @@ class StatementTest { void TestNewInit(); void TestRelease(); - // ---- Type-specific tests -------------------- + // ---- Type-specific ingest tests ------------- void TestSqlIngestBool(); @@ -359,13 +379,18 @@ class StatementTest { void TestSqlIngestUInt64(); // Floats + void TestSqlIngestFloat16(); void TestSqlIngestFloat32(); void TestSqlIngestFloat64(); // Strings void TestSqlIngestString(); void TestSqlIngestLargeString(); + void TestSqlIngestStringView(); void TestSqlIngestBinary(); + void TestSqlIngestLargeBinary(); + void TestSqlIngestFixedSizeBinary(); + void TestSqlIngestBinaryView(); // Temporal void TestSqlIngestDuration(); @@ -377,6 +402,10 @@ class StatementTest { // Dictionary-encoded void TestSqlIngestStringDictionary(); + // Nested + void TestSqlIngestListOfInt32(); + void TestSqlIngestListOfString(); + void TestSqlIngestStreamZeroArrays(); // ---- End Type-specific tests ---------------- @@ -409,6 +438,8 @@ class StatementTest { void TestSqlPrepareErrorNoQuery(); void TestSqlPrepareErrorParamCountMismatch(); + void TestSqlBind(); + void TestSqlQueryEmpty(); void TestSqlQueryInts(); void TestSqlQueryFloats(); @@ -440,6 +471,11 @@ class StatementTest { struct AdbcConnection connection; struct AdbcStatement statement; + template + void TestSqlIngestType(SchemaField type, + const std::vector>& values, + bool dictionary_encode); + template void TestSqlIngestType(ArrowType type, const std::vector>& values, bool dictionary_encode); @@ -455,6 +491,14 @@ class StatementTest { const char* timezone); }; +template +void StatementTest::TestSqlIngestType(ArrowType type, + const std::vector>& values, + bool dictionary_encode) { + SchemaField field("col", type); + TestSqlIngestType(field, values, dictionary_encode); +} + #define ADBCV_TEST_STATEMENT(FIXTURE) \ static_assert(std::is_base_of::value, \ ADBCV_STRINGIFY(FIXTURE) " must inherit from StatementTest"); \ @@ -469,17 +513,24 @@ class StatementTest { TEST_F(FIXTURE, SqlIngestUInt16) { TestSqlIngestUInt16(); } \ TEST_F(FIXTURE, SqlIngestUInt32) { TestSqlIngestUInt32(); } \ TEST_F(FIXTURE, SqlIngestUInt64) { TestSqlIngestUInt64(); } \ + TEST_F(FIXTURE, SqlIngestFloat16) { TestSqlIngestFloat16(); } \ TEST_F(FIXTURE, SqlIngestFloat32) { TestSqlIngestFloat32(); } \ TEST_F(FIXTURE, SqlIngestFloat64) { TestSqlIngestFloat64(); } \ TEST_F(FIXTURE, SqlIngestString) { TestSqlIngestString(); } \ TEST_F(FIXTURE, SqlIngestLargeString) { TestSqlIngestLargeString(); } \ + TEST_F(FIXTURE, SqlIngestStringView) { TestSqlIngestStringView(); } \ TEST_F(FIXTURE, SqlIngestBinary) { TestSqlIngestBinary(); } \ + TEST_F(FIXTURE, SqlIngestLargeBinary) { TestSqlIngestLargeBinary(); } \ + TEST_F(FIXTURE, SqlIngestFixedSizeBinary) { TestSqlIngestFixedSizeBinary(); } \ + TEST_F(FIXTURE, SqlIngestBinaryView) { TestSqlIngestBinaryView(); } \ TEST_F(FIXTURE, SqlIngestDuration) { TestSqlIngestDuration(); } \ TEST_F(FIXTURE, SqlIngestDate32) { TestSqlIngestDate32(); } \ TEST_F(FIXTURE, SqlIngestTimestamp) { TestSqlIngestTimestamp(); } \ TEST_F(FIXTURE, SqlIngestTimestampTz) { TestSqlIngestTimestampTz(); } \ TEST_F(FIXTURE, SqlIngestInterval) { TestSqlIngestInterval(); } \ TEST_F(FIXTURE, SqlIngestStringDictionary) { TestSqlIngestStringDictionary(); } \ + TEST_F(FIXTURE, SqlIngestListOfInt32) { TestSqlIngestListOfInt32(); } \ + TEST_F(FIXTURE, SqlIngestListOfString) { TestSqlIngestListOfString(); } \ TEST_F(FIXTURE, TestSqlIngestStreamZeroArrays) { TestSqlIngestStreamZeroArrays(); } \ TEST_F(FIXTURE, SqlIngestTableEscaping) { TestSqlIngestTableEscaping(); } \ TEST_F(FIXTURE, SqlIngestColumnEscaping) { TestSqlIngestColumnEscaping(); } \ @@ -508,6 +559,7 @@ class StatementTest { TEST_F(FIXTURE, SqlPrepareErrorParamCountMismatch) { \ TestSqlPrepareErrorParamCountMismatch(); \ } \ + TEST_F(FIXTURE, SqlBind) { TestSqlBind(); } \ TEST_F(FIXTURE, SqlQueryEmpty) { TestSqlQueryEmpty(); } \ TEST_F(FIXTURE, SqlQueryInts) { TestSqlQueryInts(); } \ TEST_F(FIXTURE, SqlQueryFloats) { TestSqlQueryFloats(); } \ diff --git a/3rd_party/apache-arrow-adbc/c/validation/adbc_validation_connection.cc b/3rd_party/apache-arrow-adbc/c/validation/adbc_validation_connection.cc index 80e1fbb..032f1d3 100644 --- a/3rd_party/apache-arrow-adbc/c/validation/adbc_validation_connection.cc +++ b/3rd_party/apache-arrow-adbc/c/validation/adbc_validation_connection.cc @@ -425,10 +425,10 @@ void CheckGetObjectsSchema(struct ArrowSchema* schema) { {"constraint_column_names", NANOARROW_TYPE_LIST, NOT_NULL}, {"constraint_column_usage", NANOARROW_TYPE_LIST, NULLABLE}, })); - ASSERT_NO_FATAL_FAILURE(CompareSchema( - constraint_schema->children[2], { - {std::nullopt, NANOARROW_TYPE_STRING, NULLABLE}, - })); + ASSERT_NO_FATAL_FAILURE(CompareSchema(constraint_schema->children[2], + { + {"", NANOARROW_TYPE_STRING, NULLABLE}, + })); struct ArrowSchema* usage_schema = constraint_schema->children[3]->children[0]; ASSERT_NO_FATAL_FAILURE( @@ -744,13 +744,15 @@ void ConnectionTest::TestMetadataGetObjectsColumns() { struct TestCase { std::optional filter; - std::vector column_names; - std::vector ordinal_positions; + // the pair is column name & ordinal position of the column + std::vector> columns; }; std::vector test_cases; - test_cases.push_back({std::nullopt, {"int64s", "strings"}, {1, 2}}); - test_cases.push_back({"in%", {"int64s"}, {1}}); + test_cases.push_back({std::nullopt, {{"int64s", 1}, {"strings", 2}}}); + test_cases.push_back({"in%", {{"int64s", 1}}}); + + const std::string catalog = quirks()->catalog(); for (const auto& test_case : test_cases) { std::string scope = "Filter: "; @@ -758,13 +760,14 @@ void ConnectionTest::TestMetadataGetObjectsColumns() { SCOPED_TRACE(scope); StreamReader reader; + std::vector> columns; std::vector column_names; std::vector ordinal_positions; ASSERT_THAT( AdbcConnectionGetObjects( - &connection, ADBC_OBJECT_DEPTH_COLUMNS, nullptr, nullptr, nullptr, nullptr, - test_case.filter.has_value() ? test_case.filter->c_str() : nullptr, + &connection, ADBC_OBJECT_DEPTH_COLUMNS, catalog.c_str(), nullptr, nullptr, + nullptr, test_case.filter.has_value() ? test_case.filter->c_str() : nullptr, &reader.stream.value, &error), IsOkStatus(&error)); ASSERT_NO_FATAL_FAILURE(reader.GetSchema()); @@ -834,10 +837,9 @@ void ConnectionTest::TestMetadataGetObjectsColumns() { std::string temp(name.data, name.size_bytes); std::transform(temp.begin(), temp.end(), temp.begin(), [](unsigned char c) { return std::tolower(c); }); - column_names.push_back(std::move(temp)); - ordinal_positions.push_back( - static_cast(ArrowArrayViewGetIntUnsafe( - table_columns->children[1], columns_index))); + columns.emplace_back(std::move(temp), + static_cast(ArrowArrayViewGetIntUnsafe( + table_columns->children[1], columns_index))); } } } @@ -847,8 +849,9 @@ void ConnectionTest::TestMetadataGetObjectsColumns() { } while (reader.array->release); ASSERT_TRUE(found_expected_table) << "Did (not) find table in metadata"; - ASSERT_EQ(test_case.column_names, column_names); - ASSERT_EQ(test_case.ordinal_positions, ordinal_positions); + // metadata columns do not guarantee the order they are returned in, just + // validate all the elements are there. + ASSERT_THAT(columns, testing::UnorderedElementsAreArray(test_case.columns)); } } diff --git a/3rd_party/apache-arrow-adbc/c/validation/adbc_validation_statement.cc b/3rd_party/apache-arrow-adbc/c/validation/adbc_validation_statement.cc index 4316205..cd38862 100644 --- a/3rd_party/apache-arrow-adbc/c/validation/adbc_validation_statement.cc +++ b/3rd_party/apache-arrow-adbc/c/validation/adbc_validation_statement.cc @@ -79,9 +79,12 @@ void StatementTest::TestRelease() { } template -void StatementTest::TestSqlIngestType(ArrowType type, +void StatementTest::TestSqlIngestType(SchemaField field, const std::vector>& values, bool dictionary_encode) { + // Override the field name + field.name = "col"; + if (!quirks()->supports_bulk_ingest(ADBC_INGEST_OPTION_MODE_CREATE)) { GTEST_SKIP(); } @@ -92,7 +95,7 @@ void StatementTest::TestSqlIngestType(ArrowType type, Handle schema; Handle array; struct ArrowError na_error; - ASSERT_THAT(MakeSchema(&schema.value, {{"col", type}}), IsOkErrno()); + ASSERT_THAT(MakeSchema(&schema.value, {field}), IsOkErrno()); ASSERT_THAT(MakeBatch(&schema.value, &array.value, &na_error, values), IsOkErrno()); @@ -155,16 +158,15 @@ void StatementTest::TestSqlIngestType(ArrowType type, ::testing::AnyOf(::testing::Eq(values.size()), ::testing::Eq(-1))); ASSERT_NO_FATAL_FAILURE(reader.GetSchema()); - ArrowType round_trip_type = quirks()->IngestSelectRoundTripType(type); - ASSERT_NO_FATAL_FAILURE( - CompareSchema(&reader.schema.value, {{"col", round_trip_type, NULLABLE}})); + SchemaField round_trip_field = quirks()->IngestSelectRoundTripType(field); + ASSERT_NO_FATAL_FAILURE(CompareSchema(&reader.schema.value, {round_trip_field})); ASSERT_NO_FATAL_FAILURE(reader.Next()); ASSERT_NE(nullptr, reader.array->release); ASSERT_EQ(values.size(), reader.array->length); ASSERT_EQ(1, reader.array->n_children); - if (round_trip_type == type) { + if (round_trip_field.type == field.type) { // XXX: for now we can't compare values; we would need casting ASSERT_NO_FATAL_FAILURE( CompareArray(reader.array_view->children[0], values)); @@ -235,6 +237,14 @@ void StatementTest::TestSqlIngestInt64() { ASSERT_NO_FATAL_FAILURE(TestSqlIngestNumericType(NANOARROW_TYPE_INT64)); } +void StatementTest::TestSqlIngestFloat16() { + if (!quirks()->supports_ingest_float16()) { + GTEST_SKIP(); + } + + ASSERT_NO_FATAL_FAILURE(TestSqlIngestNumericType(NANOARROW_TYPE_HALF_FLOAT)); +} + void StatementTest::TestSqlIngestFloat32() { ASSERT_NO_FATAL_FAILURE(TestSqlIngestNumericType(NANOARROW_TYPE_FLOAT)); } @@ -253,6 +263,16 @@ void StatementTest::TestSqlIngestLargeString() { NANOARROW_TYPE_LARGE_STRING, {std::nullopt, "", "", "1234", "例"}, false)); } +void StatementTest::TestSqlIngestStringView() { + if (!quirks()->supports_ingest_view_types()) { + GTEST_SKIP(); + } + + ASSERT_NO_FATAL_FAILURE(TestSqlIngestType( + NANOARROW_TYPE_STRING_VIEW, {std::nullopt, "", "", "longer than 12 bytes", "例"}, + false)); +} + void StatementTest::TestSqlIngestBinary() { ASSERT_NO_FATAL_FAILURE(TestSqlIngestType>( NANOARROW_TYPE_BINARY, @@ -264,6 +284,38 @@ void StatementTest::TestSqlIngestBinary() { false)); } +void StatementTest::TestSqlIngestLargeBinary() { + ASSERT_NO_FATAL_FAILURE(TestSqlIngestType>( + NANOARROW_TYPE_LARGE_BINARY, + {std::nullopt, std::vector{}, + std::vector{std::byte{0x00}, std::byte{0x01}}, + std::vector{std::byte{0x01}, std::byte{0x02}, std::byte{0x03}, + std::byte{0x04}}, + std::vector{std::byte{0xfe}, std::byte{0xff}}}, + false)); +} + +void StatementTest::TestSqlIngestFixedSizeBinary() { + SchemaField field = SchemaField::FixedSize("col", NANOARROW_TYPE_FIXED_SIZE_BINARY, 4); + ASSERT_NO_FATAL_FAILURE(TestSqlIngestType( + field, {std::nullopt, "abcd", "efgh", "ijkl", "mnop"}, false)); +} + +void StatementTest::TestSqlIngestBinaryView() { + if (!quirks()->supports_ingest_view_types()) { + GTEST_SKIP(); + } + + ASSERT_NO_FATAL_FAILURE(TestSqlIngestType>( + NANOARROW_TYPE_LARGE_BINARY, + {std::nullopt, std::vector{}, + std::vector{std::byte{0x00}, std::byte{0x01}}, + std::vector{std::byte{0x01}, std::byte{0x02}, std::byte{0x03}, + std::byte{0x04}}, + std::vector{std::byte{0xfe}, std::byte{0xff}}}, + false)); +} + void StatementTest::TestSqlIngestDate32() { ASSERT_NO_FATAL_FAILURE(TestSqlIngestNumericType(NANOARROW_TYPE_DATE32)); } @@ -491,6 +543,24 @@ void StatementTest::TestSqlIngestStringDictionary() { /*dictionary_encode*/ true)); } +void StatementTest::TestSqlIngestListOfInt32() { + SchemaField field = + SchemaField::Nested("col", NANOARROW_TYPE_LIST, {{"item", NANOARROW_TYPE_INT32}}); + ASSERT_NO_FATAL_FAILURE(TestSqlIngestType>( + field, {std::nullopt, std::vector{1, 2, 3}, std::vector{4, 5}}, + /*dictionary_encode*/ false)); +} + +void StatementTest::TestSqlIngestListOfString() { + SchemaField field = + SchemaField::Nested("col", NANOARROW_TYPE_LIST, {{"item", NANOARROW_TYPE_STRING}}); + ASSERT_NO_FATAL_FAILURE(TestSqlIngestType>( + field, + {std::nullopt, std::vector{"abc", "defg"}, + std::vector{"hijk"}}, + /*dictionary_encode*/ false)); +} + void StatementTest::TestSqlIngestStreamZeroArrays() { if (!quirks()->supports_bulk_ingest(ADBC_INGEST_OPTION_MODE_CREATE)) { GTEST_SKIP(); @@ -2108,6 +2178,71 @@ void StatementTest::TestSqlPrepareErrorParamCountMismatch() { ::testing::Not(IsOkStatus(&error))); } +void StatementTest::TestSqlBind() { + if (!quirks()->supports_dynamic_parameter_binding()) { + GTEST_SKIP(); + } + + ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error)); + + ASSERT_THAT(quirks()->DropTable(&connection, "bindtest", &error), IsOkStatus(&error)); + + ASSERT_THAT(AdbcStatementSetSqlQuery( + &statement, "CREATE TABLE bindtest (col1 INTEGER, col2 TEXT)", &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, nullptr, &error), + IsOkStatus(&error)); + + Handle schema; + Handle array; + struct ArrowError na_error; + ASSERT_THAT(MakeSchema(&schema.value, + {{"", NANOARROW_TYPE_INT32}, {"", NANOARROW_TYPE_STRING}}), + IsOkErrno()); + + std::vector> int_values{std::nullopt, -123, 123}; + std::vector> string_values{"abc", std::nullopt, "defg"}; + + int batch_result = MakeBatch( + &schema.value, &array.value, &na_error, int_values, string_values); + ASSERT_THAT(batch_result, IsOkErrno()); + + auto insert_query = std::string("INSERT INTO bindtest VALUES (") + + quirks()->BindParameter(0) + ", " + quirks()->BindParameter(1) + + ")"; + ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, insert_query.c_str(), &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementPrepare(&statement, &error), IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementBind(&statement, &array.value, &schema.value, &error), + IsOkStatus(&error)); + int64_t rows_affected = -10; + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, &rows_affected, &error), + IsOkStatus(&error)); + ASSERT_THAT(rows_affected, ::testing::AnyOf(::testing::Eq(-1), ::testing::Eq(3))); + + ASSERT_THAT( + AdbcStatementSetSqlQuery( + &statement, "SELECT * FROM bindtest ORDER BY col1 ASC NULLS FIRST", &error), + IsOkStatus(&error)); + { + StreamReader reader; + ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value, + &reader.rows_affected, &error), + IsOkStatus(&error)); + ASSERT_THAT(reader.rows_affected, + ::testing::AnyOf(::testing::Eq(3), ::testing::Eq(-1))); + + ASSERT_NO_FATAL_FAILURE(reader.GetSchema()); + ASSERT_NO_FATAL_FAILURE(reader.Next()); + ASSERT_EQ(reader.array->length, 3); + CompareArray(reader.array_view->children[0], int_values); + CompareArray(reader.array_view->children[1], string_values); + + ASSERT_NO_FATAL_FAILURE(reader.Next()); + ASSERT_EQ(reader.array->release, nullptr); + } +} + void StatementTest::TestSqlQueryEmpty() { ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error)); diff --git a/3rd_party/apache-arrow-adbc/c/validation/adbc_validation_util.cc b/3rd_party/apache-arrow-adbc/c/validation/adbc_validation_util.cc index 54c18cc..7d97ad7 100644 --- a/3rd_party/apache-arrow-adbc/c/validation/adbc_validation_util.cc +++ b/3rd_party/apache-arrow-adbc/c/validation/adbc_validation_util.cc @@ -165,16 +165,53 @@ ::testing::Matcher IsStatus(AdbcStatusCode code, } \ } while (false); +static int MakeSchemaColumnImpl(struct ArrowSchema* column, const SchemaField& field) { + switch (field.type) { + case NANOARROW_TYPE_FIXED_SIZE_BINARY: + case NANOARROW_TYPE_FIXED_SIZE_LIST: + CHECK_ERRNO(ArrowSchemaSetTypeFixedSize(column, field.type, field.fixed_size)); + break; + default: + CHECK_ERRNO(ArrowSchemaSetType(column, field.type)); + break; + } + + CHECK_ERRNO(ArrowSchemaSetName(column, field.name.c_str())); + + if (!field.nullable) { + column->flags &= ~ARROW_FLAG_NULLABLE; + } + + if (static_cast(column->n_children) != field.children.size()) { + return EINVAL; + } + + switch (field.type) { + // SetType for a list will allocate and initialize children + case NANOARROW_TYPE_LIST: + case NANOARROW_TYPE_LARGE_LIST: + case NANOARROW_TYPE_FIXED_SIZE_LIST: + case NANOARROW_TYPE_MAP: { + size_t i = 0; + for (const SchemaField& child : field.children) { + CHECK_ERRNO(MakeSchemaColumnImpl(column->children[i], child)); + ++i; + } + break; + } + default: + break; + } + + return 0; +} + int MakeSchema(struct ArrowSchema* schema, const std::vector& fields) { ArrowSchemaInit(schema); CHECK_ERRNO(ArrowSchemaSetTypeStruct(schema, fields.size())); size_t i = 0; for (const SchemaField& field : fields) { - CHECK_ERRNO(ArrowSchemaSetType(schema->children[i], field.type)); - CHECK_ERRNO(ArrowSchemaSetName(schema->children[i], field.name.c_str())); - if (!field.nullable) { - schema->children[i]->flags &= ~ARROW_FLAG_NULLABLE; - } + CHECK_ERRNO(MakeSchemaColumnImpl(schema->children[i], field)); i++; } return 0; @@ -244,9 +281,7 @@ void MakeStream(struct ArrowArrayStream* stream, struct ArrowSchema* schema, stream->private_data = new ConstantArrayStream(schema, std::move(batches)); } -void CompareSchema( - struct ArrowSchema* schema, - const std::vector, ArrowType, bool>>& fields) { +void CompareSchema(struct ArrowSchema* schema, const std::vector& fields) { struct ArrowError na_error; struct ArrowSchemaView view; @@ -261,12 +296,11 @@ void CompareSchema( struct ArrowSchemaView field_view; ASSERT_THAT(ArrowSchemaViewInit(&field_view, schema->children[i], &na_error), IsOkErrno(&na_error)); - ASSERT_EQ(std::get<1>(fields[i]), field_view.type); - ASSERT_EQ(std::get<2>(fields[i]), - (schema->children[i]->flags & ARROW_FLAG_NULLABLE) != 0) + ASSERT_EQ(fields[i].type, field_view.type); + ASSERT_EQ(fields[i].nullable, (schema->children[i]->flags & ARROW_FLAG_NULLABLE) != 0) << "Nullability mismatch"; - if (std::get<0>(fields[i]).has_value()) { - ASSERT_STRCASEEQ(std::get<0>(fields[i])->c_str(), schema->children[i]->name); + if (fields[i].name != "") { + ASSERT_STRCASEEQ(fields[i].name.c_str(), schema->children[i]->name); } } } diff --git a/3rd_party/apache-arrow-adbc/c/validation/adbc_validation_util.h b/3rd_party/apache-arrow-adbc/c/validation/adbc_validation_util.h index 21eca52..b4f5d6f 100644 --- a/3rd_party/apache-arrow-adbc/c/validation/adbc_validation_util.h +++ b/3rd_party/apache-arrow-adbc/c/validation/adbc_validation_util.h @@ -256,13 +256,29 @@ struct GetObjectsReader { struct SchemaField { std::string name; ArrowType type = NANOARROW_TYPE_UNINITIALIZED; + int32_t fixed_size = 0; bool nullable = true; + std::vector children; SchemaField(std::string name, ArrowType type, bool nullable) : name(std::move(name)), type(type), nullable(nullable) {} SchemaField(std::string name, ArrowType type) : SchemaField(std::move(name), type, /*nullable=*/true) {} + + static SchemaField Nested(std::string name, ArrowType type, + std::vector children) { + SchemaField out(name, type); + out.children = std::move(children); + return out; + } + + static SchemaField FixedSize(std::string name, ArrowType type, int32_t fixed_size, + std::vector children = {}) { + SchemaField out = Nested(name, type, std::move(children)); + out.fixed_size = fixed_size; + return out; + } }; /// \brief Make a schema from a vector of (name, type, nullable) tuples. @@ -303,6 +319,29 @@ int MakeArray(struct ArrowArray* parent, struct ArrowArray* array, CHECK_OK(ArrowArrayAppendInterval(array, *v)); } else if constexpr (std::is_same::value) { CHECK_OK(ArrowArrayAppendDecimal(array, *v)); + } else if constexpr ( + // Possibly a more effective way to do this using template magic + // Not included but possible are the std::optional<> variants of this + std::is_same>::value || + std::is_same>::value || + std::is_same>::value || + std::is_same>::value || + std::is_same>::value || + std::is_same>::value || + std::is_same>::value || + std::is_same>::value || + std::is_same>::value || + std::is_same>::value || + std::is_same>::value || + std::is_same>::value || + std::is_same>>::value) { + using child_t = typename T::value_type; + std::vector> value_nullable; + for (const auto& child_value : *v) { + value_nullable.push_back(child_value); + } + CHECK_OK(MakeArray(array, array->children[0], value_nullable)); + CHECK_OK(ArrowArrayFinishElement(array)); } else { static_assert(!sizeof(T), "Not yet implemented"); return ENOTSUP; @@ -359,49 +398,33 @@ void MakeStream(struct ArrowArrayStream* stream, struct ArrowSchema* schema, /// \brief Compare an array for equality against a vector of values. template void CompareArray(struct ArrowArrayView* array, - const std::vector>& values) { - ASSERT_EQ(static_cast(values.size()), array->array->length); - int64_t i = 0; + const std::vector>& values, int64_t offset = 0, + int64_t length = -1) { + if (length == -1) { + length = array->length; + } + ASSERT_EQ(static_cast(values.size()), length); + int64_t i = offset; for (const auto& v : values) { SCOPED_TRACE("Array index " + std::to_string(i)); if (v.has_value()) { ASSERT_FALSE(ArrowArrayViewIsNull(array, i)); - if constexpr (std::is_same::value) { - ASSERT_NE(array->buffer_views[1].data.data, nullptr); - ASSERT_EQ(*v, array->buffer_views[1].data.as_float[i]); - } else if constexpr (std::is_same::value) { - ASSERT_NE(array->buffer_views[1].data.data, nullptr); - ASSERT_EQ(*v, array->buffer_views[1].data.as_double[i]); - } else if constexpr (std::is_same::value) { - ASSERT_NE(array->buffer_views[1].data.data, nullptr); - ASSERT_EQ(*v, array->buffer_views[1].data.as_float[i]); - } else if constexpr (std::is_same::value) { - ASSERT_NE(array->buffer_views[1].data.data, nullptr); - ASSERT_EQ(*v, ArrowBitGet(array->buffer_views[1].data.as_uint8, i)); - } else if constexpr (std::is_same::value) { - ASSERT_NE(array->buffer_views[1].data.data, nullptr); - ASSERT_EQ(*v, array->buffer_views[1].data.as_int8[i]); - } else if constexpr (std::is_same::value) { + if constexpr (std::is_same::value || std::is_same::value) { ASSERT_NE(array->buffer_views[1].data.data, nullptr); - ASSERT_EQ(*v, array->buffer_views[1].data.as_int16[i]); - } else if constexpr (std::is_same::value) { + ASSERT_EQ(ArrowArrayViewGetDoubleUnsafe(array, i), *v); + } else if constexpr (std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value) { ASSERT_NE(array->buffer_views[1].data.data, nullptr); - ASSERT_EQ(*v, array->buffer_views[1].data.as_int32[i]); - } else if constexpr (std::is_same::value) { - ASSERT_NE(array->buffer_views[1].data.data, nullptr); - ASSERT_EQ(*v, array->buffer_views[1].data.as_int64[i]); - } else if constexpr (std::is_same::value) { - ASSERT_NE(array->buffer_views[1].data.data, nullptr); - ASSERT_EQ(*v, array->buffer_views[1].data.as_uint8[i]); - } else if constexpr (std::is_same::value) { - ASSERT_NE(array->buffer_views[1].data.data, nullptr); - ASSERT_EQ(*v, array->buffer_views[1].data.as_uint16[i]); - } else if constexpr (std::is_same::value) { - ASSERT_NE(array->buffer_views[1].data.data, nullptr); - ASSERT_EQ(*v, array->buffer_views[1].data.as_uint32[i]); - } else if constexpr (std::is_same::value) { + ASSERT_EQ(ArrowArrayViewGetIntUnsafe(array, i), *v); + } else if constexpr (std::is_same::value || + std::is_same::value || + std::is_same::value || + std::is_same::value) { ASSERT_NE(array->buffer_views[1].data.data, nullptr); - ASSERT_EQ(*v, array->buffer_views[1].data.as_uint64[i]); + ASSERT_EQ(ArrowArrayViewGetUIntUnsafe(array, i), *v); } else if constexpr (std::is_same::value) { struct ArrowStringView view = ArrowArrayViewGetStringUnsafe(array, i); std::string str(view.data, view.size_bytes); @@ -421,6 +444,34 @@ void CompareArray(struct ArrowArrayView* array, ASSERT_EQ(interval.months, (*v)->months); ASSERT_EQ(interval.days, (*v)->days); ASSERT_EQ(interval.ns, (*v)->ns); + + } else if constexpr ( + // Possibly a more effective way to do this using template magic + // Not included but possible are the std::optional<> variants of this + std::is_same>::value || + std::is_same>::value || + std::is_same>::value || + std::is_same>::value || + std::is_same>::value || + std::is_same>::value || + std::is_same>::value || + std::is_same>::value || + std::is_same>::value || + std::is_same>::value || + std::is_same>::value || + std::is_same>::value || + std::is_same>>::value) { + using child_t = typename T::value_type; + std::vector> value_nullable; + for (const auto& child_value : *v) { + value_nullable.push_back(child_value); + } + + SCOPED_TRACE("List item"); + int64_t child_offset = ArrowArrayViewListChildOffset(array, i); + int64_t child_length = ArrowArrayViewListChildOffset(array, i + 1) - child_offset; + CompareArray(array->children[0], value_nullable, child_offset, + child_length); } else { static_assert(!sizeof(T), "Not yet implemented"); } @@ -433,9 +484,7 @@ void CompareArray(struct ArrowArrayView* array, /// \brief Compare a schema for equality against a vector of (name, /// type, nullable) tuples. -void CompareSchema( - struct ArrowSchema* schema, - const std::vector, ArrowType, bool>>& fields); +void CompareSchema(struct ArrowSchema* schema, const std::vector& fields); /// \brief Helper method to get the vendor version of a driver std::string GetDriverVendorVersion(struct AdbcConnection* connection); diff --git a/3rd_party/apache-arrow-adbc/c/vendor/nanoarrow/nanoarrow.c b/3rd_party/apache-arrow-adbc/c/vendor/nanoarrow/nanoarrow.c index 9677a0e..8f26598 100644 --- a/3rd_party/apache-arrow-adbc/c/vendor/nanoarrow/nanoarrow.c +++ b/3rd_party/apache-arrow-adbc/c/vendor/nanoarrow/nanoarrow.c @@ -66,6 +66,7 @@ void ArrowLayoutInit(struct ArrowLayout* layout, enum ArrowType storage_type) { switch (storage_type) { case NANOARROW_TYPE_UNINITIALIZED: case NANOARROW_TYPE_NA: + case NANOARROW_TYPE_RUN_END_ENCODED: layout->buffer_type[0] = NANOARROW_BUFFER_TYPE_NONE; layout->buffer_data_type[0] = NANOARROW_TYPE_UNINITIALIZED; layout->buffer_type[1] = NANOARROW_BUFFER_TYPE_NONE; @@ -178,6 +179,16 @@ void ArrowLayoutInit(struct ArrowLayout* layout, enum ArrowType storage_type) { layout->buffer_data_type[2] = NANOARROW_TYPE_BINARY; break; + case NANOARROW_TYPE_BINARY_VIEW: + layout->buffer_type[1] = NANOARROW_BUFFER_TYPE_DATA; + layout->buffer_data_type[1] = NANOARROW_TYPE_BINARY_VIEW; + layout->element_size_bits[1] = 128; + break; + case NANOARROW_TYPE_STRING_VIEW: + layout->buffer_type[1] = NANOARROW_BUFFER_TYPE_DATA; + layout->buffer_data_type[1] = NANOARROW_TYPE_STRING_VIEW; + layout->element_size_bits[1] = 128; + default: break; } @@ -345,6 +356,7 @@ ArrowErrorCode ArrowDecimalSetDigits(struct ArrowDecimal* decimal, // https://github.com/apache/arrow/blob/cd3321b28b0c9703e5d7105d6146c1270bbadd7f/cpp/src/arrow/util/decimal.cc#L365 ArrowErrorCode ArrowDecimalAppendDigitsToBuffer(const struct ArrowDecimal* decimal, struct ArrowBuffer* buffer) { + NANOARROW_DCHECK(decimal->n_words == 2 || decimal->n_words == 4); int is_negative = ArrowDecimalSign(decimal) < 0; uint64_t words_little_endian[4]; @@ -468,6 +480,7 @@ ArrowErrorCode ArrowDecimalAppendDigitsToBuffer(const struct ArrowDecimal* decim // under the License. #include +#include #include #include #include @@ -552,8 +565,12 @@ static const char* ArrowSchemaFormatTemplate(enum ArrowType type) { return "u"; case NANOARROW_TYPE_LARGE_STRING: return "U"; + case NANOARROW_TYPE_STRING_VIEW: + return "vu"; case NANOARROW_TYPE_BINARY: return "z"; + case NANOARROW_TYPE_BINARY_VIEW: + return "vz"; case NANOARROW_TYPE_LARGE_BINARY: return "Z"; @@ -576,6 +593,8 @@ static const char* ArrowSchemaFormatTemplate(enum ArrowType type) { return "+s"; case NANOARROW_TYPE_MAP: return "+m"; + case NANOARROW_TYPE_RUN_END_ENCODED: + return "+r"; default: return NULL; @@ -607,6 +626,13 @@ static int ArrowSchemaInitChildrenIfNeeded(struct ArrowSchema* schema, NANOARROW_RETURN_NOT_OK( ArrowSchemaSetName(schema->children[0]->children[1], "value")); break; + case NANOARROW_TYPE_RUN_END_ENCODED: + NANOARROW_RETURN_NOT_OK(ArrowSchemaAllocateChildren(schema, 2)); + ArrowSchemaInit(schema->children[0]); + NANOARROW_RETURN_NOT_OK(ArrowSchemaSetName(schema->children[0], "run_ends")); + schema->children[0]->flags &= ~ARROW_FLAG_NULLABLE; + ArrowSchemaInit(schema->children[1]); + NANOARROW_RETURN_NOT_OK(ArrowSchemaSetName(schema->children[1], "values")); default: break; } @@ -676,10 +702,10 @@ ArrowErrorCode ArrowSchemaSetTypeFixedSize(struct ArrowSchema* schema, int n_chars; switch (type) { case NANOARROW_TYPE_FIXED_SIZE_BINARY: - n_chars = snprintf(buffer, sizeof(buffer), "w:%d", (int)fixed_size); + n_chars = snprintf(buffer, sizeof(buffer), "w:%" PRId32, fixed_size); break; case NANOARROW_TYPE_FIXED_SIZE_LIST: - n_chars = snprintf(buffer, sizeof(buffer), "+w:%d", (int)fixed_size); + n_chars = snprintf(buffer, sizeof(buffer), "+w:%" PRId32, fixed_size); break; default: return EINVAL; @@ -729,6 +755,28 @@ ArrowErrorCode ArrowSchemaSetTypeDecimal(struct ArrowSchema* schema, enum ArrowT return ArrowSchemaSetFormat(schema, buffer); } +ArrowErrorCode ArrowSchemaSetTypeRunEndEncoded(struct ArrowSchema* schema, + enum ArrowType run_end_type) { + switch (run_end_type) { + case NANOARROW_TYPE_INT16: + case NANOARROW_TYPE_INT32: + case NANOARROW_TYPE_INT64: + break; + default: + return EINVAL; + } + + NANOARROW_RETURN_NOT_OK(ArrowSchemaSetFormat( + schema, ArrowSchemaFormatTemplate(NANOARROW_TYPE_RUN_END_ENCODED))); + NANOARROW_RETURN_NOT_OK( + ArrowSchemaInitChildrenIfNeeded(schema, NANOARROW_TYPE_RUN_END_ENCODED)); + NANOARROW_RETURN_NOT_OK(ArrowSchemaSetType(schema->children[0], run_end_type)); + NANOARROW_RETURN_NOT_OK( + ArrowSchemaSetType(schema->children[1], NANOARROW_TYPE_UNINITIALIZED)); + + return NANOARROW_OK; +} + static const char* ArrowTimeUnitFormatString(enum ArrowTimeUnit time_unit) { switch (time_unit) { case NANOARROW_TIME_UNIT_SECOND: @@ -850,7 +898,7 @@ ArrowErrorCode ArrowSchemaSetTypeUnion(struct ArrowSchema* schema, enum ArrowTyp format_out_size -= n_chars; for (int64_t i = 1; i < n_children; i++) { - n_chars = snprintf(format_cursor, format_out_size, ",%d", (int)i); + n_chars = snprintf(format_cursor, format_out_size, ",%" PRId64, i); format_cursor += n_chars; format_out_size -= n_chars; } @@ -1144,8 +1192,9 @@ static ArrowErrorCode ArrowSchemaViewParse(struct ArrowSchemaView* schema_view, ArrowSchemaViewSetPrimitive(schema_view, NANOARROW_TYPE_DECIMAL256); return NANOARROW_OK; default: - ArrowErrorSet(error, "Expected decimal bitwidth of 128 or 256 but found %d", - (int)schema_view->decimal_bitwidth); + ArrowErrorSet(error, + "Expected decimal bitwidth of 128 or 256 but found %" PRId32, + schema_view->decimal_bitwidth); return EINVAL; } @@ -1202,6 +1251,13 @@ static ArrowErrorCode ArrowSchemaViewParse(struct ArrowSchemaView* schema_view, *format_end_out = format + 2; return NANOARROW_OK; + // run end encoded has no buffer at all + case 'r': + schema_view->storage_type = NANOARROW_TYPE_RUN_END_ENCODED; + schema_view->type = NANOARROW_TYPE_RUN_END_ENCODED; + *format_end_out = format + 2; + return NANOARROW_OK; + // just validity buffer case 'w': if (format[2] != ':' || format[3] == '\0') { @@ -1249,11 +1305,10 @@ static ArrowErrorCode ArrowSchemaViewParse(struct ArrowSchemaView* schema_view, int64_t n_type_ids = _ArrowParseUnionTypeIds(schema_view->union_type_ids, NULL); if (n_type_ids != schema_view->schema->n_children) { - ArrowErrorSet( - error, - "Expected union type_ids parameter to be a comma-separated list of %ld " - "values between 0 and 127 but found '%s'", - (long)schema_view->schema->n_children, schema_view->union_type_ids); + ArrowErrorSet(error, + "Expected union type_ids parameter to be a comma-separated " + "list of %" PRId64 " values between 0 and 127 but found '%s'", + schema_view->schema->n_children, schema_view->union_type_ids); return EINVAL; } *format_end_out = format + strlen(format); @@ -1432,6 +1487,24 @@ static ArrowErrorCode ArrowSchemaViewParse(struct ArrowSchemaView* schema_view, return EINVAL; } + // view types + case 'v': { + switch (format[1]) { + case 'u': + ArrowSchemaViewSetPrimitive(schema_view, NANOARROW_TYPE_STRING_VIEW); + *format_end_out = format + 2; + return NANOARROW_OK; + case 'z': + ArrowSchemaViewSetPrimitive(schema_view, NANOARROW_TYPE_BINARY_VIEW); + *format_end_out = format + 2; + return NANOARROW_OK; + default: + ArrowErrorSet(error, "Expected 'u', or 'z' following 'v' but found '%s'", + format + 1); + return EINVAL; + } + } + default: ArrowErrorSet(error, "Unknown format: '%s'", format); return EINVAL; @@ -1441,8 +1514,9 @@ static ArrowErrorCode ArrowSchemaViewParse(struct ArrowSchemaView* schema_view, static ArrowErrorCode ArrowSchemaViewValidateNChildren( struct ArrowSchemaView* schema_view, int64_t n_children, struct ArrowError* error) { if (n_children != -1 && schema_view->schema->n_children != n_children) { - ArrowErrorSet(error, "Expected schema with %d children but found %d children", - (int)n_children, (int)schema_view->schema->n_children); + ArrowErrorSet( + error, "Expected schema with %" PRId64 " children but found %" PRId64 " children", + n_children, schema_view->schema->n_children); return EINVAL; } @@ -1452,15 +1526,15 @@ static ArrowErrorCode ArrowSchemaViewValidateNChildren( for (int64_t i = 0; i < schema_view->schema->n_children; i++) { child = schema_view->schema->children[i]; if (child == NULL) { - ArrowErrorSet(error, - "Expected valid schema at schema->children[%ld] but found NULL", - (long)i); + ArrowErrorSet( + error, "Expected valid schema at schema->children[%" PRId64 "] but found NULL", + i); return EINVAL; } else if (child->release == NULL) { - ArrowErrorSet( - error, - "Expected valid schema at schema->children[%ld] but found a released schema", - (long)i); + ArrowErrorSet(error, + "Expected valid schema at schema->children[%" PRId64 + "] but found a released schema", + i); return EINVAL; } } @@ -1478,8 +1552,9 @@ static ArrowErrorCode ArrowSchemaViewValidateMap(struct ArrowSchemaView* schema_ NANOARROW_RETURN_NOT_OK(ArrowSchemaViewValidateNChildren(schema_view, 1, error)); if (schema_view->schema->children[0]->n_children != 2) { - ArrowErrorSet(error, "Expected child of map type to have 2 children but found %d", - (int)schema_view->schema->children[0]->n_children); + ArrowErrorSet(error, + "Expected child of map type to have 2 children but found %" PRId64, + schema_view->schema->children[0]->n_children); return EINVAL; } @@ -1561,6 +1636,8 @@ static ArrowErrorCode ArrowSchemaViewValidate(struct ArrowSchemaView* schema_vie case NANOARROW_TYPE_TIME32: case NANOARROW_TYPE_TIME64: case NANOARROW_TYPE_DURATION: + case NANOARROW_TYPE_BINARY_VIEW: + case NANOARROW_TYPE_STRING_VIEW: return ArrowSchemaViewValidateNChildren(schema_view, 0, error); case NANOARROW_TYPE_FIXED_SIZE_BINARY: @@ -1576,6 +1653,9 @@ static ArrowErrorCode ArrowSchemaViewValidate(struct ArrowSchemaView* schema_vie case NANOARROW_TYPE_FIXED_SIZE_LIST: return ArrowSchemaViewValidateNChildren(schema_view, 1, error); + case NANOARROW_TYPE_RUN_END_ENCODED: + return ArrowSchemaViewValidateNChildren(schema_view, 2, error); + case NANOARROW_TYPE_STRUCT: return ArrowSchemaViewValidateNChildren(schema_view, -1, error); @@ -1591,7 +1671,7 @@ static ArrowErrorCode ArrowSchemaViewValidate(struct ArrowSchemaView* schema_vie default: ArrowErrorSet(error, "Expected a valid enum ArrowType value but found %d", - (int)schema_view->type); + schema_view->type); return EINVAL; } @@ -1641,8 +1721,8 @@ ArrowErrorCode ArrowSchemaViewInit(struct ArrowSchemaView* schema_view, } if ((format + format_len) != format_end_out) { - ArrowErrorSet(error, "Error parsing schema->format '%s': parsed %d/%d characters", - format, (int)(format_end_out - format), (int)(format_len)); + ArrowErrorSet(error, "Error parsing schema->format '%s': parsed %d/%zu characters", + format, (int)(format_end_out - format), format_len); return EINVAL; } @@ -1702,9 +1782,8 @@ static int64_t ArrowSchemaTypeToStringInternal(struct ArrowSchemaView* schema_vi switch (schema_view->type) { case NANOARROW_TYPE_DECIMAL128: case NANOARROW_TYPE_DECIMAL256: - return snprintf(out, n, "%s(%d, %d)", type_string, - (int)schema_view->decimal_precision, - (int)schema_view->decimal_scale); + return snprintf(out, n, "%s(%" PRId32 ", %" PRId32 ")", type_string, + schema_view->decimal_precision, schema_view->decimal_scale); case NANOARROW_TYPE_TIMESTAMP: return snprintf(out, n, "%s('%s', '%s')", type_string, ArrowTimeUnitString(schema_view->time_unit), schema_view->timezone); @@ -1715,7 +1794,7 @@ static int64_t ArrowSchemaTypeToStringInternal(struct ArrowSchemaView* schema_vi ArrowTimeUnitString(schema_view->time_unit)); case NANOARROW_TYPE_FIXED_SIZE_BINARY: case NANOARROW_TYPE_FIXED_SIZE_LIST: - return snprintf(out, n, "%s(%ld)", type_string, (long)schema_view->fixed_size); + return snprintf(out, n, "%s(%" PRId32 ")", type_string, schema_view->fixed_size); case NANOARROW_TYPE_SPARSE_UNION: case NANOARROW_TYPE_DENSE_UNION: return snprintf(out, n, "%s([%s])", type_string, schema_view->union_type_ids); @@ -1731,7 +1810,7 @@ static inline void ArrowToStringLogChars(char** out, int64_t n_chars_last, // In the unlikely snprintf() returning a negative value (encoding error), // ensure the result won't cause an out-of-bounds access. if (n_chars_last < 0) { - n_chars = 0; + n_chars_last = 0; } *n_chars += n_chars_last; @@ -2070,6 +2149,10 @@ ArrowErrorCode ArrowMetadataBuilderRemove(struct ArrowBuffer* buffer, // under the License. #include +#include +#include +#include +#include #include #include @@ -2083,6 +2166,12 @@ static void ArrowArrayReleaseInternal(struct ArrowArray* array) { ArrowBitmapReset(&private_data->bitmap); ArrowBufferReset(&private_data->buffers[0]); ArrowBufferReset(&private_data->buffers[1]); + ArrowFree(private_data->buffer_data); + for (int32_t i = 0; i < private_data->n_variadic_buffers; ++i) { + ArrowBufferReset(&private_data->variadic_buffers[i]); + } + ArrowFree(private_data->variadic_buffers); + ArrowFree(private_data->variadic_buffer_sizes); ArrowFree(private_data); } @@ -2123,6 +2212,7 @@ static ArrowErrorCode ArrowArraySetStorageType(struct ArrowArray* array, switch (storage_type) { case NANOARROW_TYPE_UNINITIALIZED: case NANOARROW_TYPE_NA: + case NANOARROW_TYPE_RUN_END_ENCODED: array->n_buffers = 0; break; @@ -2156,7 +2246,10 @@ static ArrowErrorCode ArrowArraySetStorageType(struct ArrowArray* array, case NANOARROW_TYPE_DENSE_UNION: array->n_buffers = 2; break; - + case NANOARROW_TYPE_BINARY_VIEW: + case NANOARROW_TYPE_STRING_VIEW: + array->n_buffers = NANOARROW_BINARY_VIEW_FIXED_BUFFERS + 1; + break; case NANOARROW_TYPE_STRING: case NANOARROW_TYPE_LARGE_STRING: case NANOARROW_TYPE_BINARY: @@ -2199,12 +2292,36 @@ ArrowErrorCode ArrowArrayInitFromType(struct ArrowArray* array, ArrowBitmapInit(&private_data->bitmap); ArrowBufferInit(&private_data->buffers[0]); ArrowBufferInit(&private_data->buffers[1]); - private_data->buffer_data[0] = NULL; - private_data->buffer_data[1] = NULL; - private_data->buffer_data[2] = NULL; + private_data->buffer_data = + (const void**)ArrowMalloc(sizeof(void*) * NANOARROW_MAX_FIXED_BUFFERS); + for (int i = 0; i < NANOARROW_MAX_FIXED_BUFFERS; ++i) { + private_data->buffer_data[i] = NULL; + } + private_data->n_variadic_buffers = 0; + private_data->variadic_buffers = NULL; + private_data->variadic_buffer_sizes = NULL; array->private_data = private_data; - array->buffers = (const void**)(&private_data->buffer_data); + array->buffers = (const void**)(private_data->buffer_data); + + // These are not technically "storage" in the sense that they do not appear + // in the ArrowSchemaView's storage_type member; however, allowing them here + // is helpful to maximize the number of types that can avoid going through + // ArrowArrayInitFromSchema(). + switch (storage_type) { + case NANOARROW_TYPE_DURATION: + case NANOARROW_TYPE_TIMESTAMP: + case NANOARROW_TYPE_TIME64: + case NANOARROW_TYPE_DATE64: + storage_type = NANOARROW_TYPE_INT64; + break; + case NANOARROW_TYPE_TIME32: + case NANOARROW_TYPE_DATE32: + storage_type = NANOARROW_TYPE_INT32; + break; + default: + break; + } int result = ArrowArraySetStorageType(array, storage_type); if (result != NANOARROW_OK) { @@ -2488,10 +2605,26 @@ static void ArrowArrayFlushInternalPointers(struct ArrowArray* array) { struct ArrowArrayPrivateData* private_data = (struct ArrowArrayPrivateData*)array->private_data; - for (int64_t i = 0; i < NANOARROW_MAX_FIXED_BUFFERS; i++) { + const bool is_binary_view = private_data->storage_type == NANOARROW_TYPE_STRING_VIEW || + private_data->storage_type == NANOARROW_TYPE_BINARY_VIEW; + const int32_t nfixed_buf = is_binary_view ? 2 : NANOARROW_MAX_FIXED_BUFFERS; + + for (int32_t i = 0; i < nfixed_buf; i++) { private_data->buffer_data[i] = ArrowArrayBuffer(array, i)->data; } + if (is_binary_view) { + const int32_t nvirt_buf = private_data->n_variadic_buffers; + private_data->buffer_data = (const void**)ArrowRealloc( + private_data->buffer_data, sizeof(void*) * (nfixed_buf + nvirt_buf + 1)); + for (int32_t i = 0; i < nvirt_buf; i++) { + private_data->buffer_data[nfixed_buf + i] = private_data->variadic_buffers[i].data; + } + private_data->buffer_data[nfixed_buf + nvirt_buf] = + private_data->variadic_buffer_sizes; + array->buffers = (const void**)(private_data->buffer_data); + } + for (int64_t i = 0; i < array->n_children; i++) { ArrowArrayFlushInternalPointers(array->children[i]); } @@ -2547,6 +2680,11 @@ ArrowErrorCode ArrowArrayViewAllocateChildren(struct ArrowArrayView* array_view, return EINVAL; } + if (n_children == 0) { + array_view->n_children = 0; + return NANOARROW_OK; + } + array_view->children = (struct ArrowArrayView**)ArrowMalloc(n_children * sizeof(struct ArrowArrayView*)); if (array_view->children == NULL) { @@ -2695,6 +2833,8 @@ void ArrowArrayViewSetLength(struct ArrowArrayView* array_view, int64_t length) case NANOARROW_BUFFER_TYPE_UNION_OFFSET: array_view->buffer_views[i].size_bytes = element_size_bytes * length; continue; + case NANOARROW_BUFFER_TYPE_VARIADIC_DATA: + case NANOARROW_BUFFER_TYPE_VARIADIC_SIZE: case NANOARROW_BUFFER_TYPE_NONE: array_view->buffer_views[i].size_bytes = 0; continue; @@ -2727,9 +2867,16 @@ static int ArrowArrayViewSetArrayInternal(struct ArrowArrayView* array_view, array_view->offset = array->offset; array_view->length = array->length; array_view->null_count = array->null_count; + array_view->variadic_buffer_sizes = NULL; + array_view->variadic_buffers = NULL; + array_view->n_variadic_buffers = 0; int64_t buffers_required = 0; - for (int i = 0; i < NANOARROW_MAX_FIXED_BUFFERS; i++) { + const int nfixed_buf = array_view->storage_type == NANOARROW_TYPE_STRING_VIEW || + array_view->storage_type == NANOARROW_TYPE_BINARY_VIEW + ? NANOARROW_BINARY_VIEW_FIXED_BUFFERS + : NANOARROW_MAX_FIXED_BUFFERS; + for (int i = 0; i < nfixed_buf; i++) { if (array_view->layout.buffer_type[i] == NANOARROW_BUFFER_TYPE_NONE) { break; } @@ -2747,17 +2894,30 @@ static int ArrowArrayViewSetArrayInternal(struct ArrowArrayView* array_view, } } - // Check the number of buffers + if (array_view->storage_type == NANOARROW_TYPE_STRING_VIEW || + array_view->storage_type == NANOARROW_TYPE_BINARY_VIEW) { + const int64_t n_buffers = array->n_buffers; + const int32_t nfixed_buf = NANOARROW_BINARY_VIEW_FIXED_BUFFERS; + + const int32_t nvariadic_buf = (int32_t)(n_buffers - nfixed_buf - 1); + array_view->n_variadic_buffers = nvariadic_buf; + buffers_required += nvariadic_buf + 1; + array_view->variadic_buffers = array->buffers + NANOARROW_BINARY_VIEW_FIXED_BUFFERS; + array_view->variadic_buffer_sizes = (int64_t*)array->buffers[n_buffers - 1]; + } + if (buffers_required != array->n_buffers) { - ArrowErrorSet(error, "Expected array with %d buffer(s) but found %d buffer(s)", - (int)buffers_required, (int)array->n_buffers); + ArrowErrorSet(error, + "Expected array with %" PRId64 " buffer(s) but found %" PRId64 + " buffer(s)", + buffers_required, array->n_buffers); return EINVAL; } // Check number of children if (array_view->n_children != array->n_children) { - ArrowErrorSet(error, "Expected %ld children but found %ld children", - (long)array_view->n_children, (long)array->n_children); + ArrowErrorSet(error, "Expected %" PRId64 " children but found %" PRId64 " children", + array_view->n_children, array->n_children); return EINVAL; } @@ -2789,14 +2949,20 @@ static int ArrowArrayViewSetArrayInternal(struct ArrowArrayView* array_view, static int ArrowArrayViewValidateMinimal(struct ArrowArrayView* array_view, struct ArrowError* error) { if (array_view->length < 0) { - ArrowErrorSet(error, "Expected length >= 0 but found length %ld", - (long)array_view->length); + ArrowErrorSet(error, "Expected length >= 0 but found length %" PRId64, + array_view->length); return EINVAL; } if (array_view->offset < 0) { - ArrowErrorSet(error, "Expected offset >= 0 but found offset %ld", - (long)array_view->offset); + ArrowErrorSet(error, "Expected offset >= 0 but found offset %" PRId64, + array_view->offset); + return EINVAL; + } + + // Ensure that offset + length fits within an int64 before a possible overflow + if ((uint64_t)array_view->offset + (uint64_t)array_view->length > (uint64_t)INT64_MAX) { + ArrowErrorSet(error, "Offset + length is > INT64_MAX"); return EINVAL; } @@ -2809,7 +2975,9 @@ static int ArrowArrayViewValidateMinimal(struct ArrowArrayView* array_view, for (int i = 0; i < 2; i++) { int64_t element_size_bytes = array_view->layout.element_size_bits[i] / 8; // Initialize with a value that will cause an error if accidentally used uninitialized - int64_t min_buffer_size_bytes = array_view->buffer_views[i].size_bytes + 1; + // Need to suppress the clang-tidy warning because gcc warns for possible use + int64_t min_buffer_size_bytes = // NOLINT(clang-analyzer-deadcode.DeadStores) + array_view->buffer_views[i].size_bytes + 1; switch (array_view->layout.buffer_type[i]) { case NANOARROW_BUFFER_TYPE_VALIDITY: @@ -2835,6 +3003,8 @@ static int ArrowArrayViewValidateMinimal(struct ArrowArrayView* array_view, case NANOARROW_BUFFER_TYPE_UNION_OFFSET: min_buffer_size_bytes = element_size_bytes * offset_plus_length; break; + case NANOARROW_BUFFER_TYPE_VARIADIC_DATA: + case NANOARROW_BUFFER_TYPE_VARIADIC_SIZE: case NANOARROW_BUFFER_TYPE_NONE: continue; } @@ -2844,11 +3014,11 @@ static int ArrowArrayViewValidateMinimal(struct ArrowArrayView* array_view, array_view->buffer_views[i].size_bytes = min_buffer_size_bytes; } else if (array_view->buffer_views[i].size_bytes < min_buffer_size_bytes) { ArrowErrorSet(error, - "Expected %s array buffer %d to have size >= %ld bytes but found " - "buffer with %ld bytes", - ArrowTypeString(array_view->storage_type), (int)i, - (long)min_buffer_size_bytes, - (long)array_view->buffer_views[i].size_bytes); + "Expected %s array buffer %d to have size >= %" PRId64 + " bytes but found " + "buffer with %" PRId64 " bytes", + ArrowTypeString(array_view->storage_type), i, min_buffer_size_bytes, + array_view->buffer_views[i].size_bytes); return EINVAL; } } @@ -2860,11 +3030,20 @@ static int ArrowArrayViewValidateMinimal(struct ArrowArrayView* array_view, case NANOARROW_TYPE_FIXED_SIZE_LIST: case NANOARROW_TYPE_MAP: if (array_view->n_children != 1) { - ArrowErrorSet(error, "Expected 1 child of %s array but found %ld child arrays", - ArrowTypeString(array_view->storage_type), - (long)array_view->n_children); + ArrowErrorSet(error, + "Expected 1 child of %s array but found %" PRId64 " child arrays", + ArrowTypeString(array_view->storage_type), array_view->n_children); return EINVAL; } + break; + case NANOARROW_TYPE_RUN_END_ENCODED: + if (array_view->n_children != 2) { + ArrowErrorSet( + error, "Expected 2 children for %s array but found %" PRId64 " child arrays", + ArrowTypeString(array_view->storage_type), array_view->n_children); + return EINVAL; + } + break; default: break; } @@ -2878,12 +3057,11 @@ static int ArrowArrayViewValidateMinimal(struct ArrowArrayView* array_view, child_min_length = (array_view->offset + array_view->length); for (int64_t i = 0; i < array_view->n_children; i++) { if (array_view->children[i]->length < child_min_length) { - ArrowErrorSet( - error, - "Expected struct child %d to have length >= %ld but found child with " - "length %ld", - (int)(i + 1), (long)(child_min_length), - (long)array_view->children[i]->length); + ArrowErrorSet(error, + "Expected struct child %" PRId64 " to have length >= %" PRId64 + " but found child with " + "length %" PRId64, + i + 1, child_min_length, array_view->children[i]->length); return EINVAL; } } @@ -2894,12 +3072,78 @@ static int ArrowArrayViewValidateMinimal(struct ArrowArrayView* array_view, array_view->layout.child_size_elements; if (array_view->children[0]->length < child_min_length) { ArrowErrorSet(error, - "Expected child of fixed_size_list array to have length >= %ld but " - "found array with length %ld", - (long)child_min_length, (long)array_view->children[0]->length); + "Expected child of fixed_size_list array to have length >= %" PRId64 + " but " + "found array with length %" PRId64, + child_min_length, array_view->children[0]->length); return EINVAL; } break; + + case NANOARROW_TYPE_RUN_END_ENCODED: { + if (array_view->n_children != 2) { + ArrowErrorSet(error, + "Expected 2 children for run-end encoded array but found %" PRId64, + array_view->n_children); + return EINVAL; + } + struct ArrowArrayView* run_ends_view = array_view->children[0]; + struct ArrowArrayView* values_view = array_view->children[1]; + int64_t max_length; + switch (run_ends_view->storage_type) { + case NANOARROW_TYPE_INT16: + max_length = INT16_MAX; + break; + case NANOARROW_TYPE_INT32: + max_length = INT32_MAX; + break; + case NANOARROW_TYPE_INT64: + max_length = INT64_MAX; + break; + default: + ArrowErrorSet( + error, + "Run-end encoded array only supports INT16, INT32 or INT64 run-ends " + "but found run-ends type %s", + ArrowTypeString(run_ends_view->storage_type)); + return EINVAL; + } + + // There is already a check above that offset_plus_length < INT64_MAX + if (offset_plus_length > max_length) { + ArrowErrorSet(error, + "Offset + length of a run-end encoded array must fit in a value" + " of the run end type %s but is %" PRId64 " + %" PRId64, + ArrowTypeString(run_ends_view->storage_type), array_view->offset, + array_view->length); + return EINVAL; + } + + if (run_ends_view->length > values_view->length) { + ArrowErrorSet(error, + "Length of run_ends is greater than the length of values: %" PRId64 + " > %" PRId64, + run_ends_view->length, values_view->length); + return EINVAL; + } + + if (run_ends_view->length == 0 && values_view->length != 0) { + ArrowErrorSet(error, + "Run-end encoded array has zero length %" PRId64 + ", but values array has " + "non-zero length", + values_view->length); + return EINVAL; + } + + if (run_ends_view->null_count != 0) { + ArrowErrorSet(error, "Null count must be 0 for run ends array, but is %" PRId64, + run_ends_view->null_count); + return EINVAL; + } + break; + } + default: break; } @@ -2935,24 +3179,30 @@ static int ArrowArrayViewValidateDefault(struct ArrowArrayView* array_view, case NANOARROW_TYPE_STRING: case NANOARROW_TYPE_BINARY: if (array_view->buffer_views[1].size_bytes != 0) { - first_offset = array_view->buffer_views[1].data.as_int32[0]; + first_offset = array_view->buffer_views[1].data.as_int32[array_view->offset]; if (first_offset < 0) { - ArrowErrorSet(error, "Expected first offset >= 0 but found %ld", - (long)first_offset); + ArrowErrorSet(error, "Expected first offset >= 0 but found %" PRId64, + first_offset); return EINVAL; } last_offset = array_view->buffer_views[1].data.as_int32[offset_plus_length]; + if (last_offset < 0) { + ArrowErrorSet(error, "Expected last offset >= 0 but found %" PRId64, + last_offset); + return EINVAL; + } // If the data buffer size is unknown, assign it; otherwise, check it if (array_view->buffer_views[2].size_bytes == -1) { array_view->buffer_views[2].size_bytes = last_offset; } else if (array_view->buffer_views[2].size_bytes < last_offset) { ArrowErrorSet(error, - "Expected %s array buffer 2 to have size >= %ld bytes but found " - "buffer with %ld bytes", - ArrowTypeString(array_view->storage_type), (long)last_offset, - (long)array_view->buffer_views[2].size_bytes); + "Expected %s array buffer 2 to have size >= %" PRId64 + " bytes but found " + "buffer with %" PRId64 " bytes", + ArrowTypeString(array_view->storage_type), last_offset, + array_view->buffer_views[2].size_bytes); return EINVAL; } } else if (array_view->buffer_views[2].size_bytes == -1) { @@ -2965,24 +3215,30 @@ static int ArrowArrayViewValidateDefault(struct ArrowArrayView* array_view, case NANOARROW_TYPE_LARGE_STRING: case NANOARROW_TYPE_LARGE_BINARY: if (array_view->buffer_views[1].size_bytes != 0) { - first_offset = array_view->buffer_views[1].data.as_int64[0]; + first_offset = array_view->buffer_views[1].data.as_int64[array_view->offset]; if (first_offset < 0) { - ArrowErrorSet(error, "Expected first offset >= 0 but found %ld", - (long)first_offset); + ArrowErrorSet(error, "Expected first offset >= 0 but found %" PRId64, + first_offset); return EINVAL; } last_offset = array_view->buffer_views[1].data.as_int64[offset_plus_length]; + if (last_offset < 0) { + ArrowErrorSet(error, "Expected last offset >= 0 but found %" PRId64, + last_offset); + return EINVAL; + } // If the data buffer size is unknown, assign it; otherwise, check it if (array_view->buffer_views[2].size_bytes == -1) { array_view->buffer_views[2].size_bytes = last_offset; } else if (array_view->buffer_views[2].size_bytes < last_offset) { ArrowErrorSet(error, - "Expected %s array buffer 2 to have size >= %ld bytes but found " - "buffer with %ld bytes", - ArrowTypeString(array_view->storage_type), (long)last_offset, - (long)array_view->buffer_views[2].size_bytes); + "Expected %s array buffer 2 to have size >= %" PRId64 + " bytes but found " + "buffer with %" PRId64 " bytes", + ArrowTypeString(array_view->storage_type), last_offset, + array_view->buffer_views[2].size_bytes); return EINVAL; } } else if (array_view->buffer_views[2].size_bytes == -1) { @@ -2995,12 +3251,11 @@ static int ArrowArrayViewValidateDefault(struct ArrowArrayView* array_view, case NANOARROW_TYPE_STRUCT: for (int64_t i = 0; i < array_view->n_children; i++) { if (array_view->children[i]->length < offset_plus_length) { - ArrowErrorSet( - error, - "Expected struct child %d to have length >= %ld but found child with " - "length %ld", - (int)(i + 1), (long)offset_plus_length, - (long)array_view->children[i]->length); + ArrowErrorSet(error, + "Expected struct child %" PRId64 " to have length >= %" PRId64 + " but found child with " + "length %" PRId64, + i + 1, offset_plus_length, array_view->children[i]->length); return EINVAL; } } @@ -3009,21 +3264,27 @@ static int ArrowArrayViewValidateDefault(struct ArrowArrayView* array_view, case NANOARROW_TYPE_LIST: case NANOARROW_TYPE_MAP: if (array_view->buffer_views[1].size_bytes != 0) { - first_offset = array_view->buffer_views[1].data.as_int32[0]; + first_offset = array_view->buffer_views[1].data.as_int32[array_view->offset]; if (first_offset < 0) { - ArrowErrorSet(error, "Expected first offset >= 0 but found %ld", - (long)first_offset); + ArrowErrorSet(error, "Expected first offset >= 0 but found %" PRId64, + first_offset); return EINVAL; } last_offset = array_view->buffer_views[1].data.as_int32[offset_plus_length]; + if (last_offset < 0) { + ArrowErrorSet(error, "Expected last offset >= 0 but found %" PRId64, + last_offset); + return EINVAL; + } + if (array_view->children[0]->length < last_offset) { - ArrowErrorSet( - error, - "Expected child of %s array to have length >= %ld but found array with " - "length %ld", - ArrowTypeString(array_view->storage_type), (long)last_offset, - (long)array_view->children[0]->length); + ArrowErrorSet(error, + "Expected child of %s array to have length >= %" PRId64 + " but found array with " + "length %" PRId64, + ArrowTypeString(array_view->storage_type), last_offset, + array_view->children[0]->length); return EINVAL; } } @@ -3031,24 +3292,58 @@ static int ArrowArrayViewValidateDefault(struct ArrowArrayView* array_view, case NANOARROW_TYPE_LARGE_LIST: if (array_view->buffer_views[1].size_bytes != 0) { - first_offset = array_view->buffer_views[1].data.as_int64[0]; + first_offset = array_view->buffer_views[1].data.as_int64[array_view->offset]; if (first_offset < 0) { - ArrowErrorSet(error, "Expected first offset >= 0 but found %ld", - (long)first_offset); + ArrowErrorSet(error, "Expected first offset >= 0 but found %" PRId64, + first_offset); return EINVAL; } last_offset = array_view->buffer_views[1].data.as_int64[offset_plus_length]; + if (last_offset < 0) { + ArrowErrorSet(error, "Expected last offset >= 0 but found %" PRId64, + last_offset); + return EINVAL; + } + if (array_view->children[0]->length < last_offset) { - ArrowErrorSet( - error, - "Expected child of large list array to have length >= %ld but found array " - "with length %ld", - (long)last_offset, (long)array_view->children[0]->length); + ArrowErrorSet(error, + "Expected child of large list array to have length >= %" PRId64 + " but found array " + "with length %" PRId64, + last_offset, array_view->children[0]->length); return EINVAL; } } break; + + case NANOARROW_TYPE_RUN_END_ENCODED: { + struct ArrowArrayView* run_ends_view = array_view->children[0]; + if (run_ends_view->length == 0) { + break; + } + + int64_t first_run_end = ArrowArrayViewGetIntUnsafe(run_ends_view, 0); + if (first_run_end < 1) { + ArrowErrorSet( + error, + "All run ends must be greater than 0 but the first run end is %" PRId64, + first_run_end); + return EINVAL; + } + + // offset + length < INT64_MAX is checked in ArrowArrayViewValidateMinimal() + int64_t last_run_end = + ArrowArrayViewGetIntUnsafe(run_ends_view, run_ends_view->length - 1); + if (last_run_end < offset_plus_length) { + ArrowErrorSet(error, + "Last run end is %" PRId64 " but it should be >= (%" PRId64 + " + %" PRId64 ")", + last_run_end, array_view->offset, array_view->length); + return EINVAL; + } + break; + } default: break; } @@ -3101,7 +3396,7 @@ static int ArrowAssertIncreasingInt32(struct ArrowBufferView view, for (int64_t i = 1; i < view.size_bytes / (int64_t)sizeof(int32_t); i++) { if (view.data.as_int32[i] < view.data.as_int32[i - 1]) { - ArrowErrorSet(error, "[%ld] Expected element size >= 0", (long)i); + ArrowErrorSet(error, "[%" PRId64 "] Expected element size >= 0", i); return EINVAL; } } @@ -3117,7 +3412,7 @@ static int ArrowAssertIncreasingInt64(struct ArrowBufferView view, for (int64_t i = 1; i < view.size_bytes / (int64_t)sizeof(int64_t); i++) { if (view.data.as_int64[i] < view.data.as_int64[i - 1]) { - ArrowErrorSet(error, "[%ld] Expected element size >= 0", (long)i); + ArrowErrorSet(error, "[%" PRId64 "] Expected element size >= 0", i); return EINVAL; } } @@ -3130,8 +3425,9 @@ static int ArrowAssertRangeInt8(struct ArrowBufferView view, int8_t min_value, for (int64_t i = 0; i < view.size_bytes; i++) { if (view.data.as_int8[i] < min_value || view.data.as_int8[i] > max_value) { ArrowErrorSet(error, - "[%ld] Expected buffer value between %d and %d but found value %d", - (long)i, (int)min_value, (int)max_value, (int)view.data.as_int8[i]); + "[%" PRId64 "] Expected buffer value between %" PRId8 " and %" PRId8 + " but found value %" PRId8, + i, min_value, max_value, view.data.as_int8[i]); return EINVAL; } } @@ -3151,8 +3447,8 @@ static int ArrowAssertInt8In(struct ArrowBufferView view, const int8_t* values, } if (!item_found) { - ArrowErrorSet(error, "[%ld] Unexpected buffer value %d", (long)i, - (int)view.data.as_int8[i]); + ArrowErrorSet(error, "[%" PRId64 "] Unexpected buffer value %" PRId8, i, + view.data.as_int8[i]); return EINVAL; } } @@ -3164,13 +3460,24 @@ static int ArrowArrayViewValidateFull(struct ArrowArrayView* array_view, struct ArrowError* error) { for (int i = 0; i < NANOARROW_MAX_FIXED_BUFFERS; i++) { switch (array_view->layout.buffer_type[i]) { + // Only validate the portion of the buffer that is strictly required, + // which includes not validating the offset buffer of a zero-length array. case NANOARROW_BUFFER_TYPE_DATA_OFFSET: + if (array_view->length == 0) { + continue; + } if (array_view->layout.element_size_bits[i] == 32) { - NANOARROW_RETURN_NOT_OK( - ArrowAssertIncreasingInt32(array_view->buffer_views[i], error)); + struct ArrowBufferView sliced_offsets; + sliced_offsets.data.as_int32 = + array_view->buffer_views[i].data.as_int32 + array_view->offset; + sliced_offsets.size_bytes = (array_view->length + 1) * sizeof(int32_t); + NANOARROW_RETURN_NOT_OK(ArrowAssertIncreasingInt32(sliced_offsets, error)); } else { - NANOARROW_RETURN_NOT_OK( - ArrowAssertIncreasingInt64(array_view->buffer_views[i], error)); + struct ArrowBufferView sliced_offsets; + sliced_offsets.data.as_int64 = + array_view->buffer_views[i].data.as_int64 + array_view->offset; + sliced_offsets.size_bytes = (array_view->length + 1) * sizeof(int64_t); + NANOARROW_RETURN_NOT_OK(ArrowAssertIncreasingInt64(sliced_offsets, error)); } break; default: @@ -3180,6 +3487,15 @@ static int ArrowArrayViewValidateFull(struct ArrowArrayView* array_view, if (array_view->storage_type == NANOARROW_TYPE_DENSE_UNION || array_view->storage_type == NANOARROW_TYPE_SPARSE_UNION) { + struct ArrowBufferView sliced_type_ids; + sliced_type_ids.size_bytes = array_view->length * sizeof(int8_t); + if (array_view->length > 0) { + sliced_type_ids.data.as_int8 = + array_view->buffer_views[0].data.as_int8 + array_view->offset; + } else { + sliced_type_ids.data.as_int8 = NULL; + } + if (array_view->union_type_id_map == NULL) { // If the union_type_id map is NULL (e.g., when using ArrowArrayInitFromType() + // ArrowArrayAllocateChildren() + ArrowArrayFinishBuilding()), we don't have enough @@ -3191,9 +3507,9 @@ static int ArrowArrayViewValidateFull(struct ArrowArrayView* array_view, array_view->union_type_id_map, array_view->n_children, array_view->n_children)) { NANOARROW_RETURN_NOT_OK(ArrowAssertRangeInt8( - array_view->buffer_views[0], 0, (int8_t)(array_view->n_children - 1), error)); + sliced_type_ids, 0, (int8_t)(array_view->n_children - 1), error)); } else { - NANOARROW_RETURN_NOT_OK(ArrowAssertInt8In(array_view->buffer_views[0], + NANOARROW_RETURN_NOT_OK(ArrowAssertInt8In(sliced_type_ids, array_view->union_type_id_map + 128, array_view->n_children, error)); } @@ -3207,16 +3523,37 @@ static int ArrowArrayViewValidateFull(struct ArrowArrayView* array_view, int64_t offset = ArrowArrayViewUnionChildOffset(array_view, i); int64_t child_length = array_view->children[child_id]->length; if (offset < 0 || offset > child_length) { - ArrowErrorSet( - error, - "[%ld] Expected union offset for child id %d to be between 0 and %ld but " - "found offset value %ld", - (long)i, (int)child_id, (long)child_length, (long)offset); + ArrowErrorSet(error, + "[%" PRId64 "] Expected union offset for child id %" PRId8 + " to be between 0 and %" PRId64 + " but " + "found offset value %" PRId64, + i, child_id, child_length, offset); return EINVAL; } } } + if (array_view->storage_type == NANOARROW_TYPE_RUN_END_ENCODED) { + struct ArrowArrayView* run_ends_view = array_view->children[0]; + if (run_ends_view->length > 0) { + int64_t last_run_end = ArrowArrayViewGetIntUnsafe(run_ends_view, 0); + for (int64_t i = 1; i < run_ends_view->length; i++) { + const int64_t run_end = ArrowArrayViewGetIntUnsafe(run_ends_view, i); + if (run_end <= last_run_end) { + ArrowErrorSet( + error, + "Every run end must be strictly greater than the previous run end, " + "but run_ends[%" PRId64 " is %" PRId64 " and run_ends[%" PRId64 + "] is %" PRId64, + i, run_end, i - 1, last_run_end); + return EINVAL; + } + last_run_end = run_end; + } + } + } + // Recurse for children for (int64_t i = 0; i < array_view->n_children; i++) { NANOARROW_RETURN_NOT_OK(ArrowArrayViewValidateFull(array_view->children[i], error)); @@ -3249,6 +3586,136 @@ ArrowErrorCode ArrowArrayViewValidate(struct ArrowArrayView* array_view, ArrowErrorSet(error, "validation_level not recognized"); return EINVAL; } + +struct ArrowComparisonInternalState { + enum ArrowCompareLevel level; + int is_equal; + struct ArrowError* reason; +}; + +NANOARROW_CHECK_PRINTF_ATTRIBUTE static void ArrowComparePrependPath( + struct ArrowError* out, const char* fmt, ...) { + if (out == NULL) { + return; + } + + char prefix[128]; + prefix[0] = '\0'; + va_list args; + va_start(args, fmt); + int prefix_len = vsnprintf(prefix, sizeof(prefix), fmt, args); + va_end(args); + + if (prefix_len <= 0) { + return; + } + + size_t out_len = strlen(out->message); + size_t out_len_to_move = sizeof(struct ArrowError) - prefix_len - 1; + if (out_len_to_move > out_len) { + out_len_to_move = out_len; + } + + memmove(out->message + prefix_len, out->message, out_len_to_move); + memcpy(out->message, prefix, prefix_len); + out->message[out_len + prefix_len] = '\0'; +} + +#define SET_NOT_EQUAL_AND_RETURN_IF_IMPL(cond_, state_, reason_) \ + do { \ + if (cond_) { \ + ArrowErrorSet(state_->reason, ": %s", reason_); \ + state_->is_equal = 0; \ + return; \ + } \ + } while (0) + +#define SET_NOT_EQUAL_AND_RETURN_IF(condition_, state_) \ + SET_NOT_EQUAL_AND_RETURN_IF_IMPL(condition_, state_, #condition_) + +static void ArrowArrayViewCompareBuffer(const struct ArrowArrayView* actual, + const struct ArrowArrayView* expected, int i, + struct ArrowComparisonInternalState* state) { + SET_NOT_EQUAL_AND_RETURN_IF( + actual->buffer_views[i].size_bytes != expected->buffer_views[i].size_bytes, state); + + int64_t buffer_size = actual->buffer_views[i].size_bytes; + if (buffer_size > 0) { + SET_NOT_EQUAL_AND_RETURN_IF( + memcmp(actual->buffer_views[i].data.data, expected->buffer_views[i].data.data, + buffer_size) != 0, + state); + } +} + +static void ArrowArrayViewCompareIdentical(const struct ArrowArrayView* actual, + const struct ArrowArrayView* expected, + struct ArrowComparisonInternalState* state) { + SET_NOT_EQUAL_AND_RETURN_IF(actual->storage_type != expected->storage_type, state); + SET_NOT_EQUAL_AND_RETURN_IF(actual->n_children != expected->n_children, state); + SET_NOT_EQUAL_AND_RETURN_IF(actual->dictionary == NULL && expected->dictionary != NULL, + state); + SET_NOT_EQUAL_AND_RETURN_IF(actual->dictionary != NULL && expected->dictionary == NULL, + state); + + SET_NOT_EQUAL_AND_RETURN_IF(actual->length != expected->length, state); + SET_NOT_EQUAL_AND_RETURN_IF(actual->offset != expected->offset, state); + SET_NOT_EQUAL_AND_RETURN_IF(actual->null_count != expected->null_count, state); + + for (int i = 0; i < NANOARROW_MAX_FIXED_BUFFERS; i++) { + ArrowArrayViewCompareBuffer(actual, expected, i, state); + if (!state->is_equal) { + ArrowComparePrependPath(state->reason, ".buffers[%d]", i); + return; + } + } + + for (int64_t i = 0; i < actual->n_children; i++) { + ArrowArrayViewCompareIdentical(actual->children[i], expected->children[i], state); + if (!state->is_equal) { + ArrowComparePrependPath(state->reason, ".children[%" PRId64 "]", i); + return; + } + } + + if (actual->dictionary != NULL) { + ArrowArrayViewCompareIdentical(actual->dictionary, expected->dictionary, state); + if (!state->is_equal) { + ArrowComparePrependPath(state->reason, ".dictionary"); + return; + } + } +} + +// Top-level entry point to take care of creating, cleaning up, and +// propagating the ArrowComparisonInternalState to the caller +ArrowErrorCode ArrowArrayViewCompare(const struct ArrowArrayView* actual, + const struct ArrowArrayView* expected, + enum ArrowCompareLevel level, int* out, + struct ArrowError* reason) { + struct ArrowComparisonInternalState state; + state.level = level; + state.is_equal = 1; + state.reason = reason; + + switch (level) { + case NANOARROW_COMPARE_IDENTICAL: + ArrowArrayViewCompareIdentical(actual, expected, &state); + break; + default: + return EINVAL; + } + + *out = state.is_equal; + if (!state.is_equal) { + ArrowComparePrependPath(state.reason, "root"); + } + + return NANOARROW_OK; +} + +#undef SET_NOT_EQUAL_AND_RETURN_IF +#undef SET_NOT_EQUAL_AND_RETURN_IF_IMPL // Licensed to the Apache Software Foundation (ASF) under one // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information diff --git a/3rd_party/apache-arrow-adbc/c/vendor/nanoarrow/nanoarrow.h b/3rd_party/apache-arrow-adbc/c/vendor/nanoarrow/nanoarrow.h index e845d0a..264aad5 100644 --- a/3rd_party/apache-arrow-adbc/c/vendor/nanoarrow/nanoarrow.h +++ b/3rd_party/apache-arrow-adbc/c/vendor/nanoarrow/nanoarrow.h @@ -19,9 +19,9 @@ #define NANOARROW_BUILD_ID_H_INCLUDED #define NANOARROW_VERSION_MAJOR 0 -#define NANOARROW_VERSION_MINOR 5 +#define NANOARROW_VERSION_MINOR 6 #define NANOARROW_VERSION_PATCH 0 -#define NANOARROW_VERSION "0.5.0" +#define NANOARROW_VERSION "0.6.0" #define NANOARROW_VERSION_INT \ (NANOARROW_VERSION_MAJOR * 10000 + NANOARROW_VERSION_MINOR * 100 + \ @@ -181,14 +181,14 @@ struct ArrowArrayStream { NANOARROW_RETURN_NOT_OK((x_ <= max_) ? NANOARROW_OK : EINVAL) #if defined(NANOARROW_DEBUG) -#define _NANOARROW_RETURN_NOT_OK_WITH_ERROR_IMPL(NAME, EXPR, ERROR_PTR_EXPR, EXPR_STR) \ - do { \ - const int NAME = (EXPR); \ - if (NAME) { \ - ArrowErrorSet((ERROR_PTR_EXPR), "%s failed with errno %d\n* %s:%d", EXPR_STR, \ - NAME, __FILE__, __LINE__); \ - return NAME; \ - } \ +#define _NANOARROW_RETURN_NOT_OK_WITH_ERROR_IMPL(NAME, EXPR, ERROR_PTR_EXPR, EXPR_STR) \ + do { \ + const int NAME = (EXPR); \ + if (NAME) { \ + ArrowErrorSet((ERROR_PTR_EXPR), "%s failed with errno %d(%s)\n* %s:%d", EXPR_STR, \ + NAME, strerror(NAME), __FILE__, __LINE__); \ + return NAME; \ + } \ } while (0) #else #define _NANOARROW_RETURN_NOT_OK_WITH_ERROR_IMPL(NAME, EXPR, ERROR_PTR_EXPR, EXPR_STR) \ @@ -482,7 +482,10 @@ enum ArrowType { NANOARROW_TYPE_LARGE_STRING, NANOARROW_TYPE_LARGE_BINARY, NANOARROW_TYPE_LARGE_LIST, - NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO + NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO, + NANOARROW_TYPE_RUN_END_ENCODED, + NANOARROW_TYPE_BINARY_VIEW, + NANOARROW_TYPE_STRING_VIEW }; /// \brief Get a string value of an enum ArrowType value @@ -569,6 +572,12 @@ static inline const char* ArrowTypeString(enum ArrowType type) { return "large_list"; case NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO: return "interval_month_day_nano"; + case NANOARROW_TYPE_RUN_END_ENCODED: + return "run_end_encoded"; + case NANOARROW_TYPE_BINARY_VIEW: + return "binary_view"; + case NANOARROW_TYPE_STRING_VIEW: + return "string_view"; default: return NULL; } @@ -605,6 +614,17 @@ enum ArrowValidationLevel { NANOARROW_VALIDATION_LEVEL_FULL = 3 }; +/// \brief Comparison level enumerator +/// \ingroup nanoarrow-utils +enum ArrowCompareLevel { + /// \brief Consider arrays equal if buffers contain identical content + /// and have identical offset, null count, and length. Note that this is + /// a much stricter check than logical equality, which would take into + /// account potentially different content of null slots, arrays with a + /// non-zero offset, and other considerations. + NANOARROW_COMPARE_IDENTICAL, +}; + /// \brief Get a string value of an enum ArrowTimeUnit value /// \ingroup nanoarrow-utils /// @@ -634,15 +654,13 @@ enum ArrowBufferType { NANOARROW_BUFFER_TYPE_TYPE_ID, NANOARROW_BUFFER_TYPE_UNION_OFFSET, NANOARROW_BUFFER_TYPE_DATA_OFFSET, - NANOARROW_BUFFER_TYPE_DATA + NANOARROW_BUFFER_TYPE_DATA, + NANOARROW_BUFFER_TYPE_VARIADIC_DATA, + NANOARROW_BUFFER_TYPE_VARIADIC_SIZE }; -/// \brief The maximum number of buffers in an ArrowArrayView or ArrowLayout +/// \brief The maximum number of fixed buffers in an ArrowArrayView or ArrowLayout /// \ingroup nanoarrow-array-view -/// -/// All currently supported types have 3 buffers or fewer; however, future types -/// may involve a variable number of buffers (e.g., string view). These buffers -/// will be represented by separate members of the ArrowArrayView or ArrowLayout. #define NANOARROW_MAX_FIXED_BUFFERS 3 /// \brief An non-owning view of a string @@ -689,6 +707,7 @@ union ArrowBufferViewData { const double* as_double; const float* as_float; const char* as_char; + const union ArrowBinaryView* as_binary_view; }; /// \brief An non-owning view of a buffer @@ -826,6 +845,15 @@ struct ArrowArrayView { /// type_id == union_type_id_map[128 + child_index]. This value may be /// NULL in the case where child_id == type_id. int8_t* union_type_id_map; + + /// \brief Number of variadic buffers + int32_t n_variadic_buffers; + + /// \brief Pointers to variadic buffers of binary/string_view arrays + const void** variadic_buffers; + + /// \brief Size of each variadic buffer + int64_t* variadic_buffer_sizes; }; // Used as the private data member for ArrowArrays allocated here and accessed @@ -840,8 +868,8 @@ struct ArrowArrayPrivateData { // The array of pointers to buffers. This must be updated after a sequence // of appends to synchronize its values with the actual buffer addresses - // (which may have ben reallocated uring that time) - const void* buffer_data[NANOARROW_MAX_FIXED_BUFFERS]; + // (which may have been reallocated during that time) + const void** buffer_data; // The storage data type, or NANOARROW_TYPE_UNINITIALIZED if unknown enum ArrowType storage_type; @@ -853,6 +881,15 @@ struct ArrowArrayPrivateData { // In the future this could be replaced with a type id<->child mapping // to support constructing unions in append mode where type_id != child_index int8_t union_type_id_is_child_index; + + // Number of variadic buffers for binary view types + int32_t n_variadic_buffers; + + // Variadic buffers for binary view types + struct ArrowBuffer* variadic_buffers; + + // Size of each variadic buffer in bytes + int64_t* variadic_buffer_sizes; }; /// \brief A representation of an interval. @@ -911,7 +948,7 @@ static inline void ArrowDecimalInit(struct ArrowDecimal* decimal, int32_t bitwid memset(decimal->words, 0, sizeof(decimal->words)); decimal->precision = precision; decimal->scale = scale; - decimal->n_words = bitwidth / 8 / sizeof(uint64_t); + decimal->n_words = (int)(bitwidth / 8 / sizeof(uint64_t)); if (_ArrowIsLittleEndian()) { decimal->low_word_index = 0; @@ -1052,6 +1089,8 @@ static inline void ArrowDecimalSetBytes(struct ArrowDecimal* decimal, NANOARROW_SYMBOL(NANOARROW_NAMESPACE, ArrowSchemaSetTypeFixedSize) #define ArrowSchemaSetTypeDecimal \ NANOARROW_SYMBOL(NANOARROW_NAMESPACE, ArrowSchemaSetTypeDecimal) +#define ArrowSchemaSetTypeRunEndEncoded \ + NANOARROW_SYMBOL(NANOARROW_NAMESPACE, ArrowSchemaSetTypeRunEndEncoded) #define ArrowSchemaSetTypeDateTime \ NANOARROW_SYMBOL(NANOARROW_NAMESPACE, ArrowSchemaSetTypeDateTime) #define ArrowSchemaSetTypeUnion \ @@ -1118,6 +1157,7 @@ static inline void ArrowDecimalSetBytes(struct ArrowDecimal* decimal, NANOARROW_SYMBOL(NANOARROW_NAMESPACE, ArrowArrayViewSetArrayMinimal) #define ArrowArrayViewValidate \ NANOARROW_SYMBOL(NANOARROW_NAMESPACE, ArrowArrayViewValidate) +#define ArrowArrayViewCompare NANOARROW_SYMBOL(NANOARROW_NAMESPACE, ArrowArrayViewCompare) #define ArrowArrayViewReset NANOARROW_SYMBOL(NANOARROW_NAMESPACE, ArrowArrayViewReset) #define ArrowBasicArrayStreamInit \ NANOARROW_SYMBOL(NANOARROW_NAMESPACE, ArrowBasicArrayStreamInit) @@ -1281,6 +1321,12 @@ ArrowErrorCode ArrowDecimalSetDigits(struct ArrowDecimal* decimal, ArrowErrorCode ArrowDecimalAppendDigitsToBuffer(const struct ArrowDecimal* decimal, struct ArrowBuffer* buffer); +/// \brief Get the half float value of a float +static inline uint16_t ArrowFloatToHalfFloat(float value); + +/// \brief Get the float value of a half float +static inline float ArrowHalfFloatToFloat(uint16_t value); + /// \brief Resolve a chunk index from increasing int64_t offsets /// /// Given a buffer of increasing int64_t offsets that begin with 0 (e.g., offset buffer @@ -1358,6 +1404,17 @@ ArrowErrorCode ArrowSchemaSetTypeDecimal(struct ArrowSchema* schema, enum ArrowT int32_t decimal_precision, int32_t decimal_scale); +/// \brief Set the format field of a run-end encoded schema +/// +/// Returns EINVAL for run_end_type that is not +/// NANOARROW_TYPE_INT16, NANOARROW_TYPE_INT32 or NANOARROW_TYPE_INT64. +/// Schema must have been initialized using ArrowSchemaInit() or ArrowSchemaDeepCopy(). +/// The caller must call `ArrowSchemaSetTypeXXX(schema->children[1])` to +/// set the value type. Note that when building arrays using the `ArrowArrayAppendXXX()` +/// functions, the run-end encoded array's logical length must be updated manually. +ArrowErrorCode ArrowSchemaSetTypeRunEndEncoded(struct ArrowSchema* schema, + enum ArrowType run_end_type); + /// \brief Set the format field of a time, timestamp, or duration schema /// /// Returns EINVAL for type that is not @@ -2025,6 +2082,48 @@ ArrowErrorCode ArrowArrayViewSetArrayMinimal(struct ArrowArrayView* array_view, const struct ArrowArray* array, struct ArrowError* error); +/// \brief Get the number of buffers +/// +/// The number of buffers referred to by this ArrowArrayView. In may cases this can also +/// be calculated from the ArrowLayout member of the ArrowArrayView or ArrowSchemaView; +/// however, for binary view and string view types, the number of total buffers depends on +/// the number of variadic buffers. +static inline int64_t ArrowArrayViewGetNumBuffers(struct ArrowArrayView* array_view); + +/// \brief Get a view of a specific buffer from an ArrowArrayView +/// +/// This is the ArrowArrayView equivalent of ArrowArray::buffers[i] that includes +/// size information (if known). +static inline struct ArrowBufferView ArrowArrayViewGetBufferView( + struct ArrowArrayView* array_view, int64_t i); + +/// \brief Get the function of a specific buffer in an ArrowArrayView +/// +/// In may cases this can also be obtained from the ArrowLayout member of the +/// ArrowArrayView or ArrowSchemaView; however, for binary view and string view types, +/// the function of each buffer may be different between two arrays of the same type +/// depending on the number of variadic buffers. +static inline enum ArrowBufferType ArrowArrayViewGetBufferType( + struct ArrowArrayView* array_view, int64_t i); + +/// \brief Get the data type of a specific buffer in an ArrowArrayView +/// +/// In may cases this can also be obtained from the ArrowLayout member of the +/// ArrowArrayView or ArrowSchemaView; however, for binary view and string view types, +/// the data type of each buffer may be different between two arrays of the same type +/// depending on the number of variadic buffers. +static inline enum ArrowType ArrowArrayViewGetBufferDataType( + struct ArrowArrayView* array_view, int64_t i); + +/// \brief Get the element size (in bits) of a specific buffer in an ArrowArrayView +/// +/// In may cases this can also be obtained from the ArrowLayout member of the +/// ArrowArrayView or ArrowSchemaView; however, for binary view and string view types, +/// the element width of each buffer may be different between two arrays of the same type +/// depending on the number of variadic buffers. +static inline int64_t ArrowArrayViewGetBufferElementSizeBits( + struct ArrowArrayView* array_view, int64_t i); + /// \brief Performs checks on the content of an ArrowArrayView /// /// If using ArrowArrayViewSetArray() to back array_view with an ArrowArray, @@ -2037,6 +2136,19 @@ ArrowErrorCode ArrowArrayViewValidate(struct ArrowArrayView* array_view, enum ArrowValidationLevel validation_level, struct ArrowError* error); +/// \brief Compare two ArrowArrayView objects for equality +/// +/// Given two ArrowArrayView instances, place either 0 (not equal) and +/// 1 (equal) at the address pointed to by out. If the comparison determines +/// that actual and expected are not equal, a reason will be communicated via +/// error if error is non-NULL. +/// +/// Returns NANOARROW_OK if the comparison completed successfully. +ArrowErrorCode ArrowArrayViewCompare(const struct ArrowArrayView* actual, + const struct ArrowArrayView* expected, + enum ArrowCompareLevel level, int* out, + struct ArrowError* reason); + /// \brief Reset the contents of an ArrowArrayView and frees resources void ArrowArrayViewReset(struct ArrowArrayView* array_view); @@ -2044,6 +2156,10 @@ void ArrowArrayViewReset(struct ArrowArrayView* array_view); static inline int8_t ArrowArrayViewIsNull(const struct ArrowArrayView* array_view, int64_t i); +/// \brief Compute null count for an ArrowArrayView +static inline int64_t ArrowArrayViewComputeNullCount( + const struct ArrowArrayView* array_view); + /// \brief Get the type id of a union array element static inline int8_t ArrowArrayViewUnionTypeId(const struct ArrowArrayView* array_view, int64_t i); @@ -2233,6 +2349,57 @@ static inline int64_t _ArrowGrowByFactor(int64_t current_capacity, int64_t new_c } } +// float to half float conversion, adapted from Arrow Go +// https://github.com/apache/arrow/blob/main/go/arrow/float16/float16.go +static inline uint16_t ArrowFloatToHalfFloat(float value) { + union { + float f; + uint32_t b; + } u; + u.f = value; + + uint16_t sn = (uint16_t)((u.b >> 31) & 0x1); + uint16_t exp = (u.b >> 23) & 0xff; + int16_t res = (int16_t)(exp - 127 + 15); + uint16_t fc = (uint16_t)(u.b >> 13) & 0x3ff; + + if (exp == 0) { + res = 0; + } else if (exp == 0xff) { + res = 0x1f; + } else if (res > 0x1e) { + res = 0x1f; + fc = 0; + } else if (res < 0x01) { + res = 0; + fc = 0; + } + + return (uint16_t)((sn << 15) | (uint16_t)(res << 10) | fc); +} + +// half float to float conversion, adapted from Arrow Go +// https://github.com/apache/arrow/blob/main/go/arrow/float16/float16.go +static inline float ArrowHalfFloatToFloat(uint16_t value) { + uint32_t sn = (uint32_t)((value >> 15) & 0x1); + uint32_t exp = (value >> 10) & 0x1f; + uint32_t res = exp + 127 - 15; + uint32_t fc = value & 0x3ff; + + if (exp == 0) { + res = 0; + } else if (exp == 0x1f) { + res = 0xff; + } + + union { + float f; + uint32_t b; + } u; + u.b = (uint32_t)(sn << 31) | (uint32_t)(res << 23) | (uint32_t)(fc << 13); + return u.f; +} + static inline void ArrowBufferInit(struct ArrowBuffer* buffer) { buffer->data = NULL; buffer->size_bytes = 0; @@ -2316,6 +2483,7 @@ static inline ArrowErrorCode ArrowBufferReserve(struct ArrowBuffer* buffer, static inline void ArrowBufferAppendUnsafe(struct ArrowBuffer* buffer, const void* data, int64_t size_bytes) { if (size_bytes > 0) { + NANOARROW_DCHECK(buffer->data != NULL); memcpy(buffer->data + buffer->size_bytes, data, size_bytes); buffer->size_bytes += size_bytes; } @@ -2391,10 +2559,16 @@ static inline ArrowErrorCode ArrowBufferAppendBufferView(struct ArrowBuffer* buf static inline ArrowErrorCode ArrowBufferAppendFill(struct ArrowBuffer* buffer, uint8_t value, int64_t size_bytes) { + if (size_bytes == 0) { + return NANOARROW_OK; + } + NANOARROW_RETURN_NOT_OK(ArrowBufferReserve(buffer, size_bytes)); + NANOARROW_DCHECK(buffer->data != NULL); // To help clang-tidy memset(buffer->data + buffer->size_bytes, value, size_bytes); buffer->size_bytes += size_bytes; + return NANOARROW_OK; } @@ -2511,6 +2685,8 @@ static inline void ArrowBitsUnpackInt32(const uint8_t* bits, int64_t start_offse return; } + NANOARROW_DCHECK(bits != NULL && out != NULL); + const int64_t i_begin = start_offset; const int64_t i_end = start_offset + length; const int64_t i_last_valid = i_end - 1; @@ -2553,12 +2729,18 @@ static inline void ArrowBitClear(uint8_t* bits, int64_t i) { } static inline void ArrowBitSetTo(uint8_t* bits, int64_t i, uint8_t bit_is_set) { - bits[i / 8] ^= - ((uint8_t)(-((uint8_t)(bit_is_set != 0)) ^ bits[i / 8])) & _ArrowkBitmask[i % 8]; + bits[i / 8] ^= (uint8_t)(((uint8_t)(-((uint8_t)(bit_is_set != 0)) ^ bits[i / 8])) & + _ArrowkBitmask[i % 8]); } static inline void ArrowBitsSetTo(uint8_t* bits, int64_t start_offset, int64_t length, uint8_t bits_are_set) { + if (length == 0) { + return; + } + + NANOARROW_DCHECK(bits != NULL); + const int64_t i_begin = start_offset; const int64_t i_end = start_offset + length; const uint8_t fill_byte = (uint8_t)(-bits_are_set); @@ -2602,6 +2784,8 @@ static inline int64_t ArrowBitCountSet(const uint8_t* bits, int64_t start_offset return 0; } + NANOARROW_DCHECK(bits != NULL); + const int64_t i_begin = start_offset; const int64_t i_end = start_offset + length; const int64_t i_last_valid = i_end - 1; @@ -3095,6 +3279,8 @@ static inline ArrowErrorCode _ArrowArrayAppendEmptyInternal(struct ArrowArray* a switch (private_data->layout.buffer_type[i]) { case NANOARROW_BUFFER_TYPE_NONE: + case NANOARROW_BUFFER_TYPE_VARIADIC_DATA: + case NANOARROW_BUFFER_TYPE_VARIADIC_SIZE: case NANOARROW_BUFFER_TYPE_VALIDITY: continue; case NANOARROW_BUFFER_TYPE_DATA_OFFSET: @@ -3173,6 +3359,10 @@ static inline ArrowErrorCode ArrowArrayAppendInt(struct ArrowArray* array, case NANOARROW_TYPE_FLOAT: NANOARROW_RETURN_NOT_OK(ArrowBufferAppendFloat(data_buffer, (float)value)); break; + case NANOARROW_TYPE_HALF_FLOAT: + NANOARROW_RETURN_NOT_OK( + ArrowBufferAppendUInt16(data_buffer, ArrowFloatToHalfFloat((float)value))); + break; case NANOARROW_TYPE_BOOL: NANOARROW_RETURN_NOT_OK(_ArrowArrayAppendBits(array, 1, value != 0, 1)); break; @@ -3223,6 +3413,10 @@ static inline ArrowErrorCode ArrowArrayAppendUInt(struct ArrowArray* array, case NANOARROW_TYPE_FLOAT: NANOARROW_RETURN_NOT_OK(ArrowBufferAppendFloat(data_buffer, (float)value)); break; + case NANOARROW_TYPE_HALF_FLOAT: + NANOARROW_RETURN_NOT_OK( + ArrowBufferAppendUInt16(data_buffer, ArrowFloatToHalfFloat((float)value))); + break; case NANOARROW_TYPE_BOOL: NANOARROW_RETURN_NOT_OK(_ArrowArrayAppendBits(array, 1, value != 0, 1)); break; @@ -3252,6 +3446,10 @@ static inline ArrowErrorCode ArrowArrayAppendDouble(struct ArrowArray* array, case NANOARROW_TYPE_FLOAT: NANOARROW_RETURN_NOT_OK(ArrowBufferAppendFloat(data_buffer, (float)value)); break; + case NANOARROW_TYPE_HALF_FLOAT: + NANOARROW_RETURN_NOT_OK( + ArrowBufferAppendUInt16(data_buffer, ArrowFloatToHalfFloat((float)value))); + break; default: return EINVAL; } @@ -3264,52 +3462,151 @@ static inline ArrowErrorCode ArrowArrayAppendDouble(struct ArrowArray* array, return NANOARROW_OK; } +// Binary views only have two fixed buffers, but be aware that they must also +// always have more 1 buffer to store variadic buffer sizes (even if there are none) +#define NANOARROW_BINARY_VIEW_FIXED_BUFFERS 2 +#define NANOARROW_BINARY_VIEW_INLINE_SIZE 12 +#define NANOARROW_BINARY_VIEW_PREFIX_SIZE 4 +#define NANOARROW_BINARY_VIEW_BLOCK_SIZE (32 << 10) // 32KB + +// The Arrow C++ implementation uses anonymous structs as members +// of the ArrowBinaryView. For Cython support in this library, we define +// those structs outside of the ArrowBinaryView +struct ArrowBinaryViewInlined { + int32_t size; + uint8_t data[NANOARROW_BINARY_VIEW_INLINE_SIZE]; +}; + +struct ArrowBinaryViewRef { + int32_t size; + uint8_t prefix[NANOARROW_BINARY_VIEW_PREFIX_SIZE]; + int32_t buffer_index; + int32_t offset; +}; + +union ArrowBinaryView { + struct ArrowBinaryViewInlined inlined; + struct ArrowBinaryViewRef ref; + int64_t alignment_dummy; +}; + +static inline int32_t ArrowArrayVariadicBufferCount(struct ArrowArray* array) { + struct ArrowArrayPrivateData* private_data = + (struct ArrowArrayPrivateData*)array->private_data; + + return private_data->n_variadic_buffers; +} + +static inline ArrowErrorCode ArrowArrayAddVariadicBuffers(struct ArrowArray* array, + int32_t nbuffers) { + const int32_t n_current_bufs = ArrowArrayVariadicBufferCount(array); + const int32_t nvariadic_bufs_needed = n_current_bufs + nbuffers; + + struct ArrowArrayPrivateData* private_data = + (struct ArrowArrayPrivateData*)array->private_data; + + private_data->variadic_buffers = (struct ArrowBuffer*)ArrowRealloc( + private_data->variadic_buffers, sizeof(struct ArrowBuffer) * nvariadic_bufs_needed); + if (private_data->variadic_buffers == NULL) { + return ENOMEM; + } + private_data->variadic_buffer_sizes = (int64_t*)ArrowRealloc( + private_data->variadic_buffer_sizes, sizeof(int64_t) * nvariadic_bufs_needed); + if (private_data->variadic_buffer_sizes == NULL) { + return ENOMEM; + } + + for (int32_t i = n_current_bufs; i < nvariadic_bufs_needed; i++) { + ArrowBufferInit(&private_data->variadic_buffers[i]); + private_data->variadic_buffer_sizes[i] = 0; + } + private_data->n_variadic_buffers = nvariadic_bufs_needed; + array->n_buffers = NANOARROW_BINARY_VIEW_FIXED_BUFFERS + 1 + nvariadic_bufs_needed; + + return NANOARROW_OK; +} + static inline ArrowErrorCode ArrowArrayAppendBytes(struct ArrowArray* array, struct ArrowBufferView value) { struct ArrowArrayPrivateData* private_data = (struct ArrowArrayPrivateData*)array->private_data; - struct ArrowBuffer* offset_buffer = ArrowArrayBuffer(array, 1); - struct ArrowBuffer* data_buffer = ArrowArrayBuffer( - array, 1 + (private_data->storage_type != NANOARROW_TYPE_FIXED_SIZE_BINARY)); - int32_t offset; - int64_t large_offset; - int64_t fixed_size_bytes = private_data->layout.element_size_bits[1] / 8; + if (private_data->storage_type == NANOARROW_TYPE_STRING_VIEW || + private_data->storage_type == NANOARROW_TYPE_BINARY_VIEW) { + struct ArrowBuffer* data_buffer = ArrowArrayBuffer(array, 1); + union ArrowBinaryView bvt; + bvt.inlined.size = (int32_t)value.size_bytes; - switch (private_data->storage_type) { - case NANOARROW_TYPE_STRING: - case NANOARROW_TYPE_BINARY: - offset = ((int32_t*)offset_buffer->data)[array->length]; - if ((((int64_t)offset) + value.size_bytes) > INT32_MAX) { - return EOVERFLOW; + if (value.size_bytes <= NANOARROW_BINARY_VIEW_INLINE_SIZE) { + memcpy(bvt.inlined.data, value.data.as_char, value.size_bytes); + memset(bvt.inlined.data + bvt.inlined.size, 0, + NANOARROW_BINARY_VIEW_INLINE_SIZE - bvt.inlined.size); + } else { + int32_t current_n_vbufs = ArrowArrayVariadicBufferCount(array); + if (current_n_vbufs == 0 || + private_data->variadic_buffers[current_n_vbufs - 1].size_bytes + + value.size_bytes > + NANOARROW_BINARY_VIEW_BLOCK_SIZE) { + const int32_t additional_bufs_needed = 1; + NANOARROW_RETURN_NOT_OK( + ArrowArrayAddVariadicBuffers(array, additional_bufs_needed)); + current_n_vbufs += additional_bufs_needed; } - offset += (int32_t)value.size_bytes; - NANOARROW_RETURN_NOT_OK(ArrowBufferAppend(offset_buffer, &offset, sizeof(int32_t))); + const int32_t buf_index = current_n_vbufs - 1; + struct ArrowBuffer* variadic_buf = &private_data->variadic_buffers[buf_index]; + memcpy(bvt.ref.prefix, value.data.as_char, NANOARROW_BINARY_VIEW_PREFIX_SIZE); + bvt.ref.buffer_index = (int32_t)buf_index; + bvt.ref.offset = (int32_t)variadic_buf->size_bytes; NANOARROW_RETURN_NOT_OK( - ArrowBufferAppend(data_buffer, value.data.data, value.size_bytes)); - break; + ArrowBufferAppend(variadic_buf, value.data.as_char, value.size_bytes)); + private_data->variadic_buffer_sizes[buf_index] = variadic_buf->size_bytes; + } + NANOARROW_RETURN_NOT_OK(ArrowBufferAppend(data_buffer, &bvt, sizeof(bvt))); + } else { + struct ArrowBuffer* offset_buffer = ArrowArrayBuffer(array, 1); + struct ArrowBuffer* data_buffer = ArrowArrayBuffer( + array, 1 + (private_data->storage_type != NANOARROW_TYPE_FIXED_SIZE_BINARY)); + int32_t offset; + int64_t large_offset; + int64_t fixed_size_bytes = private_data->layout.element_size_bits[1] / 8; + + switch (private_data->storage_type) { + case NANOARROW_TYPE_STRING: + case NANOARROW_TYPE_BINARY: + offset = ((int32_t*)offset_buffer->data)[array->length]; + if ((((int64_t)offset) + value.size_bytes) > INT32_MAX) { + return EOVERFLOW; + } - case NANOARROW_TYPE_LARGE_STRING: - case NANOARROW_TYPE_LARGE_BINARY: - large_offset = ((int64_t*)offset_buffer->data)[array->length]; - large_offset += value.size_bytes; - NANOARROW_RETURN_NOT_OK( - ArrowBufferAppend(offset_buffer, &large_offset, sizeof(int64_t))); - NANOARROW_RETURN_NOT_OK( - ArrowBufferAppend(data_buffer, value.data.data, value.size_bytes)); - break; + offset += (int32_t)value.size_bytes; + NANOARROW_RETURN_NOT_OK( + ArrowBufferAppend(offset_buffer, &offset, sizeof(int32_t))); + NANOARROW_RETURN_NOT_OK( + ArrowBufferAppend(data_buffer, value.data.data, value.size_bytes)); + break; - case NANOARROW_TYPE_FIXED_SIZE_BINARY: - if (value.size_bytes != fixed_size_bytes) { - return EINVAL; - } + case NANOARROW_TYPE_LARGE_STRING: + case NANOARROW_TYPE_LARGE_BINARY: + large_offset = ((int64_t*)offset_buffer->data)[array->length]; + large_offset += value.size_bytes; + NANOARROW_RETURN_NOT_OK( + ArrowBufferAppend(offset_buffer, &large_offset, sizeof(int64_t))); + NANOARROW_RETURN_NOT_OK( + ArrowBufferAppend(data_buffer, value.data.data, value.size_bytes)); + break; - NANOARROW_RETURN_NOT_OK( - ArrowBufferAppend(data_buffer, value.data.data, value.size_bytes)); - break; - default: - return EINVAL; + case NANOARROW_TYPE_FIXED_SIZE_BINARY: + if (value.size_bytes != fixed_size_bytes) { + return EINVAL; + } + + NANOARROW_RETURN_NOT_OK( + ArrowBufferAppend(data_buffer, value.data.data, value.size_bytes)); + break; + default: + return EINVAL; + } } if (private_data->bitmap.buffer.data != NULL) { @@ -3332,8 +3629,10 @@ static inline ArrowErrorCode ArrowArrayAppendString(struct ArrowArray* array, switch (private_data->storage_type) { case NANOARROW_TYPE_STRING: case NANOARROW_TYPE_LARGE_STRING: + case NANOARROW_TYPE_STRING_VIEW: case NANOARROW_TYPE_BINARY: case NANOARROW_TYPE_LARGE_BINARY: + case NANOARROW_TYPE_BINARY_VIEW: return ArrowArrayAppendBytes(array, buffer_view); default: return EINVAL; @@ -3520,6 +3819,132 @@ static inline void ArrowArrayViewMove(struct ArrowArrayView* src, ArrowArrayViewInitFromType(src, NANOARROW_TYPE_UNINITIALIZED); } +static inline int64_t ArrowArrayViewGetNumBuffers(struct ArrowArrayView* array_view) { + switch (array_view->storage_type) { + case NANOARROW_TYPE_BINARY_VIEW: + case NANOARROW_TYPE_STRING_VIEW: + return NANOARROW_BINARY_VIEW_FIXED_BUFFERS + array_view->n_variadic_buffers + 1; + default: + break; + } + + int64_t n_buffers = 0; + for (int i = 0; i < NANOARROW_MAX_FIXED_BUFFERS; i++) { + if (array_view->layout.buffer_type[i] == NANOARROW_BUFFER_TYPE_NONE) { + break; + } + + n_buffers++; + } + + return n_buffers; +} + +static inline struct ArrowBufferView ArrowArrayViewGetBufferView( + struct ArrowArrayView* array_view, int64_t i) { + switch (array_view->storage_type) { + case NANOARROW_TYPE_BINARY_VIEW: + case NANOARROW_TYPE_STRING_VIEW: + if (i < NANOARROW_BINARY_VIEW_FIXED_BUFFERS) { + return array_view->buffer_views[i]; + } else if (i >= + (array_view->n_variadic_buffers + NANOARROW_BINARY_VIEW_FIXED_BUFFERS)) { + struct ArrowBufferView view; + view.data.as_int64 = array_view->variadic_buffer_sizes; + view.size_bytes = array_view->n_variadic_buffers * sizeof(double); + return view; + } else { + struct ArrowBufferView view; + view.data.data = + array_view->variadic_buffers[i - NANOARROW_BINARY_VIEW_FIXED_BUFFERS]; + view.size_bytes = + array_view->variadic_buffer_sizes[i - NANOARROW_BINARY_VIEW_FIXED_BUFFERS]; + return view; + } + default: + // We need this check to avoid -Warray-bounds from complaining + if (i >= NANOARROW_MAX_FIXED_BUFFERS) { + struct ArrowBufferView view; + view.data.data = NULL; + view.size_bytes = 0; + return view; + } else { + return array_view->buffer_views[i]; + } + } +} + +enum ArrowBufferType ArrowArrayViewGetBufferType(struct ArrowArrayView* array_view, + int64_t i) { + switch (array_view->storage_type) { + case NANOARROW_TYPE_BINARY_VIEW: + case NANOARROW_TYPE_STRING_VIEW: + if (i < NANOARROW_BINARY_VIEW_FIXED_BUFFERS) { + return array_view->layout.buffer_type[i]; + } else if (i == + (array_view->n_variadic_buffers + NANOARROW_BINARY_VIEW_FIXED_BUFFERS)) { + return NANOARROW_BUFFER_TYPE_VARIADIC_SIZE; + } else { + return NANOARROW_BUFFER_TYPE_VARIADIC_DATA; + } + default: + // We need this check to avoid -Warray-bounds from complaining + if (i >= NANOARROW_MAX_FIXED_BUFFERS) { + return NANOARROW_BUFFER_TYPE_NONE; + } else { + return array_view->layout.buffer_type[i]; + } + } +} + +static inline enum ArrowType ArrowArrayViewGetBufferDataType( + struct ArrowArrayView* array_view, int64_t i) { + switch (array_view->storage_type) { + case NANOARROW_TYPE_BINARY_VIEW: + case NANOARROW_TYPE_STRING_VIEW: + if (i < NANOARROW_BINARY_VIEW_FIXED_BUFFERS) { + return array_view->layout.buffer_data_type[i]; + } else if (i >= + (array_view->n_variadic_buffers + NANOARROW_BINARY_VIEW_FIXED_BUFFERS)) { + return NANOARROW_TYPE_INT64; + } else if (array_view->storage_type == NANOARROW_TYPE_BINARY_VIEW) { + return NANOARROW_TYPE_BINARY; + } else { + return NANOARROW_TYPE_STRING; + } + default: + // We need this check to avoid -Warray-bounds from complaining + if (i >= NANOARROW_MAX_FIXED_BUFFERS) { + return NANOARROW_TYPE_UNINITIALIZED; + } else { + return array_view->layout.buffer_data_type[i]; + } + } +} + +static inline int64_t ArrowArrayViewGetBufferElementSizeBits( + struct ArrowArrayView* array_view, int64_t i) { + switch (array_view->storage_type) { + case NANOARROW_TYPE_BINARY_VIEW: + case NANOARROW_TYPE_STRING_VIEW: + if (i < NANOARROW_BINARY_VIEW_FIXED_BUFFERS) { + return array_view->layout.element_size_bits[i]; + } else if (i >= + (array_view->n_variadic_buffers + NANOARROW_BINARY_VIEW_FIXED_BUFFERS)) { + return sizeof(int64_t) * 8; + } else { + return 0; + } + default: + // We need this check to avoid -Warray-bounds from complaining + if (i >= NANOARROW_MAX_FIXED_BUFFERS) { + return 0; + } else { + return array_view->layout.element_size_bits[i]; + } + } +} + static inline int8_t ArrowArrayViewIsNull(const struct ArrowArrayView* array_view, int64_t i) { const uint8_t* validity_buffer = array_view->buffer_views[0].data.as_uint8; @@ -3536,12 +3961,37 @@ static inline int8_t ArrowArrayViewIsNull(const struct ArrowArrayView* array_vie } } +static inline int64_t ArrowArrayViewComputeNullCount( + const struct ArrowArrayView* array_view) { + if (array_view->length == 0) { + return 0; + } + + switch (array_view->storage_type) { + case NANOARROW_TYPE_NA: + return array_view->length; + case NANOARROW_TYPE_DENSE_UNION: + case NANOARROW_TYPE_SPARSE_UNION: + // Unions are "never null" in Arrow land + return 0; + default: + break; + } + + const uint8_t* validity_buffer = array_view->buffer_views[0].data.as_uint8; + if (validity_buffer == NULL) { + return 0; + } + return array_view->length - + ArrowBitCountSet(validity_buffer, array_view->offset, array_view->length); +} + static inline int8_t ArrowArrayViewUnionTypeId(const struct ArrowArrayView* array_view, int64_t i) { switch (array_view->storage_type) { case NANOARROW_TYPE_DENSE_UNION: case NANOARROW_TYPE_SPARSE_UNION: - return array_view->buffer_views[0].data.as_int8[i]; + return array_view->buffer_views[0].data.as_int8[array_view->offset + i]; default: return -1; } @@ -3561,9 +4011,9 @@ static inline int64_t ArrowArrayViewUnionChildOffset( const struct ArrowArrayView* array_view, int64_t i) { switch (array_view->storage_type) { case NANOARROW_TYPE_DENSE_UNION: - return array_view->buffer_views[1].data.as_int32[i]; + return array_view->buffer_views[1].data.as_int32[array_view->offset + i]; case NANOARROW_TYPE_SPARSE_UNION: - return i; + return array_view->offset + i; default: return -1; } @@ -3581,6 +4031,20 @@ static inline int64_t ArrowArrayViewListChildOffset( } } +static struct ArrowBufferView ArrowArrayViewGetBytesFromViewArrayUnsafe( + const struct ArrowArrayView* array_view, int64_t i) { + const union ArrowBinaryView* bv = &array_view->buffer_views[1].data.as_binary_view[i]; + struct ArrowBufferView out = {{NULL}, bv->inlined.size}; + if (bv->inlined.size <= NANOARROW_BINARY_VIEW_INLINE_SIZE) { + out.data.as_uint8 = bv->inlined.data; + return out; + } + + out.data.data = array_view->variadic_buffers[bv->ref.buffer_index]; + out.data.as_uint8 += bv->ref.offset; + return out; +} + static inline int64_t ArrowArrayViewGetIntUnsafe(const struct ArrowArrayView* array_view, int64_t i) { const struct ArrowBufferView* data_view = &array_view->buffer_views[1]; @@ -3607,6 +4071,8 @@ static inline int64_t ArrowArrayViewGetIntUnsafe(const struct ArrowArrayView* ar return (int64_t)data_view->data.as_double[i]; case NANOARROW_TYPE_FLOAT: return (int64_t)data_view->data.as_float[i]; + case NANOARROW_TYPE_HALF_FLOAT: + return (int64_t)ArrowHalfFloatToFloat(data_view->data.as_uint16[i]); case NANOARROW_TYPE_BOOL: return ArrowBitGet(data_view->data.as_uint8, i); default: @@ -3640,6 +4106,8 @@ static inline uint64_t ArrowArrayViewGetUIntUnsafe( return (uint64_t)data_view->data.as_double[i]; case NANOARROW_TYPE_FLOAT: return (uint64_t)data_view->data.as_float[i]; + case NANOARROW_TYPE_HALF_FLOAT: + return (uint64_t)ArrowHalfFloatToFloat(data_view->data.as_uint16[i]); case NANOARROW_TYPE_BOOL: return ArrowBitGet(data_view->data.as_uint8, i); default: @@ -3672,6 +4140,8 @@ static inline double ArrowArrayViewGetDoubleUnsafe( return data_view->data.as_double[i]; case NANOARROW_TYPE_FLOAT: return data_view->data.as_float[i]; + case NANOARROW_TYPE_HALF_FLOAT: + return ArrowHalfFloatToFloat(data_view->data.as_uint16[i]); case NANOARROW_TYPE_BOOL: return ArrowBitGet(data_view->data.as_uint8, i); default: @@ -3703,6 +4173,14 @@ static inline struct ArrowStringView ArrowArrayViewGetStringUnsafe( view.size_bytes = array_view->layout.element_size_bits[1] / 8; view.data = array_view->buffer_views[1].data.as_char + (i * view.size_bytes); break; + case NANOARROW_TYPE_STRING_VIEW: + case NANOARROW_TYPE_BINARY_VIEW: { + struct ArrowBufferView buf_view = + ArrowArrayViewGetBytesFromViewArrayUnsafe(array_view, i); + view.data = buf_view.data.as_char; + view.size_bytes = buf_view.size_bytes; + break; + } default: view.data = NULL; view.size_bytes = 0; @@ -3737,6 +4215,10 @@ static inline struct ArrowBufferView ArrowArrayViewGetBytesUnsafe( view.data.as_uint8 = array_view->buffer_views[1].data.as_uint8 + (i * view.size_bytes); break; + case NANOARROW_TYPE_STRING_VIEW: + case NANOARROW_TYPE_BINARY_VIEW: + view = ArrowArrayViewGetBytesFromViewArrayUnsafe(array_view, i); + break; default: view.data.data = NULL; view.size_bytes = 0; diff --git a/3rd_party/apache-arrow-adbc/c/vendor/nanoarrow/nanoarrow.hpp b/3rd_party/apache-arrow-adbc/c/vendor/nanoarrow/nanoarrow.hpp index 49ba38f..16c2e55 100644 --- a/3rd_party/apache-arrow-adbc/c/vendor/nanoarrow/nanoarrow.hpp +++ b/3rd_party/apache-arrow-adbc/c/vendor/nanoarrow/nanoarrow.hpp @@ -15,11 +15,12 @@ // specific language governing permissions and limitations // under the License. +#include #include #include #include -#include "nanoarrow/nanoarrow.h" +#include "nanoarrow.h" #ifndef NANOARROW_HPP_INCLUDED #define NANOARROW_HPP_INCLUDED @@ -216,10 +217,16 @@ template class Unique { public: /// \brief Construct an invalid instance of T holding no resources - Unique() { init_pointer(&data_); } + Unique() { + std::memset(&data_, 0, sizeof(data_)); + init_pointer(&data_); + } /// \brief Move and take ownership of data - Unique(T* data) { move_pointer(data, &data_); } + Unique(T* data) { + std::memset(&data_, 0, sizeof(data_)); + move_pointer(data, &data_); + } /// \brief Move and take ownership of data wrapped by rhs Unique(Unique&& rhs) : Unique(rhs.get()) {} diff --git a/3rd_party/apache-arrow-adbc/c/vendor/vendor_nanoarrow.sh b/3rd_party/apache-arrow-adbc/c/vendor/vendor_nanoarrow.sh index f74ebde..9024090 100755 --- a/3rd_party/apache-arrow-adbc/c/vendor/vendor_nanoarrow.sh +++ b/3rd_party/apache-arrow-adbc/c/vendor/vendor_nanoarrow.sh @@ -21,7 +21,7 @@ main() { local -r repo_url="https://github.com/apache/arrow-nanoarrow" # Check releases page: https://github.com/apache/arrow-nanoarrow/releases/ - local -r commit_sha=c5fb10035c17b598e6fd688ad9eb7b874c7c631b + local -r commit_sha=33d2c8b973d8f8f424e02ac92ddeaace2a92f8dd echo "Fetching $commit_sha from $repo_url" SCRATCH=$(mktemp -d) @@ -34,21 +34,13 @@ main() { mkdir -p nanoarrow tar --strip-components 1 -C "$SCRATCH" -xf "$tarball" - # Build the bundle using cmake. We could also use the dist/ files - # but this allows us to add the symbol namespace and ensures that the - # resulting bundle is perfectly synchronized with the commit we've pulled. - pushd "$SCRATCH" - mkdir build && cd build - # Do not use "adbc" in the namespace name since our scripts expose all - # such symbols - cmake .. -DNANOARROW_BUNDLE=ON -DNANOARROW_NAMESPACE=Private - cmake --build . - cmake --install . --prefix=../dist-adbc - popd + # Build the bundle + python "$SCRATCH/ci/scripts/bundle.py" \ + --symbol-namespace=Private \ + --include-output-dir=nanoarrow \ + --source-output-dir=nanoarrow \ + --header-namespace= - cp "$SCRATCH/dist-adbc/nanoarrow.c" nanoarrow/ - cp "$SCRATCH/dist-adbc/nanoarrow.h" nanoarrow/ - cp "$SCRATCH/dist-adbc/nanoarrow.hpp" nanoarrow/ mv CMakeLists.nanoarrow.tmp nanoarrow/CMakeLists.txt } From 0023de55a0281e6ac8193a47a555c96e8ff5a3a5 Mon Sep 17 00:00:00 2001 From: Cocoa Date: Tue, 12 Nov 2024 01:08:35 +0000 Subject: [PATCH 2/6] driver: bump duckdb to `1.1.3` and adbc drivers to `1.3.0` --- lib/adbc_driver.ex | 87 +++++++++++++++++++++++++++++----------------- update.exs | 8 ++--- 2 files changed, 59 insertions(+), 36 deletions(-) diff --git a/lib/adbc_driver.ex b/lib/adbc_driver.ex index 2d634a5..936e0ac 100644 --- a/lib/adbc_driver.ex +++ b/lib/adbc_driver.ex @@ -6,126 +6,149 @@ defmodule Adbc.Driver do # == GENERATED CONSTANTS == - # Generated by update.exs at 2024-11-03T11:30:49. Do not change manually. + # Generated by update.exs at 2024-11-12T01:02:58. Do not change manually. @generated_driver_versions %{ - duckdb: "1.1.2", - sqlite: "1.2.0", - postgresql: "1.2.0", - flightsql: "1.2.0", - snowflake: "1.2.0" + duckdb: "1.1.3", + sqlite: "1.3.0", + postgresql: "1.3.0", + flightsql: "1.3.0", + snowflake: "1.3.0", + bigquery: "1.3.0" } @generated_driver_data %{ duckdb: %{ "aarch64-apple-darwin" => %{ url: - "https://github.com/duckdb/duckdb/releases/download/v1.1.2/libduckdb-osx-universal.zip" + "https://github.com/duckdb/duckdb/releases/download/v1.1.3/libduckdb-osx-universal.zip" }, "aarch64-linux-gnu" => %{ url: - "https://github.com/duckdb/duckdb/releases/download/v1.1.2/libduckdb-linux-aarch64.zip" + "https://github.com/duckdb/duckdb/releases/download/v1.1.3/libduckdb-linux-aarch64.zip" }, "aarch64-windows-msvc" => %{ url: - "https://github.com/duckdb/duckdb/releases/download/v1.1.2/libduckdb-windows-arm64.zip" + "https://github.com/duckdb/duckdb/releases/download/v1.1.3/libduckdb-windows-arm64.zip" }, "x86_64-apple-darwin" => %{ url: - "https://github.com/duckdb/duckdb/releases/download/v1.1.2/libduckdb-osx-universal.zip" + "https://github.com/duckdb/duckdb/releases/download/v1.1.3/libduckdb-osx-universal.zip" }, "x86_64-linux-gnu" => %{ - url: "https://github.com/duckdb/duckdb/releases/download/v1.1.2/libduckdb-linux-amd64.zip" + url: "https://github.com/duckdb/duckdb/releases/download/v1.1.3/libduckdb-linux-amd64.zip" }, "x86_64-windows-msvc" => %{ url: - "https://github.com/duckdb/duckdb/releases/download/v1.1.2/libduckdb-windows-amd64.zip" + "https://github.com/duckdb/duckdb/releases/download/v1.1.3/libduckdb-windows-amd64.zip" } }, sqlite: %{ "aarch64-apple-darwin" => %{ url: - "https://github.com/apache/arrow-adbc/releases/download/apache-arrow-adbc-14/adbc_driver_sqlite-1.2.0-py3-none-macosx_11_0_arm64.whl" + "https://github.com/apache/arrow-adbc/releases/download/apache-arrow-adbc-15-rc1/adbc_driver_sqlite-1.3.0-py3-none-macosx_11_0_arm64.whl" }, "aarch64-linux-gnu" => %{ url: - "https://github.com/apache/arrow-adbc/releases/download/apache-arrow-adbc-14/adbc_driver_sqlite-1.2.0-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl" + "https://github.com/apache/arrow-adbc/releases/download/apache-arrow-adbc-15-rc1/adbc_driver_sqlite-1.3.0-py3-none-manylinux_2_28_aarch64.whl" }, "x86_64-apple-darwin" => %{ url: - "https://github.com/apache/arrow-adbc/releases/download/apache-arrow-adbc-14/adbc_driver_sqlite-1.2.0-py3-none-macosx_10_15_x86_64.whl" + "https://github.com/apache/arrow-adbc/releases/download/apache-arrow-adbc-15-rc1/adbc_driver_sqlite-1.3.0-py3-none-macosx_10_15_x86_64.whl" }, "x86_64-linux-gnu" => %{ url: - "https://github.com/apache/arrow-adbc/releases/download/apache-arrow-adbc-14/adbc_driver_sqlite-1.2.0-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl" + "https://github.com/apache/arrow-adbc/releases/download/apache-arrow-adbc-15-rc1/adbc_driver_sqlite-1.3.0-py3-none-manylinux_2_28_x86_64.whl" }, "x86_64-windows-msvc" => %{ url: - "https://github.com/apache/arrow-adbc/releases/download/apache-arrow-adbc-14/adbc_driver_sqlite-1.2.0-py3-none-win_amd64.whl" + "https://github.com/apache/arrow-adbc/releases/download/apache-arrow-adbc-15-rc1/adbc_driver_sqlite-1.3.0-py3-none-win_amd64.whl" } }, postgresql: %{ "aarch64-apple-darwin" => %{ url: - "https://github.com/apache/arrow-adbc/releases/download/apache-arrow-adbc-14/adbc_driver_postgresql-1.2.0-py3-none-macosx_11_0_arm64.whl" + "https://github.com/apache/arrow-adbc/releases/download/apache-arrow-adbc-15-rc1/adbc_driver_postgresql-1.3.0-py3-none-macosx_11_0_arm64.whl" }, "aarch64-linux-gnu" => %{ url: - "https://github.com/apache/arrow-adbc/releases/download/apache-arrow-adbc-14/adbc_driver_postgresql-1.2.0-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl" + "https://github.com/apache/arrow-adbc/releases/download/apache-arrow-adbc-15-rc1/adbc_driver_postgresql-1.3.0-py3-none-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl" }, "x86_64-apple-darwin" => %{ url: - "https://github.com/apache/arrow-adbc/releases/download/apache-arrow-adbc-14/adbc_driver_postgresql-1.2.0-py3-none-macosx_10_15_x86_64.whl" + "https://github.com/apache/arrow-adbc/releases/download/apache-arrow-adbc-15-rc1/adbc_driver_postgresql-1.3.0-py3-none-macosx_10_15_x86_64.whl" }, "x86_64-linux-gnu" => %{ url: - "https://github.com/apache/arrow-adbc/releases/download/apache-arrow-adbc-14/adbc_driver_postgresql-1.2.0-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl" + "https://github.com/apache/arrow-adbc/releases/download/apache-arrow-adbc-15-rc1/adbc_driver_postgresql-1.3.0-py3-none-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl" }, "x86_64-windows-msvc" => %{ url: - "https://github.com/apache/arrow-adbc/releases/download/apache-arrow-adbc-14/adbc_driver_postgresql-1.2.0-py3-none-win_amd64.whl" + "https://github.com/apache/arrow-adbc/releases/download/apache-arrow-adbc-15-rc1/adbc_driver_postgresql-1.3.0-py3-none-win_amd64.whl" } }, flightsql: %{ "aarch64-apple-darwin" => %{ url: - "https://github.com/apache/arrow-adbc/releases/download/apache-arrow-adbc-14/adbc_driver_flightsql-1.2.0-py3-none-macosx_11_0_arm64.whl" + "https://github.com/apache/arrow-adbc/releases/download/apache-arrow-adbc-15-rc1/adbc_driver_flightsql-1.3.0-py3-none-macosx_11_0_arm64.whl" }, "aarch64-linux-gnu" => %{ url: - "https://github.com/apache/arrow-adbc/releases/download/apache-arrow-adbc-14/adbc_driver_flightsql-1.2.0-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl" + "https://github.com/apache/arrow-adbc/releases/download/apache-arrow-adbc-15-rc1/adbc_driver_flightsql-1.3.0-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl" }, "x86_64-apple-darwin" => %{ url: - "https://github.com/apache/arrow-adbc/releases/download/apache-arrow-adbc-14/adbc_driver_flightsql-1.2.0-py3-none-macosx_10_15_x86_64.whl" + "https://github.com/apache/arrow-adbc/releases/download/apache-arrow-adbc-15-rc1/adbc_driver_flightsql-1.3.0-py3-none-macosx_10_15_x86_64.whl" }, "x86_64-linux-gnu" => %{ url: - "https://github.com/apache/arrow-adbc/releases/download/apache-arrow-adbc-14/adbc_driver_flightsql-1.2.0-py3-none-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl" + "https://github.com/apache/arrow-adbc/releases/download/apache-arrow-adbc-15-rc1/adbc_driver_flightsql-1.3.0-py3-none-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_28_x86_64.whl" }, "x86_64-windows-msvc" => %{ url: - "https://github.com/apache/arrow-adbc/releases/download/apache-arrow-adbc-14/adbc_driver_flightsql-1.2.0-py3-none-win_amd64.whl" + "https://github.com/apache/arrow-adbc/releases/download/apache-arrow-adbc-15-rc1/adbc_driver_flightsql-1.3.0-py3-none-win_amd64.whl" } }, snowflake: %{ "aarch64-apple-darwin" => %{ url: - "https://github.com/apache/arrow-adbc/releases/download/apache-arrow-adbc-14/adbc_driver_snowflake-1.2.0-py3-none-macosx_11_0_arm64.whl" + "https://github.com/apache/arrow-adbc/releases/download/apache-arrow-adbc-15-rc1/adbc_driver_snowflake-1.3.0-py3-none-macosx_11_0_arm64.whl" }, "aarch64-linux-gnu" => %{ url: - "https://github.com/apache/arrow-adbc/releases/download/apache-arrow-adbc-14/adbc_driver_snowflake-1.2.0-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl" + "https://github.com/apache/arrow-adbc/releases/download/apache-arrow-adbc-15-rc1/adbc_driver_snowflake-1.3.0-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl" }, "x86_64-apple-darwin" => %{ url: - "https://github.com/apache/arrow-adbc/releases/download/apache-arrow-adbc-14/adbc_driver_snowflake-1.2.0-py3-none-macosx_10_15_x86_64.whl" + "https://github.com/apache/arrow-adbc/releases/download/apache-arrow-adbc-15-rc1/adbc_driver_snowflake-1.3.0-py3-none-macosx_10_15_x86_64.whl" }, "x86_64-linux-gnu" => %{ url: - "https://github.com/apache/arrow-adbc/releases/download/apache-arrow-adbc-14/adbc_driver_snowflake-1.2.0-py3-none-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl" + "https://github.com/apache/arrow-adbc/releases/download/apache-arrow-adbc-15-rc1/adbc_driver_snowflake-1.3.0-py3-none-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_28_x86_64.whl" }, "x86_64-windows-msvc" => %{ url: - "https://github.com/apache/arrow-adbc/releases/download/apache-arrow-adbc-14/adbc_driver_snowflake-1.2.0-py3-none-win_amd64.whl" + "https://github.com/apache/arrow-adbc/releases/download/apache-arrow-adbc-15-rc1/adbc_driver_snowflake-1.3.0-py3-none-win_amd64.whl" + } + }, + bigquery: %{ + "aarch64-apple-darwin" => %{ + url: + "https://github.com/apache/arrow-adbc/releases/download/apache-arrow-adbc-15-rc1/adbc_driver_bigquery-1.3.0-py3-none-macosx_11_0_arm64.whl" + }, + "aarch64-linux-gnu" => %{ + url: + "https://github.com/apache/arrow-adbc/releases/download/apache-arrow-adbc-15-rc1/adbc_driver_bigquery-1.3.0-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl" + }, + "x86_64-apple-darwin" => %{ + url: + "https://github.com/apache/arrow-adbc/releases/download/apache-arrow-adbc-15-rc1/adbc_driver_bigquery-1.3.0-py3-none-macosx_10_15_x86_64.whl" + }, + "x86_64-linux-gnu" => %{ + url: + "https://github.com/apache/arrow-adbc/releases/download/apache-arrow-adbc-15-rc1/adbc_driver_bigquery-1.3.0-py3-none-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_28_x86_64.whl" + }, + "x86_64-windows-msvc" => %{ + url: + "https://github.com/apache/arrow-adbc/releases/download/apache-arrow-adbc-15-rc1/adbc_driver_bigquery-1.3.0-py3-none-win_amd64.whl" } } } diff --git a/update.exs b/update.exs index b9b19ac..02d39ba 100644 --- a/update.exs +++ b/update.exs @@ -12,13 +12,13 @@ Mix.install([{:req, "~> 0.4"}]) defmodule Update do # To update duckdb driver, just bump this version # https://github.com/duckdb/duckdb/releases/ - @duckdb_version "1.1.2" + @duckdb_version "1.1.3" # To update ADBC drivers, bump the tag and version accordingly # https://github.com/apache/arrow-adbc/releases - @adbc_driver_version "1.2.0" - @adbc_tag "apache-arrow-adbc-14" - @adbc_drivers ~w(sqlite postgresql flightsql snowflake)a + @adbc_driver_version "1.3.0" + @adbc_tag "apache-arrow-adbc-15-rc1" + @adbc_drivers ~w(sqlite postgresql flightsql snowflake bigquery)a def versions do Map.new(@adbc_drivers, &{&1, @adbc_driver_version}) From 6bacc6c48fbd57be7b0692cb23f6bb9ddf996240 Mon Sep 17 00:00:00 2001 From: Cocoa Date: Tue, 12 Nov 2024 01:08:43 +0000 Subject: [PATCH 3/6] fix unit tests --- test/adbc_connection_test.exs | 19 ++----------------- test/adbc_test.exs | 2 +- 2 files changed, 3 insertions(+), 18 deletions(-) diff --git a/test/adbc_connection_test.exs b/test/adbc_connection_test.exs index 7c8552b..dbcb326 100644 --- a/test/adbc_connection_test.exs +++ b/test/adbc_connection_test.exs @@ -166,25 +166,10 @@ defmodule Adbc.Connection.Test do assert results = %Adbc.Result{ num_rows: nil, - data: [ - %Adbc.Column{ - data: [], - name: "catalog_name", - type: :string, - metadata: nil, - nullable: true - }, - %Adbc.Column{ - data: [], - name: "catalog_db_schemas", - type: :list, - metadata: nil, - nullable: true - } - ] + data: [] } = Adbc.Result.materialize(results) - assert %{"catalog_db_schemas" => [], "catalog_name" => []} == Adbc.Result.to_map(results) + assert %{} == Adbc.Result.to_map(results) end end diff --git a/test/adbc_test.exs b/test/adbc_test.exs index 1b94b6b..00911b2 100644 --- a/test/adbc_test.exs +++ b/test/adbc_test.exs @@ -12,7 +12,7 @@ defmodule AdbcTest do test "returns errors" do assert {:error, - "unknown driver :unknown, expected one of :duckdb, :flightsql, :postgresql, " <> _} = + "unknown driver :unknown, expected one of :bigquery, :duckdb, :flightsql, :postgresql, " <> _} = Adbc.download_driver(:unknown) end end From a69433315ce4cb3d276847e9d42fadb8f3a79196 Mon Sep 17 00:00:00 2001 From: Cocoa Date: Tue, 12 Nov 2024 01:08:58 +0000 Subject: [PATCH 4/6] start `v0.6.5-dev` --- mix.exs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mix.exs b/mix.exs index 6fb6919..20fb1a3 100644 --- a/mix.exs +++ b/mix.exs @@ -1,7 +1,7 @@ defmodule Adbc.MixProject do use Mix.Project - @version "0.6.4" + @version "0.6.5-dev" @github_url "https://github.com/elixir-explorer/adbc" def project do From 907f316ce094873476667a9a02096846bdc0f21c Mon Sep 17 00:00:00 2001 From: Cocoa Date: Thu, 21 Nov 2024 18:48:10 +0000 Subject: [PATCH 5/6] updated CHANGELOG.md Signed-off-by: Cocoa --- CHANGELOG.md | 8 ++++++-- mix.exs | 2 +- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index cdb4b31..a2f4466 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,13 +4,17 @@ #### Added -- Added an environment variable `ADBC_PREFER_PRECOMPILED`. Set to `false` to force compile locally. +* Added an environment variable `ADBC_PREFER_PRECOMPILED`. Set to `false` to force compile locally. + +#### Changed + +* Updated to ADBC library 15. ## v0.6.4 #### Fixes -- Fixed the issue with the `clean` target in the Makefile not removing all relevant files (#109) +* Fixed the issue with the `clean` target in the Makefile not removing all relevant files (#109) ## v0.6.3 diff --git a/mix.exs b/mix.exs index 262c5f0..2050071 100644 --- a/mix.exs +++ b/mix.exs @@ -1,7 +1,7 @@ defmodule Adbc.MixProject do use Mix.Project - @version "0.6.6-dev" + @version "0.6.5-dev" @github_url "https://github.com/elixir-explorer/adbc" def project do From 4fc862a6ea017c69a46362e6540d7a050cd3d634 Mon Sep 17 00:00:00 2001 From: Cocoa Date: Thu, 21 Nov 2024 18:50:14 +0000 Subject: [PATCH 6/6] updated @adbc_tag to `apache-arrow-adbc-15` Signed-off-by: Cocoa --- lib/adbc_driver.ex | 52 +++++++++++++++++++++++----------------------- update.exs | 2 +- 2 files changed, 27 insertions(+), 27 deletions(-) diff --git a/lib/adbc_driver.ex b/lib/adbc_driver.ex index 936e0ac..d7fc5cd 100644 --- a/lib/adbc_driver.ex +++ b/lib/adbc_driver.ex @@ -6,7 +6,7 @@ defmodule Adbc.Driver do # == GENERATED CONSTANTS == - # Generated by update.exs at 2024-11-12T01:02:58. Do not change manually. + # Generated by update.exs at 2024-11-21T18:49:45. Do not change manually. @generated_driver_versions %{ duckdb: "1.1.3", sqlite: "1.3.0", @@ -44,111 +44,111 @@ defmodule Adbc.Driver do sqlite: %{ "aarch64-apple-darwin" => %{ url: - "https://github.com/apache/arrow-adbc/releases/download/apache-arrow-adbc-15-rc1/adbc_driver_sqlite-1.3.0-py3-none-macosx_11_0_arm64.whl" + "https://github.com/apache/arrow-adbc/releases/download/apache-arrow-adbc-15/adbc_driver_sqlite-1.3.0-py3-none-macosx_11_0_arm64.whl" }, "aarch64-linux-gnu" => %{ url: - "https://github.com/apache/arrow-adbc/releases/download/apache-arrow-adbc-15-rc1/adbc_driver_sqlite-1.3.0-py3-none-manylinux_2_28_aarch64.whl" + "https://github.com/apache/arrow-adbc/releases/download/apache-arrow-adbc-15/adbc_driver_sqlite-1.3.0-py3-none-manylinux_2_28_aarch64.whl" }, "x86_64-apple-darwin" => %{ url: - "https://github.com/apache/arrow-adbc/releases/download/apache-arrow-adbc-15-rc1/adbc_driver_sqlite-1.3.0-py3-none-macosx_10_15_x86_64.whl" + "https://github.com/apache/arrow-adbc/releases/download/apache-arrow-adbc-15/adbc_driver_sqlite-1.3.0-py3-none-macosx_10_15_x86_64.whl" }, "x86_64-linux-gnu" => %{ url: - "https://github.com/apache/arrow-adbc/releases/download/apache-arrow-adbc-15-rc1/adbc_driver_sqlite-1.3.0-py3-none-manylinux_2_28_x86_64.whl" + "https://github.com/apache/arrow-adbc/releases/download/apache-arrow-adbc-15/adbc_driver_sqlite-1.3.0-py3-none-manylinux_2_28_x86_64.whl" }, "x86_64-windows-msvc" => %{ url: - "https://github.com/apache/arrow-adbc/releases/download/apache-arrow-adbc-15-rc1/adbc_driver_sqlite-1.3.0-py3-none-win_amd64.whl" + "https://github.com/apache/arrow-adbc/releases/download/apache-arrow-adbc-15/adbc_driver_sqlite-1.3.0-py3-none-win_amd64.whl" } }, postgresql: %{ "aarch64-apple-darwin" => %{ url: - "https://github.com/apache/arrow-adbc/releases/download/apache-arrow-adbc-15-rc1/adbc_driver_postgresql-1.3.0-py3-none-macosx_11_0_arm64.whl" + "https://github.com/apache/arrow-adbc/releases/download/apache-arrow-adbc-15/adbc_driver_postgresql-1.3.0-py3-none-macosx_11_0_arm64.whl" }, "aarch64-linux-gnu" => %{ url: - "https://github.com/apache/arrow-adbc/releases/download/apache-arrow-adbc-15-rc1/adbc_driver_postgresql-1.3.0-py3-none-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl" + "https://github.com/apache/arrow-adbc/releases/download/apache-arrow-adbc-15/adbc_driver_postgresql-1.3.0-py3-none-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl" }, "x86_64-apple-darwin" => %{ url: - "https://github.com/apache/arrow-adbc/releases/download/apache-arrow-adbc-15-rc1/adbc_driver_postgresql-1.3.0-py3-none-macosx_10_15_x86_64.whl" + "https://github.com/apache/arrow-adbc/releases/download/apache-arrow-adbc-15/adbc_driver_postgresql-1.3.0-py3-none-macosx_10_15_x86_64.whl" }, "x86_64-linux-gnu" => %{ url: - "https://github.com/apache/arrow-adbc/releases/download/apache-arrow-adbc-15-rc1/adbc_driver_postgresql-1.3.0-py3-none-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl" + "https://github.com/apache/arrow-adbc/releases/download/apache-arrow-adbc-15/adbc_driver_postgresql-1.3.0-py3-none-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl" }, "x86_64-windows-msvc" => %{ url: - "https://github.com/apache/arrow-adbc/releases/download/apache-arrow-adbc-15-rc1/adbc_driver_postgresql-1.3.0-py3-none-win_amd64.whl" + "https://github.com/apache/arrow-adbc/releases/download/apache-arrow-adbc-15/adbc_driver_postgresql-1.3.0-py3-none-win_amd64.whl" } }, flightsql: %{ "aarch64-apple-darwin" => %{ url: - "https://github.com/apache/arrow-adbc/releases/download/apache-arrow-adbc-15-rc1/adbc_driver_flightsql-1.3.0-py3-none-macosx_11_0_arm64.whl" + "https://github.com/apache/arrow-adbc/releases/download/apache-arrow-adbc-15/adbc_driver_flightsql-1.3.0-py3-none-macosx_11_0_arm64.whl" }, "aarch64-linux-gnu" => %{ url: - "https://github.com/apache/arrow-adbc/releases/download/apache-arrow-adbc-15-rc1/adbc_driver_flightsql-1.3.0-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl" + "https://github.com/apache/arrow-adbc/releases/download/apache-arrow-adbc-15/adbc_driver_flightsql-1.3.0-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl" }, "x86_64-apple-darwin" => %{ url: - "https://github.com/apache/arrow-adbc/releases/download/apache-arrow-adbc-15-rc1/adbc_driver_flightsql-1.3.0-py3-none-macosx_10_15_x86_64.whl" + "https://github.com/apache/arrow-adbc/releases/download/apache-arrow-adbc-15/adbc_driver_flightsql-1.3.0-py3-none-macosx_10_15_x86_64.whl" }, "x86_64-linux-gnu" => %{ url: - "https://github.com/apache/arrow-adbc/releases/download/apache-arrow-adbc-15-rc1/adbc_driver_flightsql-1.3.0-py3-none-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_28_x86_64.whl" + "https://github.com/apache/arrow-adbc/releases/download/apache-arrow-adbc-15/adbc_driver_flightsql-1.3.0-py3-none-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_28_x86_64.whl" }, "x86_64-windows-msvc" => %{ url: - "https://github.com/apache/arrow-adbc/releases/download/apache-arrow-adbc-15-rc1/adbc_driver_flightsql-1.3.0-py3-none-win_amd64.whl" + "https://github.com/apache/arrow-adbc/releases/download/apache-arrow-adbc-15/adbc_driver_flightsql-1.3.0-py3-none-win_amd64.whl" } }, snowflake: %{ "aarch64-apple-darwin" => %{ url: - "https://github.com/apache/arrow-adbc/releases/download/apache-arrow-adbc-15-rc1/adbc_driver_snowflake-1.3.0-py3-none-macosx_11_0_arm64.whl" + "https://github.com/apache/arrow-adbc/releases/download/apache-arrow-adbc-15/adbc_driver_snowflake-1.3.0-py3-none-macosx_11_0_arm64.whl" }, "aarch64-linux-gnu" => %{ url: - "https://github.com/apache/arrow-adbc/releases/download/apache-arrow-adbc-15-rc1/adbc_driver_snowflake-1.3.0-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl" + "https://github.com/apache/arrow-adbc/releases/download/apache-arrow-adbc-15/adbc_driver_snowflake-1.3.0-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl" }, "x86_64-apple-darwin" => %{ url: - "https://github.com/apache/arrow-adbc/releases/download/apache-arrow-adbc-15-rc1/adbc_driver_snowflake-1.3.0-py3-none-macosx_10_15_x86_64.whl" + "https://github.com/apache/arrow-adbc/releases/download/apache-arrow-adbc-15/adbc_driver_snowflake-1.3.0-py3-none-macosx_10_15_x86_64.whl" }, "x86_64-linux-gnu" => %{ url: - "https://github.com/apache/arrow-adbc/releases/download/apache-arrow-adbc-15-rc1/adbc_driver_snowflake-1.3.0-py3-none-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_28_x86_64.whl" + "https://github.com/apache/arrow-adbc/releases/download/apache-arrow-adbc-15/adbc_driver_snowflake-1.3.0-py3-none-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_28_x86_64.whl" }, "x86_64-windows-msvc" => %{ url: - "https://github.com/apache/arrow-adbc/releases/download/apache-arrow-adbc-15-rc1/adbc_driver_snowflake-1.3.0-py3-none-win_amd64.whl" + "https://github.com/apache/arrow-adbc/releases/download/apache-arrow-adbc-15/adbc_driver_snowflake-1.3.0-py3-none-win_amd64.whl" } }, bigquery: %{ "aarch64-apple-darwin" => %{ url: - "https://github.com/apache/arrow-adbc/releases/download/apache-arrow-adbc-15-rc1/adbc_driver_bigquery-1.3.0-py3-none-macosx_11_0_arm64.whl" + "https://github.com/apache/arrow-adbc/releases/download/apache-arrow-adbc-15/adbc_driver_bigquery-1.3.0-py3-none-macosx_11_0_arm64.whl" }, "aarch64-linux-gnu" => %{ url: - "https://github.com/apache/arrow-adbc/releases/download/apache-arrow-adbc-15-rc1/adbc_driver_bigquery-1.3.0-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl" + "https://github.com/apache/arrow-adbc/releases/download/apache-arrow-adbc-15/adbc_driver_bigquery-1.3.0-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl" }, "x86_64-apple-darwin" => %{ url: - "https://github.com/apache/arrow-adbc/releases/download/apache-arrow-adbc-15-rc1/adbc_driver_bigquery-1.3.0-py3-none-macosx_10_15_x86_64.whl" + "https://github.com/apache/arrow-adbc/releases/download/apache-arrow-adbc-15/adbc_driver_bigquery-1.3.0-py3-none-macosx_10_15_x86_64.whl" }, "x86_64-linux-gnu" => %{ url: - "https://github.com/apache/arrow-adbc/releases/download/apache-arrow-adbc-15-rc1/adbc_driver_bigquery-1.3.0-py3-none-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_28_x86_64.whl" + "https://github.com/apache/arrow-adbc/releases/download/apache-arrow-adbc-15/adbc_driver_bigquery-1.3.0-py3-none-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_28_x86_64.whl" }, "x86_64-windows-msvc" => %{ url: - "https://github.com/apache/arrow-adbc/releases/download/apache-arrow-adbc-15-rc1/adbc_driver_bigquery-1.3.0-py3-none-win_amd64.whl" + "https://github.com/apache/arrow-adbc/releases/download/apache-arrow-adbc-15/adbc_driver_bigquery-1.3.0-py3-none-win_amd64.whl" } } } diff --git a/update.exs b/update.exs index 02d39ba..c030b2c 100644 --- a/update.exs +++ b/update.exs @@ -17,7 +17,7 @@ defmodule Update do # To update ADBC drivers, bump the tag and version accordingly # https://github.com/apache/arrow-adbc/releases @adbc_driver_version "1.3.0" - @adbc_tag "apache-arrow-adbc-15-rc1" + @adbc_tag "apache-arrow-adbc-15" @adbc_drivers ~w(sqlite postgresql flightsql snowflake bigquery)a def versions do