diff --git a/include/executor/operator.hpp b/include/executor/operator.hpp index 7e34592..e63738b 100644 --- a/include/executor/operator.hpp +++ b/include/executor/operator.hpp @@ -275,28 +275,42 @@ class AggregateOperator : public Operator { * @brief Hash join operator */ class HashJoinOperator : public Operator { + public: + using JoinType = cloudsql::executor::JoinType; + private: + struct BuildTuple { + Tuple tuple; + bool matched = false; + }; + std::unique_ptr left_; std::unique_ptr right_; std::unique_ptr left_key_; std::unique_ptr right_key_; + JoinType join_type_; Schema schema_; /* In-memory hash table for the right side */ - std::unordered_multimap hash_table_; + std::unordered_multimap hash_table_; /* Probe phase state */ std::optional left_tuple_; + bool left_had_match_ = false; struct MatchIterator { - std::unordered_multimap::iterator current; - std::unordered_multimap::iterator end; + std::unordered_multimap::iterator current; + std::unordered_multimap::iterator end; }; std::optional match_iter_; + /* Final phase for RIGHT/FULL joins */ + std::optional::iterator> right_idx_iter_; + public: HashJoinOperator(std::unique_ptr left, std::unique_ptr right, std::unique_ptr left_key, - std::unique_ptr right_key); + std::unique_ptr right_key, + JoinType join_type = JoinType::Inner); bool init() override; bool open() override; diff --git a/include/executor/types.hpp b/include/executor/types.hpp index eb03e46..3c5e384 100644 --- a/include/executor/types.hpp +++ b/include/executor/types.hpp @@ -25,6 +25,11 @@ namespace cloudsql::executor { */ enum class ExecState : uint8_t { Init, Open, Executing, Done, Error }; +/** + * @brief Supported join types for relation merging. + */ +enum class JoinType : uint8_t { Inner, Left, Right, Full }; + /** * @brief Supported aggregation functions for analytical queries. */ diff --git a/src/executor/operator.cpp b/src/executor/operator.cpp index b16efd6..a09ed66 100644 --- a/src/executor/operator.cpp +++ b/src/executor/operator.cpp @@ -565,19 +565,29 @@ Schema& AggregateOperator::output_schema() { HashJoinOperator::HashJoinOperator(std::unique_ptr left, std::unique_ptr right, std::unique_ptr left_key, - std::unique_ptr right_key) + std::unique_ptr right_key, + executor::JoinType join_type) : Operator(OperatorType::HashJoin, left->get_txn(), left->get_lock_manager()), left_(std::move(left)), right_(std::move(right)), left_key_(std::move(left_key)), - right_key_(std::move(right_key)) { + right_key_(std::move(right_key)), + join_type_(join_type) { /* Build resulting schema */ if (left_ && right_) { for (const auto& col : left_->output_schema().columns()) { - schema_.add_column(col); + auto col_meta = col; + if (join_type_ == executor::JoinType::Right || join_type_ == executor::JoinType::Full) { + col_meta.set_nullable(true); + } + schema_.add_column(col_meta); } for (const auto& col : right_->output_schema().columns()) { - schema_.add_column(col); + auto col_meta = col; + if (join_type_ == executor::JoinType::Left || join_type_ == executor::JoinType::Full) { + col_meta.set_nullable(true); + } + schema_.add_column(col_meta); } } } @@ -597,62 +607,107 @@ bool HashJoinOperator::open() { auto right_schema = right_->output_schema(); while (right_->next(right_tuple)) { const common::Value key = right_key_->evaluate(&right_tuple, &right_schema); - hash_table_.emplace(key.to_string(), std::move(right_tuple)); + hash_table_.emplace(key.to_string(), BuildTuple{std::move(right_tuple), false}); } left_tuple_ = std::nullopt; match_iter_ = std::nullopt; + left_had_match_ = false; + right_idx_iter_ = std::nullopt; set_state(ExecState::Open); return true; } bool HashJoinOperator::next(Tuple& out_tuple) { auto left_schema = left_->output_schema(); + auto right_schema = right_->output_schema(); while (true) { if (match_iter_.has_value()) { - /* We are currently iterating through matches for a left tuple */ auto& iter_state = match_iter_.value(); if (iter_state.current != iter_state.end) { - const auto& right_tuple = iter_state.current->second; - - /* Concatenate left and right tuples */ - if (left_tuple_.has_value()) { - std::vector joined_values = left_tuple_->values(); - joined_values.insert(joined_values.end(), right_tuple.values().begin(), - right_tuple.values().end()); + auto& build_tuple = iter_state.current->second; + const auto& right_tuple = build_tuple.tuple; + std::vector joined_values = left_tuple_->values(); + joined_values.insert(joined_values.end(), right_tuple.values().begin(), + right_tuple.values().end()); + + out_tuple = Tuple(std::move(joined_values)); + iter_state.current++; + left_had_match_ = true; + build_tuple.matched = true; + return true; + } - out_tuple = Tuple(std::move(joined_values)); - iter_state.current++; - return true; + /* No more matches for this left tuple. If (LEFT or FULL join) and no matches found, + * emit NULLs */ + match_iter_ = std::nullopt; + if ((join_type_ == JoinType::Left || join_type_ == JoinType::Full) && + !left_had_match_) { + std::vector joined_values = left_tuple_->values(); + for (size_t i = 0; i < right_schema.column_count(); ++i) { + joined_values.push_back(common::Value::make_null()); } + out_tuple = Tuple(std::move(joined_values)); + left_tuple_ = std::nullopt; + return true; } - /* No more matches for this left tuple */ - match_iter_ = std::nullopt; left_tuple_ = std::nullopt; } /* Pull next tuple from left side */ Tuple next_left; - if (!left_->next(next_left)) { - set_state(ExecState::Done); - return false; - } - - left_tuple_ = std::move(next_left); - if (left_tuple_.has_value()) { + if (left_->next(next_left)) { + left_tuple_ = std::move(next_left); + left_had_match_ = false; const common::Value key = left_key_->evaluate(&(left_tuple_.value()), &left_schema); /* Look up in hash table */ auto range = hash_table_.equal_range(key.to_string()); if (range.first != range.second) { match_iter_ = {range.first, range.second}; - /* Continue loop to return the first match */ + } else if (join_type_ == JoinType::Left || join_type_ == JoinType::Full) { + /* No match found immediately, emit NULLs if Left/Full join */ + std::vector joined_values = left_tuple_->values(); + for (size_t i = 0; i < right_schema.column_count(); ++i) { + joined_values.push_back(common::Value::make_null()); + } + out_tuple = Tuple(std::move(joined_values)); + left_tuple_ = std::nullopt; + return true; } else { - /* No match for this left tuple, pull next */ + /* Inner/Right join and no match, skip to next left tuple */ left_tuple_ = std::nullopt; } + continue; + } + + /* Probe phase done. For RIGHT or FULL joins, scan hash table for unmatched right tuples */ + if (join_type_ == JoinType::Right || join_type_ == JoinType::Full) { + if (!right_idx_iter_.has_value()) { + right_idx_iter_ = hash_table_.begin(); + } + + auto& it = right_idx_iter_.value(); + while (it != hash_table_.end()) { + if (!it->second.matched) { + std::vector joined_values; + for (size_t i = 0; i < left_schema.column_count(); ++i) { + joined_values.push_back(common::Value::make_null()); + } + joined_values.insert(joined_values.end(), it->second.tuple.values().begin(), + it->second.tuple.values().end()); + out_tuple = Tuple(std::move(joined_values)); + it->second.matched = true; /* Mark as emitted */ + it++; + return true; + } + it++; + } } + + set_state(ExecState::Done); + return false; } } diff --git a/src/executor/query_executor.cpp b/src/executor/query_executor.cpp index 25685b4..ebd31b0 100644 --- a/src/executor/query_executor.cpp +++ b/src/executor/query_executor.cpp @@ -649,9 +649,18 @@ std::unique_ptr QueryExecutor::build_plan(const parser::SelectStatemen } if (use_hash_join) { - current_root = - std::make_unique(std::move(current_root), std::move(join_scan), - std::move(left_key), std::move(right_key)); + executor::JoinType exec_join_type = executor::JoinType::Inner; + if (join.type == parser::SelectStatement::JoinType::Left) { + exec_join_type = executor::JoinType::Left; + } else if (join.type == parser::SelectStatement::JoinType::Right) { + exec_join_type = executor::JoinType::Right; + } else if (join.type == parser::SelectStatement::JoinType::Full) { + exec_join_type = executor::JoinType::Full; + } + + current_root = std::make_unique( + std::move(current_root), std::move(join_scan), std::move(left_key), + std::move(right_key), exec_join_type); } else { /* TODO: Implement NestedLoopJoin for non-equality or missing conditions */ return nullptr; diff --git a/src/network/server.cpp b/src/network/server.cpp index 3a422fb..3cc8f3b 100644 --- a/src/network/server.cpp +++ b/src/network/server.cpp @@ -414,12 +414,24 @@ void Server::handle_connection(int client_fd) { for (const auto& row : res.rows()) { const char d_type = 'D'; uint32_t d_len = 4 + 2; // len + num_cols - std::vector str_vals; + + struct ColValue { + bool is_null; + std::string val; + }; + std::vector col_vals; + col_vals.reserve(num_cols); + for (uint32_t i = 0; i < num_cols; ++i) { - const std::string s_val = row.get(i).to_string(); - str_vals.push_back(s_val); - d_len += - 4 + static_cast(s_val.size()); // len + value + const auto& v = row.get(i); + if (v.is_null()) { + col_vals.push_back({true, ""}); + d_len += 4; + } else { + std::string s_val = v.to_string(); + d_len += 4 + static_cast(s_val.size()); + col_vals.push_back({false, std::move(s_val)}); + } } const uint32_t net_d_len = htonl(d_len); @@ -427,12 +439,17 @@ void Server::handle_connection(int client_fd) { static_cast(send(client_fd, &net_d_len, 4, 0)); static_cast(send(client_fd, &net_num_cols, 2, 0)); - for (const auto& s_val : str_vals) { - const uint32_t val_len = - htonl(static_cast(s_val.size())); - static_cast(send(client_fd, &val_len, 4, 0)); - static_cast( - send(client_fd, s_val.c_str(), s_val.size(), 0)); + for (const auto& cv : col_vals) { + if (cv.is_null) { + const uint32_t null_len = 0xFFFFFFFF; + static_cast(send(client_fd, &null_len, 4, 0)); + } else { + const uint32_t val_len = + htonl(static_cast(cv.val.size())); + static_cast(send(client_fd, &val_len, 4, 0)); + static_cast( + send(client_fd, cv.val.c_str(), cv.val.size(), 0)); + } } } } diff --git a/src/parser/expression.cpp b/src/parser/expression.cpp index d796910..c347f0f 100644 --- a/src/parser/expression.cpp +++ b/src/parser/expression.cpp @@ -234,7 +234,16 @@ common::Value ColumnExpr::evaluate(const executor::Tuple* tuple, return common::Value::make_null(); } - const size_t index = schema->find_column(name_); + size_t index = static_cast(-1); + + /* 1. Try exact match (either fully qualified or just name) */ + index = schema->find_column(this->to_string()); + + /* 2. If not found and it's qualified, try just the column name */ + if (index == static_cast(-1) && has_table()) { + index = schema->find_column(name_); + } + if (index == static_cast(-1)) { return common::Value::make_null(); } @@ -245,7 +254,16 @@ common::Value ColumnExpr::evaluate(const executor::Tuple* tuple, void ColumnExpr::evaluate_vectorized(const executor::VectorBatch& batch, const executor::Schema& schema, executor::ColumnVector& result) const { - const size_t index = schema.find_column(name_); + size_t index = static_cast(-1); + + /* 1. Try exact match (either fully qualified or just name) */ + index = schema.find_column(this->to_string()); + + /* 2. If not found and it's qualified, try just the column name */ + if (index == static_cast(-1) && has_table()) { + index = schema.find_column(name_); + } + result.clear(); if (index == static_cast(-1)) { for (size_t i = 0; i < batch.row_count(); ++i) { diff --git a/tests/cloudSQL_tests.cpp b/tests/cloudSQL_tests.cpp index a1a02c9..82a5b9e 100644 --- a/tests/cloudSQL_tests.cpp +++ b/tests/cloudSQL_tests.cpp @@ -784,8 +784,9 @@ TEST(ParserAdvanced, JoinAndComplexSelect) { /* 1. Left Join and multiple joins */ { auto lexer = std::make_unique( - "SELECT a.id, b.val FROM t1 LEFT JOIN t2 ON a.id = b.id JOIN t3 ON b.x = t3.x WHERE " - "a.id > 10"); + "SELECT t1.id, t2.val FROM t1 LEFT JOIN t2 ON t1.id = t2.id JOIN t3 ON t2.x = t3.x " + "WHERE " + "t1.id > 10"); Parser parser(std::move(lexer)); auto stmt = parser.parse_statement(); ASSERT_NE(stmt, nullptr); diff --git a/tests/logic/aggregates.slt b/tests/logic/aggregates.slt new file mode 100644 index 0000000..fe4758e --- /dev/null +++ b/tests/logic/aggregates.slt @@ -0,0 +1,45 @@ +# Aggregate and Group By Tests + +statement ok +CREATE TABLE agg_test (grp TEXT, val INT); + +statement ok +INSERT INTO agg_test VALUES ('A', 10), ('A', 20), ('B', 5), ('B', 15), ('C', 100); + +# Basic Aggregates +query IIII +SELECT SUM(val), COUNT(val), MIN(val), MAX(val) FROM agg_test; +---- +150 5 5 100 + +# Group By +query TI +SELECT grp, SUM(val) FROM agg_test GROUP BY grp ORDER BY grp; +---- +A 30 +B 20 +C 100 + +# Group By with Filter +query TI +SELECT grp, COUNT(val) FROM agg_test WHERE val > 10 GROUP BY grp ORDER BY grp; +---- +A 1 +B 1 +C 1 + +# Having Clause +query TI +SELECT grp, SUM(val) FROM agg_test GROUP BY grp HAVING SUM(val) > 25 ORDER BY grp; +---- +A 30 +C 100 + +# Average (Real) +query R +SELECT AVG(val) FROM agg_test WHERE grp = 'A'; +---- +15.0 + +statement ok +DROP TABLE agg_test; diff --git a/tests/logic/basic.slt b/tests/logic/basic.slt new file mode 100644 index 0000000..40f9d7c --- /dev/null +++ b/tests/logic/basic.slt @@ -0,0 +1,57 @@ +# Basic SLT Test + +statement ok +CREATE TABLE test_slt (id INT, name TEXT, val DOUBLE); + +statement ok +INSERT INTO test_slt VALUES (1, 'Alice', 10.5); + +statement ok +INSERT INTO test_slt VALUES (2, 'Bob', 20.0); + +statement ok +INSERT INTO test_slt VALUES (3, 'Charlie', 30.75); + +# Basic Select +query ITR +SELECT id, name, val FROM test_slt ORDER BY id; +---- +1 Alice 10.5 +2 Bob 20.0 +3 Charlie 30.75 + +# Filtered Select +query T +SELECT name FROM test_slt WHERE id = 2; +---- +Bob + +# Null handling (if supported by server responses) +statement ok +INSERT INTO test_slt VALUES (4, 'NULL_TEST', NULL); + +query I +SELECT id FROM test_slt WHERE val IS NULL; +---- +4 + +# Update +statement ok +UPDATE test_slt SET val = 15.0 WHERE id = 1; + +query R +SELECT val FROM test_slt WHERE id = 1; +---- +15.0 + +# Delete +statement ok +DELETE FROM test_slt WHERE id = 3; + +query I +SELECT COUNT(id) FROM test_slt; +---- +3 + +statement ok +DROP TABLE test_slt; diff --git a/tests/logic/expressions.slt b/tests/logic/expressions.slt new file mode 100644 index 0000000..359c524 --- /dev/null +++ b/tests/logic/expressions.slt @@ -0,0 +1,49 @@ +# Expression Logic Tests + +statement ok +CREATE TABLE expr_test (a INT, b INT, c DOUBLE); + +statement ok +INSERT INTO expr_test VALUES (10, 20, 30.5), (5, 5, 5.0), (0, 100, -1.5); + +# Arithmetic +query R +SELECT a + b + c FROM expr_test ORDER BY a; +---- +98.5 +15.0 +60.5 + +# Comparison & Logic +query I +SELECT a FROM expr_test WHERE a < b AND b > 50; +---- +0 + +query I +SELECT COUNT(a) FROM expr_test WHERE (a + b) = 10; +---- +1 + +# NULL Logic +statement ok +INSERT INTO expr_test VALUES (NULL, 1, 1.0); + +query I +SELECT b FROM expr_test WHERE a IS NULL; +---- +1 + +query I +SELECT COUNT(b) FROM expr_test WHERE a IS NOT NULL; +---- +3 + +# Complex Binary +query I +SELECT a FROM expr_test WHERE a * 2 = b; +---- +10 + +statement ok +DROP TABLE expr_test; diff --git a/tests/logic/joins.slt b/tests/logic/joins.slt new file mode 100644 index 0000000..eee732d --- /dev/null +++ b/tests/logic/joins.slt @@ -0,0 +1,42 @@ +# Join Operations Tests + +statement ok +CREATE TABLE users_j (id INT, name TEXT); + +statement ok +CREATE TABLE orders_j (id INT, user_id INT, amount DOUBLE); + +statement ok +INSERT INTO users_j VALUES (1, 'Alice'), (2, 'Bob'), (3, 'Charlie'); + +statement ok +INSERT INTO orders_j VALUES (101, 1, 50.0), (102, 1, 25.0), (103, 2, 100.0); + +# Inner Join +query TR +SELECT users_j.name, orders_j.amount FROM users_j JOIN orders_j ON users_j.id = orders_j.user_id ORDER BY orders_j.amount; +---- +Alice 25.0 +Alice 50.0 +Bob 100.0 + +# Join with where +query TI +SELECT users_j.name, orders_j.id FROM users_j JOIN orders_j ON users_j.id = orders_j.user_id WHERE orders_j.amount > 60; +---- +Bob 103 + +# Left Join (Charlie has no orders) +query TR +SELECT users_j.name, orders_j.amount FROM users_j LEFT JOIN orders_j ON users_j.id = orders_j.user_id ORDER BY users_j.name, orders_j.amount; +---- +Alice 25.0 +Alice 50.0 +Bob 100.0 +Charlie NULL + +statement ok +DROP TABLE users_j; + +statement ok +DROP TABLE orders_j; diff --git a/tests/logic/slt_runner.py b/tests/logic/slt_runner.py new file mode 100644 index 0000000..7c19f61 --- /dev/null +++ b/tests/logic/slt_runner.py @@ -0,0 +1,226 @@ +import socket +import struct +import sys +import time +import math + +PROTOCOL_VERSION_3 = 196608 + +class CloudSQLClient: + def __init__(self, host='127.0.0.1', port=5432): + self.host = host + self.port = port + self.sock = None + + def connect(self): + self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.sock.settimeout(5.0) + self.sock.connect((self.host, self.port)) + + length = 8 + packet = struct.pack('!II', length, PROTOCOL_VERSION_3) + self.sock.sendall(packet) + + try: + r_type = self.recv_exactly(1) + if r_type != b'R': + raise Exception(f"Expected AuthOK 'R', got {r_type}") + self.recv_exactly(8) + + z_type = self.recv_exactly(1) + if z_type != b'Z': + raise Exception(f"Expected ReadyForQuery 'Z', got {z_type}") + self.recv_exactly(5) + except Exception as e: + raise Exception(f"Handshake failed: {e}") + + def recv_exactly(self, n): + data = b'' + while len(data) < n: + packet = self.sock.recv(n - len(data)) + if not packet: + return None + data += packet + return data + + def query(self, sql): + sql_bytes = sql.encode('utf-8') + b'\0' + length = 4 + len(sql_bytes) + packet = b'Q' + struct.pack('!I', length) + sql_bytes + self.sock.sendall(packet) + + rows = [] + status = "OK" + + while True: + type_byte = self.recv_exactly(1) + if not type_byte: + break + + type_char = type_byte.decode() + len_bytes = self.recv_exactly(4) + if not len_bytes: + break + length = struct.unpack('!I', len_bytes)[0] + body = self.recv_exactly(length - 4) + + if type_char == 'D': + num_cols = struct.unpack('!h', body[:2])[0] + idx = 2 + row_data = [] + for _ in range(num_cols): + col_len = struct.unpack('!I', body[idx:idx+4])[0] + idx += 4 + if col_len == 0xFFFFFFFF: + row_data.append(None) + else: + val = body[idx:idx+col_len].decode('utf-8') + row_data.append(val) + idx += col_len + rows.append(row_data) + elif type_char == 'C': + pass # CommandComplete + elif type_char == 'E': + status = "ERROR" + elif type_char == 'Z': + break + + return rows, status + +def run_slt(file_path, port): + client = CloudSQLClient(port=port) + client.connect() + + with open(file_path, 'r') as f: + lines = f.readlines() + + line_idx = 0 + total_tests = 0 + failed_tests = 0 + + while line_idx < len(lines): + line = lines[line_idx].strip() + if not line or line.startswith('#'): + line_idx += 1 + continue + + if line.startswith('statement'): + expected_status = line.split()[1] # ok or error + sql_lines = [] + line_idx += 1 + while line_idx < len(lines) and lines[line_idx].strip(): + sql_lines.append(lines[line_idx].strip()) + line_idx += 1 + + sql = " ".join(sql_lines) + total_tests += 1 + _, actual_status = client.query(sql) + + if actual_status.lower() != expected_status.lower(): + print(f"FAILURE at {file_path}:{line_idx}") + print(f" SQL: {sql}") + print(f" Expected status: {expected_status}, got: {actual_status}") + failed_tests += 1 + + elif line.startswith('query'): + # query [sort] + parts = line.split() + types = parts[1] + sort_mode = parts[2] if len(parts) > 2 else None + + sql_lines = [] + line_idx += 1 + while line_idx < len(lines) and lines[line_idx].strip() != '----': + sql_lines.append(lines[line_idx].strip()) + line_idx += 1 + + sql = " ".join(sql_lines) + line_idx += 1 # skip '----' + + expected_rows = [] + while line_idx < len(lines) and lines[line_idx].strip(): + expected_rows.append(lines[line_idx].strip().split()) + line_idx += 1 + + total_tests += 1 + actual_rows, status = client.query(sql) + + if status == "ERROR": + print(f"FAILURE at {file_path}:{line_idx}") + print(f" SQL: {sql}") + print(f" Query failed with ERROR status") + failed_tests += 1 + continue + + # Apply sort mode + if sort_mode == 'rowsort': + actual_rows.sort() + expected_rows.sort() + elif sort_mode == 'valuesort': + actual_values = sorted([str(val) if val is not None else "NULL" for row in actual_rows for val in row]) + expected_values = sorted([val for row in expected_rows for val in row]) + actual_rows = [[v] for v in actual_values] + expected_rows = [[v] for v in expected_values] + elif sort_mode: + print(f"ERROR: Unsupported sort mode: {sort_mode}") + sys.exit(1) + + # Compare results + if len(actual_rows) != len(expected_rows): + print(f"FAILURE at {file_path}:{line_idx}") + print(f" SQL: {sql}") + print(f" Expected {len(expected_rows)} rows, got {len(actual_rows)}") + failed_tests += 1 + continue + + for i in range(len(actual_rows)): + if len(actual_rows[i]) != len(expected_rows[i]): + print(f"FAILURE at {file_path}:{line_idx}, row {i}") + print(f" Expected {len(expected_rows[i])} columns, got {len(actual_rows[i])}") + failed_tests += 1 + break + + match = True + for j in range(len(actual_rows[i])): + act = actual_rows[i][j] + exp = expected_rows[i][j] + + if exp == "NULL" and act is None: + continue + + # Basic numeric normalization for float comparison + if types[j] == 'R' and sort_mode != 'valuesort': + try: + if not math.isclose(float(act), float(exp), rel_tol=1e-6): + match = False + except: + match = False + else: + if str(act) != str(exp): + match = False + + if not match: + print(f"FAILURE at {file_path}:{line_idx}, row {i} col {j}") + print(f" Expected '{exp}', got '{act}'") + failed_tests += 1 + break + if not match: break + + else: + line_idx += 1 + + print(f"SLT Summary: {total_tests} tests, {failed_tests} failed.") + return failed_tests == 0 + +if __name__ == "__main__": + if len(sys.argv) < 3: + print("Usage: python3 slt_runner.py ") + sys.exit(1) + + port = int(sys.argv[1]) + file_path = sys.argv[2] + + if run_slt(file_path, port): + sys.exit(0) + else: + sys.exit(1) diff --git a/tests/run_test.sh b/tests/run_test.sh index 0bb200e..12d7879 100755 --- a/tests/run_test.sh +++ b/tests/run_test.sh @@ -1,23 +1,60 @@ -#!/usr/bin/env bash -# cleanup function to ensure background cloudSQL process is terminated +#!/bin/bash + +# cloudSQL E2E and Logic Test Runner +# This script builds the engine and runs both Python E2E tests and SLT logic tests. + cleanup() { - if [ -n "$SQL_PID" ]; then - kill $SQL_PID 2>/dev/null || true - wait $SQL_PID 2>/dev/null || true + echo "Shutting down..." + if [ ! -z "$SQL_PID" ]; then + kill $SQL_PID 2>/dev/null fi } # Trap exit, interrupt and error signals trap cleanup EXIT INT ERR -rm -rf ../test_data || true -mkdir -p ../test_data -cd ../build -make -j4 -./cloudSQL -p 5438 -d ../test_data & +# Resolve absolute paths +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +BUILD_DIR="$ROOT_DIR/build" +TEST_DATA_DIR="$ROOT_DIR/test_data" + +rm -rf "$TEST_DATA_DIR" || true +mkdir -p "$TEST_DATA_DIR" + +# Detect CPU count for parallel make +if command -v nproc >/dev/null 2>&1; then + CPU_COUNT=$(nproc) +elif command -v sysctl >/dev/null 2>&1; then + CPU_COUNT=$(sysctl -n hw.ncpu) +elif command -v getconf >/dev/null 2>&1; then + CPU_COUNT=$(getconf _NPROCESSORS_ONLN) +else + CPU_COUNT=1 +fi + +echo "Detected $CPU_COUNT CPUs, building with -j$CPU_COUNT" + +cd "$BUILD_DIR" || exit 1 +make -j"$CPU_COUNT" +./cloudSQL -p 5438 -d "$TEST_DATA_DIR" & SQL_PID=$! sleep 2 -echo "Running E2E" -python3 ../tests/e2e/e2e_test.py 5438 + +echo "--- Running E2E Tests ---" +python3 "$ROOT_DIR/tests/e2e/e2e_test.py" 5438 RET=$? + +if [ $RET -eq 0 ]; then + echo "--- Running SLT Logic Tests ---" + for slt_file in "$ROOT_DIR"/tests/logic/*.slt; do + echo "Running $slt_file..." + python3 "$ROOT_DIR/tests/logic/slt_runner.py" 5438 "$slt_file" + SLT_RET=$? + if [ $SLT_RET -ne 0 ]; then + RET=$SLT_RET + break + fi + done +fi + exit $RET