Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions include/executor/operator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -275,28 +275,42 @@ class AggregateOperator : public Operator {
* @brief Hash join operator
*/
class HashJoinOperator : public Operator {
public:
using JoinType = cloudsql::executor::JoinType;

private:
struct BuildTuple {
Tuple tuple;
bool matched = false;
};

std::unique_ptr<Operator> left_;
std::unique_ptr<Operator> right_;
std::unique_ptr<parser::Expression> left_key_;
std::unique_ptr<parser::Expression> right_key_;
JoinType join_type_;
Schema schema_;

/* In-memory hash table for the right side */
std::unordered_multimap<std::string, Tuple> hash_table_;
std::unordered_multimap<std::string, BuildTuple> hash_table_;

/* Probe phase state */
std::optional<Tuple> left_tuple_;
bool left_had_match_ = false;
struct MatchIterator {
std::unordered_multimap<std::string, Tuple>::iterator current;
std::unordered_multimap<std::string, Tuple>::iterator end;
std::unordered_multimap<std::string, BuildTuple>::iterator current;
std::unordered_multimap<std::string, BuildTuple>::iterator end;
};
std::optional<MatchIterator> match_iter_;

/* Final phase for RIGHT/FULL joins */
std::optional<std::unordered_multimap<std::string, BuildTuple>::iterator> right_idx_iter_;

public:
HashJoinOperator(std::unique_ptr<Operator> left, std::unique_ptr<Operator> right,
std::unique_ptr<parser::Expression> left_key,
std::unique_ptr<parser::Expression> right_key);
std::unique_ptr<parser::Expression> right_key,
JoinType join_type = JoinType::Inner);

bool init() override;
bool open() override;
Expand Down
5 changes: 5 additions & 0 deletions include/executor/types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ namespace cloudsql::executor {
*/
enum class ExecState : uint8_t { Init, Open, Executing, Done, Error };

/**
* @brief Supported join types for relation merging.
*/
enum class JoinType : uint8_t { Inner, Left, Right, Full };

/**
* @brief Supported aggregation functions for analytical queries.
*/
Expand Down
109 changes: 82 additions & 27 deletions src/executor/operator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -565,19 +565,29 @@ Schema& AggregateOperator::output_schema() {

HashJoinOperator::HashJoinOperator(std::unique_ptr<Operator> left, std::unique_ptr<Operator> right,
std::unique_ptr<parser::Expression> left_key,
std::unique_ptr<parser::Expression> right_key)
std::unique_ptr<parser::Expression> right_key,
executor::JoinType join_type)
: Operator(OperatorType::HashJoin, left->get_txn(), left->get_lock_manager()),
left_(std::move(left)),
right_(std::move(right)),
left_key_(std::move(left_key)),
right_key_(std::move(right_key)) {
right_key_(std::move(right_key)),
join_type_(join_type) {
/* Build resulting schema */
if (left_ && right_) {
for (const auto& col : left_->output_schema().columns()) {
schema_.add_column(col);
auto col_meta = col;
if (join_type_ == executor::JoinType::Right || join_type_ == executor::JoinType::Full) {
col_meta.set_nullable(true);
}
schema_.add_column(col_meta);
}
for (const auto& col : right_->output_schema().columns()) {
schema_.add_column(col);
auto col_meta = col;
if (join_type_ == executor::JoinType::Left || join_type_ == executor::JoinType::Full) {
col_meta.set_nullable(true);
}
schema_.add_column(col_meta);
}
}
}
Expand All @@ -597,62 +607,107 @@ bool HashJoinOperator::open() {
auto right_schema = right_->output_schema();
while (right_->next(right_tuple)) {
const common::Value key = right_key_->evaluate(&right_tuple, &right_schema);
hash_table_.emplace(key.to_string(), std::move(right_tuple));
hash_table_.emplace(key.to_string(), BuildTuple{std::move(right_tuple), false});
Comment on lines 609 to +610
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Do not hash SQL NULL join keys as ordinary values.

Lines 609-610 and 661-665 route join keys through to_string(), so a NULL key becomes just another hash entry. That makes NULL = NULL match here and can even collide with a real text key like "NULL", which breaks join semantics and prevents the unmatched-row paths from firing correctly in LEFT/RIGHT/FULL joins.

Also applies to: 661-675

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/executor/operator.cpp` around lines 609 - 610, The code currently calls
right_key_->evaluate(...) and then uses key.to_string() to build the hash_table_
entry, which treats SQL NULL like an ordinary string and causes NULL=NULL
matches; change both the build side (where hash_table_.emplace(...) is called
with key.to_string()) and the probe/lookup side (the symmetric block around the
661-675 region) to detect SQL NULL via the common::Value null-check (e.g.,
key.is_null() or equivalent) and handle it specially: do not convert NULL to a
string or insert it into the normal string-keyed hash (skip inserting NULL keys
or store them using a distinct null marker/optional type), and ensure lookups
treat NULL keys as non-matching so NULLs never match each other or collide with
the literal "NULL" string. Ensure you update both BuildTuple insertion sites and
the corresponding lookup logic to use the same null-handling convention.

}

left_tuple_ = std::nullopt;
match_iter_ = std::nullopt;
left_had_match_ = false;
right_idx_iter_ = std::nullopt;
set_state(ExecState::Open);
return true;
}

bool HashJoinOperator::next(Tuple& out_tuple) {
auto left_schema = left_->output_schema();
auto right_schema = right_->output_schema();

while (true) {
if (match_iter_.has_value()) {
/* We are currently iterating through matches for a left tuple */
auto& iter_state = match_iter_.value();
if (iter_state.current != iter_state.end) {
const auto& right_tuple = iter_state.current->second;

/* Concatenate left and right tuples */
if (left_tuple_.has_value()) {
std::vector<common::Value> joined_values = left_tuple_->values();
joined_values.insert(joined_values.end(), right_tuple.values().begin(),
right_tuple.values().end());
auto& build_tuple = iter_state.current->second;
const auto& right_tuple = build_tuple.tuple;
std::vector<common::Value> joined_values = left_tuple_->values();
joined_values.insert(joined_values.end(), right_tuple.values().begin(),
right_tuple.values().end());

out_tuple = Tuple(std::move(joined_values));
iter_state.current++;
left_had_match_ = true;
build_tuple.matched = true;
return true;
}

out_tuple = Tuple(std::move(joined_values));
iter_state.current++;
return true;
/* No more matches for this left tuple. If (LEFT or FULL join) and no matches found,
* emit NULLs */
match_iter_ = std::nullopt;
if ((join_type_ == JoinType::Left || join_type_ == JoinType::Full) &&
!left_had_match_) {
std::vector<common::Value> joined_values = left_tuple_->values();
for (size_t i = 0; i < right_schema.column_count(); ++i) {
joined_values.push_back(common::Value::make_null());
}
out_tuple = Tuple(std::move(joined_values));
left_tuple_ = std::nullopt;
return true;
}
/* No more matches for this left tuple */
match_iter_ = std::nullopt;
left_tuple_ = std::nullopt;
}

/* Pull next tuple from left side */
Tuple next_left;
if (!left_->next(next_left)) {
set_state(ExecState::Done);
return false;
}

left_tuple_ = std::move(next_left);
if (left_tuple_.has_value()) {
if (left_->next(next_left)) {
left_tuple_ = std::move(next_left);
left_had_match_ = false;
const common::Value key = left_key_->evaluate(&(left_tuple_.value()), &left_schema);

/* Look up in hash table */
auto range = hash_table_.equal_range(key.to_string());
if (range.first != range.second) {
match_iter_ = {range.first, range.second};
/* Continue loop to return the first match */
} else if (join_type_ == JoinType::Left || join_type_ == JoinType::Full) {
/* No match found immediately, emit NULLs if Left/Full join */
std::vector<common::Value> joined_values = left_tuple_->values();
for (size_t i = 0; i < right_schema.column_count(); ++i) {
joined_values.push_back(common::Value::make_null());
}
out_tuple = Tuple(std::move(joined_values));
left_tuple_ = std::nullopt;
return true;
} else {
/* No match for this left tuple, pull next */
/* Inner/Right join and no match, skip to next left tuple */
left_tuple_ = std::nullopt;
}
continue;
}

/* Probe phase done. For RIGHT or FULL joins, scan hash table for unmatched right tuples */
if (join_type_ == JoinType::Right || join_type_ == JoinType::Full) {
if (!right_idx_iter_.has_value()) {
right_idx_iter_ = hash_table_.begin();
}

auto& it = right_idx_iter_.value();
while (it != hash_table_.end()) {
if (!it->second.matched) {
std::vector<common::Value> joined_values;
for (size_t i = 0; i < left_schema.column_count(); ++i) {
joined_values.push_back(common::Value::make_null());
}
joined_values.insert(joined_values.end(), it->second.tuple.values().begin(),
it->second.tuple.values().end());
out_tuple = Tuple(std::move(joined_values));
it->second.matched = true; /* Mark as emitted */
it++;
return true;
}
it++;
}
}

set_state(ExecState::Done);
return false;
}
}

Expand Down
15 changes: 12 additions & 3 deletions src/executor/query_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -649,9 +649,18 @@ std::unique_ptr<Operator> QueryExecutor::build_plan(const parser::SelectStatemen
}

if (use_hash_join) {
current_root =
std::make_unique<HashJoinOperator>(std::move(current_root), std::move(join_scan),
std::move(left_key), std::move(right_key));
executor::JoinType exec_join_type = executor::JoinType::Inner;
if (join.type == parser::SelectStatement::JoinType::Left) {
exec_join_type = executor::JoinType::Left;
} else if (join.type == parser::SelectStatement::JoinType::Right) {
exec_join_type = executor::JoinType::Right;
} else if (join.type == parser::SelectStatement::JoinType::Full) {
exec_join_type = executor::JoinType::Full;
}

current_root = std::make_unique<HashJoinOperator>(
std::move(current_root), std::move(join_scan), std::move(left_key),
std::move(right_key), exec_join_type);
} else {
/* TODO: Implement NestedLoopJoin for non-equality or missing conditions */
return nullptr;
Expand Down
39 changes: 28 additions & 11 deletions src/network/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -414,25 +414,42 @@ void Server::handle_connection(int client_fd) {
for (const auto& row : res.rows()) {
const char d_type = 'D';
uint32_t d_len = 4 + 2; // len + num_cols
std::vector<std::string> str_vals;

struct ColValue {
bool is_null;
std::string val;
};
std::vector<ColValue> col_vals;
col_vals.reserve(num_cols);

for (uint32_t i = 0; i < num_cols; ++i) {
const std::string s_val = row.get(i).to_string();
str_vals.push_back(s_val);
d_len +=
4 + static_cast<uint32_t>(s_val.size()); // len + value
const auto& v = row.get(i);
if (v.is_null()) {
col_vals.push_back({true, ""});
d_len += 4;
} else {
std::string s_val = v.to_string();
d_len += 4 + static_cast<uint32_t>(s_val.size());
col_vals.push_back({false, std::move(s_val)});
}
}

const uint32_t net_d_len = htonl(d_len);
static_cast<void>(send(client_fd, &d_type, 1, 0));
static_cast<void>(send(client_fd, &net_d_len, 4, 0));
static_cast<void>(send(client_fd, &net_num_cols, 2, 0));

for (const auto& s_val : str_vals) {
const uint32_t val_len =
htonl(static_cast<uint32_t>(s_val.size()));
static_cast<void>(send(client_fd, &val_len, 4, 0));
static_cast<void>(
send(client_fd, s_val.c_str(), s_val.size(), 0));
for (const auto& cv : col_vals) {
if (cv.is_null) {
const uint32_t null_len = 0xFFFFFFFF;
static_cast<void>(send(client_fd, &null_len, 4, 0));
} else {
const uint32_t val_len =
htonl(static_cast<uint32_t>(cv.val.size()));
static_cast<void>(send(client_fd, &val_len, 4, 0));
static_cast<void>(
send(client_fd, cv.val.c_str(), cv.val.size(), 0));
}
}
}
}
Expand Down
22 changes: 20 additions & 2 deletions src/parser/expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,16 @@ common::Value ColumnExpr::evaluate(const executor::Tuple* tuple,
return common::Value::make_null();
}

const size_t index = schema->find_column(name_);
size_t index = static_cast<size_t>(-1);

/* 1. Try exact match (either fully qualified or just name) */
index = schema->find_column(this->to_string());

/* 2. If not found and it's qualified, try just the column name */
if (index == static_cast<size_t>(-1) && has_table()) {
index = schema->find_column(name_);
}

if (index == static_cast<size_t>(-1)) {
return common::Value::make_null();
}
Expand All @@ -245,7 +254,16 @@ common::Value ColumnExpr::evaluate(const executor::Tuple* tuple,
void ColumnExpr::evaluate_vectorized(const executor::VectorBatch& batch,
const executor::Schema& schema,
executor::ColumnVector& result) const {
const size_t index = schema.find_column(name_);
size_t index = static_cast<size_t>(-1);

/* 1. Try exact match (either fully qualified or just name) */
index = schema.find_column(this->to_string());

/* 2. If not found and it's qualified, try just the column name */
if (index == static_cast<size_t>(-1) && has_table()) {
index = schema.find_column(name_);
}

result.clear();
if (index == static_cast<size_t>(-1)) {
for (size_t i = 0; i < batch.row_count(); ++i) {
Expand Down
5 changes: 3 additions & 2 deletions tests/cloudSQL_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -784,8 +784,9 @@ TEST(ParserAdvanced, JoinAndComplexSelect) {
/* 1. Left Join and multiple joins */
{
auto lexer = std::make_unique<Lexer>(
"SELECT a.id, b.val FROM t1 LEFT JOIN t2 ON a.id = b.id JOIN t3 ON b.x = t3.x WHERE "
"a.id > 10");
"SELECT t1.id, t2.val FROM t1 LEFT JOIN t2 ON t1.id = t2.id JOIN t3 ON t2.x = t3.x "
"WHERE "
"t1.id > 10");
Parser parser(std::move(lexer));
auto stmt = parser.parse_statement();
ASSERT_NE(stmt, nullptr);
Expand Down
45 changes: 45 additions & 0 deletions tests/logic/aggregates.slt
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Aggregate and Group By Tests

statement ok
CREATE TABLE agg_test (grp TEXT, val INT);

statement ok
INSERT INTO agg_test VALUES ('A', 10), ('A', 20), ('B', 5), ('B', 15), ('C', 100);

# Basic Aggregates
query IIII
SELECT SUM(val), COUNT(val), MIN(val), MAX(val) FROM agg_test;
----
150 5 5 100

# Group By
query TI
SELECT grp, SUM(val) FROM agg_test GROUP BY grp ORDER BY grp;
----
A 30
B 20
C 100

# Group By with Filter
query TI
SELECT grp, COUNT(val) FROM agg_test WHERE val > 10 GROUP BY grp ORDER BY grp;
----
A 1
B 1
C 1

# Having Clause
query TI
SELECT grp, SUM(val) FROM agg_test GROUP BY grp HAVING SUM(val) > 25 ORDER BY grp;
----
A 30
C 100

# Average (Real)
query R
SELECT AVG(val) FROM agg_test WHERE grp = 'A';
----
15.0

statement ok
DROP TABLE agg_test;
Loading