Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 17 additions & 13 deletions cpp/src/arrow/util/hashing.h
Original file line number Diff line number Diff line change
Expand Up @@ -286,15 +286,16 @@ class HashTable {
uint64_t size() const { return size_; }

// Visit all non-empty entries in the table
// The visit_func should have signature void(const Entry*)
// The visit_func should have signature Status(const Entry*)
template <typename VisitFunc>
void VisitEntries(VisitFunc&& visit_func) const {
Status VisitEntries(VisitFunc&& visit_func) const {
for (uint64_t i = 0; i < capacity_; i++) {
const auto& entry = entries_[i];
if (entry) {
visit_func(&entry);
RETURN_NOT_OK(visit_func(&entry));
}
}
return Status::OK();
}

protected:
Expand Down Expand Up @@ -494,12 +495,13 @@ class ScalarMemoTable : public MemoTable {
// So that both uint16_t and Float16 are allowed
static_assert(sizeof(Value) == sizeof(Scalar));
Scalar* out = reinterpret_cast<Scalar*>(out_data);
hash_table_.VisitEntries([=](const HashTableEntry* entry) {
ARROW_DCHECK_OK(hash_table_.VisitEntries([=](const HashTableEntry* entry) {
int32_t index = entry->payload.memo_index - start;
if (index >= 0) {
out[index] = entry->payload.value;
}
});
return Status::OK();
}));
// Zero-initialize the null entry
if (null_index_ != kKeyNotFound) {
int32_t index = null_index_ - start;
Expand Down Expand Up @@ -534,12 +536,10 @@ class ScalarMemoTable : public MemoTable {
// Merge entries from `other_table` into `this->hash_table_`.
Status MergeTable(const ScalarMemoTable& other_table) {
const HashTableType& other_hashtable = other_table.hash_table_;

other_hashtable.VisitEntries([this](const HashTableEntry* other_entry) {
RETURN_NOT_OK(other_hashtable.VisitEntries([this](const HashTableEntry* other_entry) {
int32_t unused;
ARROW_DCHECK_OK(this->GetOrInsert(other_entry->payload.value, &unused));
});
// TODO: ARROW-17074 - implement proper error handling
return this->GetOrInsert(other_entry->payload.value, &unused);
}));
return Status::OK();
}
};
Expand Down Expand Up @@ -899,11 +899,15 @@ class BinaryMemoTable : public MemoTable {

public:
Status MergeTable(const BinaryMemoTable& other_table) {
other_table.VisitValues(0, [this](std::string_view other_value) {
Status status = Status::OK();
other_table.VisitValues(0, [this, &status](std::string_view other_value) {
if (ARROW_PREDICT_FALSE(!status.ok())) {
return;
}
int32_t unused;
ARROW_DCHECK_OK(this->GetOrInsert(other_value, &unused));
status = this->GetOrInsert(other_value, &unused);
});
return Status::OK();
return status;
}
};

Expand Down
56 changes: 56 additions & 0 deletions cpp/src/arrow/util/hashing_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

#include "arrow/array/builder_primitive.h"
#include "arrow/array/concatenate.h"
#include "arrow/memory_pool.h"
#include "arrow/testing/gtest_util.h"
#include "arrow/util/bit_util.h"
#include "arrow/util/hashing.h"
Expand Down Expand Up @@ -376,6 +377,32 @@ TEST(ScalarMemoTable, StressInt64) {
ASSERT_EQ(table.size(), map.size());
}

TEST(ScalarMemoTable, MergeTablePropagatesInsertError) {
Comment thread
kris-gaudel marked this conversation as resolved.
int64_t bytes_allocated_limit = 0;
{
ProxyMemoryPool probe(default_memory_pool());
ScalarMemoTable<int64_t> target(&probe, 0);
for (int64_t value = 0; value < 15; ++value) {
AssertGetOrInsert(target, value, static_cast<int32_t>(value));
}
bytes_allocated_limit = probe.bytes_allocated();
}
ASSERT_GT(bytes_allocated_limit, 0);

ScalarMemoTable<int64_t> source(default_memory_pool(), 0);
AssertGetOrInsert(source, 15, 0);

ProxyMemoryPool proxy(default_memory_pool());
CappedMemoryPool pool(&proxy, bytes_allocated_limit);
ScalarMemoTable<int64_t> target(&pool, 0);
for (int64_t value = 0; value < 15; ++value) {
AssertGetOrInsert(target, value, static_cast<int32_t>(value));
}
ASSERT_EQ(proxy.bytes_allocated(), bytes_allocated_limit);

ASSERT_RAISES(OutOfMemory, target.MergeTable(source));
}

TEST(BinaryMemoTable, Basics) {
std::string A = "", B = "a", C = "foo", D = "bar", E, F;
E += '\0';
Expand Down Expand Up @@ -480,6 +507,35 @@ TEST(BinaryMemoTable, Stress) {
ASSERT_EQ(table.size(), map.size());
}

TEST(BinaryMemoTable, MergeTablePropagatesInsertError) {
const std::vector<std::string> initial_values = {"a", "bb", "ccc", "dddd"};
const std::string extra_value(4096, 'x');

int64_t bytes_allocated_limit = 0;
{
ProxyMemoryPool probe(default_memory_pool());
BinaryMemoTable<BinaryBuilder> target(&probe, 0);
for (size_t i = 0; i < initial_values.size(); ++i) {
AssertGetOrInsert(target, initial_values[i], static_cast<int32_t>(i));
}
bytes_allocated_limit = probe.bytes_allocated();
}
ASSERT_GT(bytes_allocated_limit, 0);

BinaryMemoTable<BinaryBuilder> source(default_memory_pool(), 0);
AssertGetOrInsert(source, extra_value, 0);

ProxyMemoryPool proxy(default_memory_pool());
CappedMemoryPool pool(&proxy, bytes_allocated_limit);
BinaryMemoTable<BinaryBuilder> target(&pool, 0);
for (size_t i = 0; i < initial_values.size(); ++i) {
AssertGetOrInsert(target, initial_values[i], static_cast<int32_t>(i));
}
ASSERT_EQ(proxy.bytes_allocated(), bytes_allocated_limit);

ASSERT_RAISES(OutOfMemory, target.MergeTable(source));
}

TEST(BinaryMemoTable, Empty) {
BinaryMemoTable<BinaryBuilder> table(default_memory_pool());
ASSERT_EQ(table.size(), 0);
Expand Down