diff --git a/cpp/src/arrow/array/array_binary_test.cc b/cpp/src/arrow/array/array_binary_test.cc index 04391be0ac78..37d052888d31 100644 --- a/cpp/src/arrow/array/array_binary_test.cc +++ b/cpp/src/arrow/array/array_binary_test.cc @@ -20,6 +20,7 @@ #include #include #include +#include #include #include @@ -29,6 +30,7 @@ #include "arrow/array/builder_binary.h" #include "arrow/array/validate.h" #include "arrow/buffer.h" +#include "arrow/extension_type.h" #include "arrow/memory_pool.h" #include "arrow/status.h" #include "arrow/testing/builder.h" @@ -48,6 +50,35 @@ namespace arrow { using internal::checked_cast; +class BinaryExtensionType : public ExtensionType { + public: + BinaryExtensionType(std::shared_ptr storage_type, std::string extension_name) + : ExtensionType(std::move(storage_type)), + extension_name_(std::move(extension_name)) {} + + std::string extension_name() const override { return extension_name_; } + + bool ExtensionEquals(const ExtensionType& other) const override { + return other.extension_name() == this->extension_name() && + storage_type()->Equals(*other.storage_type()); + } + + std::shared_ptr MakeArray(std::shared_ptr data) const override { + return std::make_shared(std::move(data)); + } + + Result> Deserialize( + std::shared_ptr storage_type, + const std::string& serialized) const override { + return Status::NotImplemented("BinaryExtensionType::Deserialize"); + } + + std::string Serialize() const override { return ""; } + + private: + std::string extension_name_; +}; + // ---------------------------------------------------------------------- // String / Binary tests @@ -802,6 +833,77 @@ TYPED_TEST(TestStringBuilder, TestZeroLength) { this->TestZeroLength(); } TYPED_TEST(TestStringBuilder, TestOverflowCheck) { this->TestOverflowCheck(); } +TEST(BinaryBuilder, PreservesDataType) { + auto type = std::make_shared(binary(), "test.binary"); + BinaryBuilder builder(type, default_memory_pool()); + + AssertTypeEqual(*type, *builder.type()); + ASSERT_OK(builder.Append("abc")); + ASSERT_OK(builder.AppendNull()); + + ASSERT_OK_AND_ASSIGN(auto array, builder.Finish()); + ASSERT_OK(array->ValidateFull()); + AssertTypeEqual(*type, *array->type()); + ASSERT_EQ(Type::EXTENSION, array->type_id()); + + auto extension_array = std::static_pointer_cast(array); + auto storage = std::static_pointer_cast(extension_array->storage()); + AssertTypeEqual(*binary(), *storage->type()); + ASSERT_EQ("abc", storage->GetString(0)); + ASSERT_TRUE(storage->IsNull(1)); +} + +TEST(BinaryBuilder, TypedFinishRejectsExtensionType) { + auto type = std::make_shared(binary(), "test.binary"); + BinaryBuilder builder(type, default_memory_pool()); + ASSERT_OK(builder.Append("abc")); + + std::shared_ptr binary_array; + ASSERT_RAISES(TypeError, builder.Finish(&binary_array)); + ASSERT_EQ(nullptr, binary_array); + + ASSERT_OK_AND_ASSIGN(auto array, builder.Finish()); + AssertTypeEqual(*type, *array->type()); + ASSERT_EQ(1, array->length()); +} + +TEST(LargeBinaryBuilder, PreservesDataType) { + auto type = + std::make_shared(large_binary(), "test.large_binary"); + LargeBinaryBuilder builder(type, default_memory_pool()); + + AssertTypeEqual(*type, *builder.type()); + ASSERT_OK(builder.Append("abc")); + ASSERT_OK(builder.AppendNull()); + + ASSERT_OK_AND_ASSIGN(auto array, builder.Finish()); + ASSERT_OK(array->ValidateFull()); + AssertTypeEqual(*type, *array->type()); + ASSERT_EQ(Type::EXTENSION, array->type_id()); + + auto extension_array = std::static_pointer_cast(array); + auto storage = + std::static_pointer_cast(extension_array->storage()); + AssertTypeEqual(*large_binary(), *storage->type()); + ASSERT_EQ("abc", storage->GetString(0)); + ASSERT_TRUE(storage->IsNull(1)); +} + +TEST(LargeBinaryBuilder, TypedFinishRejectsExtensionType) { + auto type = + std::make_shared(large_binary(), "test.large_binary"); + LargeBinaryBuilder builder(type, default_memory_pool()); + ASSERT_OK(builder.Append("abc")); + + std::shared_ptr large_binary_array; + ASSERT_RAISES(TypeError, builder.Finish(&large_binary_array)); + ASSERT_EQ(nullptr, large_binary_array); + + ASSERT_OK_AND_ASSIGN(auto array, builder.Finish()); + AssertTypeEqual(*type, *array->type()); + ASSERT_EQ(1, array->length()); +} + // ---------------------------------------------------------------------- // ChunkedBinaryBuilder tests diff --git a/cpp/src/arrow/array/builder_binary.h b/cpp/src/arrow/array/builder_binary.h index d0e761ae9684..51c994a1c02a 100644 --- a/cpp/src/arrow/array/builder_binary.h +++ b/cpp/src/arrow/array/builder_binary.h @@ -34,9 +34,13 @@ #include "arrow/array/data.h" #include "arrow/buffer.h" #include "arrow/buffer_builder.h" +#include "arrow/extension_type.h" #include "arrow/status.h" #include "arrow/type.h" +#include "arrow/type_traits.h" #include "arrow/util/binary_view_util.h" +#include "arrow/util/checked_cast.h" +#include "arrow/util/logging.h" #include "arrow/util/macros.h" #include "arrow/util/visibility.h" @@ -60,11 +64,15 @@ class BaseBinaryBuilder explicit BaseBinaryBuilder(MemoryPool* pool = default_memory_pool(), int64_t alignment = kDefaultBufferAlignment) : ArrayBuilder(pool, alignment), + type_(TypeTraits::type_singleton()), offsets_builder_(pool, alignment), value_data_builder_(pool, alignment) {} BaseBinaryBuilder(const std::shared_ptr& type, MemoryPool* pool) - : BaseBinaryBuilder(pool) {} + : ArrayBuilder(pool), + type_(ValidateType(type)), + offsets_builder_(pool), + value_data_builder_(pool) {} Status Append(const uint8_t* value, offset_type length) { ARROW_RETURN_NOT_OK(Reserve(1)); @@ -356,6 +364,8 @@ class BaseBinaryBuilder /// \return capacity of values buffer int64_t value_data_capacity() const { return value_data_builder_.capacity(); } + std::shared_ptr type() const override { return type_; } + /// \return data pointer of the value date builder const offset_type* offsets_data() const { return offsets_builder_.data(); } @@ -390,6 +400,32 @@ class BaseBinaryBuilder } protected: + template + static std::shared_ptr ValidateType(std::shared_ptr type) { + ARROW_CHECK(type != nullptr) << "Cannot construct binary builder with null type"; + const auto expected_type = TypeTraits::type_singleton(); + const DataType* storage_type = type.get(); + if (type->id() == Type::EXTENSION) { + storage_type = + internal::checked_cast(*type).storage_type().get(); + } + ARROW_CHECK(storage_type->Equals(*expected_type)) + << "Cannot construct binary builder for " << expected_type->ToString() + << " values from type " << type->ToString(); + return type; + } + + template + Status FinishTypedAs(std::shared_ptr* out) { + const auto expected_type = TypeTraits::type_singleton(); + if (ARROW_PREDICT_FALSE(!type()->Equals(*expected_type))) { + return Status::TypeError("Cannot finish builder with type ", type()->ToString(), + " as ", expected_type->ToString(), " array"); + } + return FinishTyped(out); + } + + std::shared_ptr type_; TypedBufferBuilder offsets_builder_; TypedBufferBuilder value_data_builder_; @@ -414,24 +450,35 @@ class ARROW_EXPORT BinaryBuilder : public BaseBinaryBuilder { using ArrayBuilder::Finish; /// \endcond - Status Finish(std::shared_ptr* out) { return FinishTyped(out); } - - std::shared_ptr type() const override { return binary(); } + Status Finish(std::shared_ptr* out) { + return FinishTypedAs(out); + } }; /// \class StringBuilder /// \brief Builder class for UTF8 strings class ARROW_EXPORT StringBuilder : public BinaryBuilder { public: - using BinaryBuilder::BinaryBuilder; + explicit StringBuilder(MemoryPool* pool = default_memory_pool(), + int64_t alignment = kDefaultBufferAlignment) + : BinaryBuilder(pool, alignment) { + type_ = utf8(); + } + + StringBuilder(const std::shared_ptr& type, MemoryPool* pool) + : BinaryBuilder(pool) { + type_ = ValidateType(type); + } /// \cond FALSE using ArrayBuilder::Finish; /// \endcond - Status Finish(std::shared_ptr* out) { return FinishTyped(out); } + Status Finish(std::shared_ptr* out) { + return FinishTypedAs(out); + } - std::shared_ptr type() const override { return utf8(); } + std::shared_ptr type() const override { return type_; } }; /// \class LargeBinaryBuilder @@ -444,24 +491,35 @@ class ARROW_EXPORT LargeBinaryBuilder : public BaseBinaryBuilder* out) { return FinishTyped(out); } - - std::shared_ptr type() const override { return large_binary(); } + Status Finish(std::shared_ptr* out) { + return FinishTypedAs(out); + } }; /// \class LargeStringBuilder /// \brief Builder class for large UTF8 strings class ARROW_EXPORT LargeStringBuilder : public LargeBinaryBuilder { public: - using LargeBinaryBuilder::LargeBinaryBuilder; + explicit LargeStringBuilder(MemoryPool* pool = default_memory_pool(), + int64_t alignment = kDefaultBufferAlignment) + : LargeBinaryBuilder(pool, alignment) { + type_ = large_utf8(); + } + + LargeStringBuilder(const std::shared_ptr& type, MemoryPool* pool) + : LargeBinaryBuilder(pool) { + type_ = ValidateType(type); + } /// \cond FALSE using ArrayBuilder::Finish; /// \endcond - Status Finish(std::shared_ptr* out) { return FinishTyped(out); } + Status Finish(std::shared_ptr* out) { + return FinishTypedAs(out); + } - std::shared_ptr type() const override { return large_utf8(); } + std::shared_ptr type() const override { return type_; } }; // ----------------------------------------------------------------------