diff --git a/cpp/src/arrow/util/hashing.h b/cpp/src/arrow/util/hashing.h index 53f92c8f23d2..cba39dffdb15 100644 --- a/cpp/src/arrow/util/hashing.h +++ b/cpp/src/arrow/util/hashing.h @@ -286,15 +286,16 @@ class HashTable { uint64_t size() const { return size_; } // Visit all non-empty entries in the table - // The visit_func should have signature void(const Entry*) + // The visit_func should have signature Status(const Entry*) template - void VisitEntries(VisitFunc&& visit_func) const { + Status VisitEntries(VisitFunc&& visit_func) const { for (uint64_t i = 0; i < capacity_; i++) { const auto& entry = entries_[i]; if (entry) { - visit_func(&entry); + RETURN_NOT_OK(visit_func(&entry)); } } + return Status::OK(); } protected: @@ -494,12 +495,13 @@ class ScalarMemoTable : public MemoTable { // So that both uint16_t and Float16 are allowed static_assert(sizeof(Value) == sizeof(Scalar)); Scalar* out = reinterpret_cast(out_data); - hash_table_.VisitEntries([=](const HashTableEntry* entry) { + ARROW_DCHECK_OK(hash_table_.VisitEntries([=](const HashTableEntry* entry) { int32_t index = entry->payload.memo_index - start; if (index >= 0) { out[index] = entry->payload.value; } - }); + return Status::OK(); + })); // Zero-initialize the null entry if (null_index_ != kKeyNotFound) { int32_t index = null_index_ - start; @@ -534,12 +536,10 @@ class ScalarMemoTable : public MemoTable { // Merge entries from `other_table` into `this->hash_table_`. Status MergeTable(const ScalarMemoTable& other_table) { const HashTableType& other_hashtable = other_table.hash_table_; - - other_hashtable.VisitEntries([this](const HashTableEntry* other_entry) { + RETURN_NOT_OK(other_hashtable.VisitEntries([this](const HashTableEntry* other_entry) { int32_t unused; - ARROW_DCHECK_OK(this->GetOrInsert(other_entry->payload.value, &unused)); - }); - // TODO: ARROW-17074 - implement proper error handling + return this->GetOrInsert(other_entry->payload.value, &unused); + })); return Status::OK(); } }; @@ -899,11 +899,15 @@ class BinaryMemoTable : public MemoTable { public: Status MergeTable(const BinaryMemoTable& other_table) { - other_table.VisitValues(0, [this](std::string_view other_value) { + Status status = Status::OK(); + other_table.VisitValues(0, [this, &status](std::string_view other_value) { + if (ARROW_PREDICT_FALSE(!status.ok())) { + return; + } int32_t unused; - ARROW_DCHECK_OK(this->GetOrInsert(other_value, &unused)); + status = this->GetOrInsert(other_value, &unused); }); - return Status::OK(); + return status; } }; diff --git a/cpp/src/arrow/util/hashing_test.cc b/cpp/src/arrow/util/hashing_test.cc index f6ada0acd2d0..6e4c59a1ebd2 100644 --- a/cpp/src/arrow/util/hashing_test.cc +++ b/cpp/src/arrow/util/hashing_test.cc @@ -30,6 +30,7 @@ #include "arrow/array/builder_primitive.h" #include "arrow/array/concatenate.h" +#include "arrow/memory_pool.h" #include "arrow/testing/gtest_util.h" #include "arrow/util/bit_util.h" #include "arrow/util/hashing.h" @@ -376,6 +377,32 @@ TEST(ScalarMemoTable, StressInt64) { ASSERT_EQ(table.size(), map.size()); } +TEST(ScalarMemoTable, MergeTablePropagatesInsertError) { + int64_t bytes_allocated_limit = 0; + { + ProxyMemoryPool probe(default_memory_pool()); + ScalarMemoTable target(&probe, 0); + for (int64_t value = 0; value < 15; ++value) { + AssertGetOrInsert(target, value, static_cast(value)); + } + bytes_allocated_limit = probe.bytes_allocated(); + } + ASSERT_GT(bytes_allocated_limit, 0); + + ScalarMemoTable source(default_memory_pool(), 0); + AssertGetOrInsert(source, 15, 0); + + ProxyMemoryPool proxy(default_memory_pool()); + CappedMemoryPool pool(&proxy, bytes_allocated_limit); + ScalarMemoTable target(&pool, 0); + for (int64_t value = 0; value < 15; ++value) { + AssertGetOrInsert(target, value, static_cast(value)); + } + ASSERT_EQ(proxy.bytes_allocated(), bytes_allocated_limit); + + ASSERT_RAISES(OutOfMemory, target.MergeTable(source)); +} + TEST(BinaryMemoTable, Basics) { std::string A = "", B = "a", C = "foo", D = "bar", E, F; E += '\0'; @@ -480,6 +507,35 @@ TEST(BinaryMemoTable, Stress) { ASSERT_EQ(table.size(), map.size()); } +TEST(BinaryMemoTable, MergeTablePropagatesInsertError) { + const std::vector initial_values = {"a", "bb", "ccc", "dddd"}; + const std::string extra_value(4096, 'x'); + + int64_t bytes_allocated_limit = 0; + { + ProxyMemoryPool probe(default_memory_pool()); + BinaryMemoTable target(&probe, 0); + for (size_t i = 0; i < initial_values.size(); ++i) { + AssertGetOrInsert(target, initial_values[i], static_cast(i)); + } + bytes_allocated_limit = probe.bytes_allocated(); + } + ASSERT_GT(bytes_allocated_limit, 0); + + BinaryMemoTable source(default_memory_pool(), 0); + AssertGetOrInsert(source, extra_value, 0); + + ProxyMemoryPool proxy(default_memory_pool()); + CappedMemoryPool pool(&proxy, bytes_allocated_limit); + BinaryMemoTable target(&pool, 0); + for (size_t i = 0; i < initial_values.size(); ++i) { + AssertGetOrInsert(target, initial_values[i], static_cast(i)); + } + ASSERT_EQ(proxy.bytes_allocated(), bytes_allocated_limit); + + ASSERT_RAISES(OutOfMemory, target.MergeTable(source)); +} + TEST(BinaryMemoTable, Empty) { BinaryMemoTable table(default_memory_pool()); ASSERT_EQ(table.size(), 0);