Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 32 additions & 9 deletions src/abi/ace_exports.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5976,9 +5976,12 @@ UNSIGNED32 AdsAppendRecord(ADSHANDLE hTable) {
}
#endif
#if defined(OPENADS_WITH_MSSQL)
if (get_mssql_table(hTable)) {
return fail(openads::AE_FUNCTION_NOT_AVAILABLE,
"MssqlTable: write not available in v1");
if (auto* st = get_mssql_table(hTable)) {
if (st->conn == nullptr)
return fail(openads::AE_INVALID_CONNECTION_HANDLE, "");
auto r = st->conn->append_blank(st);
if (!r) return fail(r.error());
return ok();
}
#endif
Table* t = get_table(hTable);
Expand Down Expand Up @@ -6039,9 +6042,12 @@ UNSIGNED32 AdsWriteRecord(ADSHANDLE hTable) {
}
#endif
#if defined(OPENADS_WITH_MSSQL)
if (get_mssql_table(hTable)) {
return fail(openads::AE_FUNCTION_NOT_AVAILABLE,
"MssqlTable: write not available in v1");
if (auto* st = get_mssql_table(hTable)) {
if (st->conn == nullptr)
return fail(openads::AE_INVALID_CONNECTION_HANDLE, "");
auto r = st->conn->flush_record(st);
if (!r) return fail(r.error());
return ok();
}
#endif
Table* t = get_table(hTable);
Expand Down Expand Up @@ -6139,9 +6145,12 @@ UNSIGNED32 AdsDeleteRecord(ADSHANDLE hTable) {
}
#endif
#if defined(OPENADS_WITH_MSSQL)
if (get_mssql_table(hTable)) {
return fail(openads::AE_FUNCTION_NOT_AVAILABLE,
"MssqlTable: write not available in v1");
if (auto* st = get_mssql_table(hTable)) {
if (st->conn == nullptr)
return fail(openads::AE_INVALID_CONNECTION_HANDLE, "");
auto r = st->conn->delete_record(st);
if (!r) return fail(r.error());
return ok();
}
#endif
Table* t = get_table(hTable);
Expand Down Expand Up @@ -6306,6 +6315,20 @@ UNSIGNED32 AdsSetString(ADSHANDLE hTable, UNSIGNED8* pucField,
if (!r) return fail(r.error());
return ok();
}
#endif
#if defined(OPENADS_WITH_MSSQL)
if (auto* st = get_mssql_table(hTable)) {
if (pucField == nullptr) return fail(openads::AE_INTERNAL_ERROR, "");
if (st->conn == nullptr)
return fail(openads::AE_INVALID_CONNECTION_HANDLE, "");
std::string fname(reinterpret_cast<const char*>(pucField));
std::string val;
if (pucValue != nullptr && ulLen > 0)
val.assign(reinterpret_cast<const char*>(pucValue), ulLen);
auto r = st->conn->set_field(st, fname, val);
if (!r) return fail(r.error());
return ok();
}
#endif
Table* t = get_table(hTable);
if (!t) return fail(openads::AE_INTERNAL_ERROR, "unknown table");
Expand Down
215 changes: 215 additions & 0 deletions src/sql_backend/mssql_connection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,68 @@
#if defined(OPENADS_WITH_MSSQL)

#include "openads/error.h"
#include "sql_backend/mssql_table.h"
#include "sql_backend/mssql_uri.h"
#include "sql_backend/tds_protocol.h"

#include <cctype>
#include <cstddef>
#include <string>
#include <utility>
#include <vector>

namespace openads::sql_backend {

namespace {

// [name] with ']' doubled — safe SQL Server identifier quoting.
std::string quote_ident(const std::string& name) {
std::string out = "[";
for (char c : name) { if (c == ']') out += ']'; out += c; }
out += ']';
return out;
}
Comment on lines +21 to +26

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The quote_ident function currently wraps the entire identifier in brackets (e.g., [dbo.CLIENTES]). In SQL Server, schema-qualified identifiers must be quoted individually (e.g., [dbo].[CLIENTES]). Wrapping the entire string with a dot inside brackets will cause SQL Server to look for a table literally named dbo.CLIENTES in the default schema, which will fail. We should split the identifier by . and quote each part individually.

std::string quote_ident(const std::string& name) {
    std::string out;
    std::string part;
    for (char c : name) {
        if (c == '.') {
            if (!part.empty()) {
                out += "[";
                for (char pc : part) { if (pc == ']') out += ']'; out += pc; }
                out += "].";
                part.clear();
            }
        } else {
            part += c;
        }
    }
    if (!part.empty()) {
        out += "[";
        for (char pc : part) { if (pc == ']') out += ']'; out += pc; }
        out += "]";
    }
    return out;
}


// N'...' literal with '\'' doubled. All staged values are bound as Unicode
// string literals; SQL Server implicit-converts to the target column type.
std::string quote_lit(const std::string& v) {
std::string out = "N'";
for (char c : v) { if (c == '\'') out += '\''; out += c; }
out += '\'';
return out;
}

std::size_t col_index_ci(const MssqlTable& t, const std::string& name) {
for (std::size_t i = 0; i < t.data.columns.size(); ++i) {
const std::string& cn = t.data.columns[i].name;
if (cn.size() != name.size()) continue;
bool eq = true;
for (std::size_t k = 0; k < cn.size(); ++k) {
if (std::tolower(static_cast<unsigned char>(cn[k])) !=
std::tolower(static_cast<unsigned char>(name[k]))) { eq = false; break; }
}
if (eq) return i;
}
return static_cast<std::size_t>(-1);
}

// Build "[pk1] = N'v1' AND [pk2] = N'v2'" from a result row's PK cells.
std::string pk_where(const MssqlTable& t,
const std::vector<tds::TdsCell>& row) {
std::string w;
bool any = false;
for (std::size_t i : t.pk_cols) {
if (any) w += " AND ";
w += quote_ident(t.data.columns[i].name);
if (i < row.size() && row[i].is_null) w += " IS NULL";
else w += " = " + quote_lit(i < row.size() ? row[i].value : std::string{});
any = true;
}
return w;
}

} // namespace

struct MssqlConnection::Impl {
TdsTlsChannel channel;
bool authenticated = false;
Expand Down Expand Up @@ -122,6 +176,167 @@ util::Result<tds::QueryResult> MssqlConnection::query(const std::string& sql) {
return qr;
}

// ---------------------------------------------------------------------------
// Navigational write
// ---------------------------------------------------------------------------

namespace {
// Re-run SELECT * and replace the table's buffered result (so record_count
// and navigation reflect the write). Resets the cursor to BOF.
util::Result<void> refetch(MssqlConnection& c, MssqlTable* tbl) {
auto qr = c.query("SELECT * FROM " + quote_ident(tbl->table_name));
if (!qr) return qr.error();
if (!qr.value().ok) {
return util::Error{static_cast<std::int32_t>(qr.value().error_number),
0, qr.value().message, ""};
}
tbl->data = std::move(qr).value();
tbl->pos = 0;
tbl->bof = true;
tbl->eof = tbl->data.rows.empty();
return util::Result<void>{};
}
Comment on lines +186 to +198

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The refetch function unconditionally resets the cursor position to BOF (tbl->pos = 0). In a navigational database, after performing an UPDATE or INSERT (via flush_record), the cursor is expected to remain on the written record. Resetting to BOF will cause subsequent edits or deletes to silently target the wrong record. We should accept the target primary key values and restore the cursor position to the matching row after refetching.

util::Result<void> refetch(MssqlConnection& c, MssqlTable* tbl, const std::vector<std::string>& target_pk = {}) {
    auto qr = c.query("SELECT * FROM " + quote_ident(tbl->table_name));
    if (!qr) return qr.error();
    if (!qr.value().ok) {
        return util::Error{static_cast<std::int32_t>(qr.value().error_number),
                           0, qr.value().message, ""};
    }
    tbl->data = std::move(qr).value();
    tbl->pos  = 0;
    tbl->bof  = true;
    tbl->eof  = tbl->data.rows.empty();

    if (!target_pk.empty() && !tbl->pk_cols.empty()) {
        for (std::size_t r = 0; r < tbl->data.rows.size(); ++r) {
            const auto& row = tbl->data.rows[r];
            bool match = true;
            for (std::size_t i = 0; i < tbl->pk_cols.size(); ++i) {
                std::size_t col_idx = tbl->pk_cols[i];
                if (col_idx >= row.size() || row[col_idx].value != target_pk[i]) {
                    match = false;
                    break;
                }
            }
            if (match) {
                tbl->pos = r;
                tbl->bof = false;
                tbl->eof = false;
                break;
            }
        }
    }
    return util::Result<void>{};
}

} // namespace

util::Result<void> MssqlConnection::append_blank(MssqlTable* tbl) {
if (!valid() || tbl == nullptr) {
return util::Error{5001, 0, "invalid mssql append", ""};
}
const std::size_t n = tbl->data.columns.size();
tbl->staging_row.assign(n, std::string{});
tbl->staging_nulls.assign(n, true);
tbl->pending_append = true;
tbl->row_dirty = true;
return util::Result<void>{};
}

util::Result<void> MssqlConnection::set_field(
MssqlTable* tbl, const std::string& field_name, const std::string& value) {
if (!valid() || tbl == nullptr) {
return util::Error{5001, 0, "invalid mssql set_field", ""};
}
const std::size_t idx = col_index_ci(*tbl, field_name);
if (idx == static_cast<std::size_t>(-1)) {
return util::Error{5063, 0, "column not found", field_name};
}
Comment on lines +213 to +221

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

In set_field, if we are not pending an append and there is no current record (e.g., cursor is at EOF), we should fail early with AE_NO_CURRENT_RECORD (error code 5026) instead of staging values on an invalid cursor state and failing only during flush_record.

Suggested change
util::Result<void> MssqlConnection::set_field(
MssqlTable* tbl, const std::string& field_name, const std::string& value) {
if (!valid() || tbl == nullptr) {
return util::Error{5001, 0, "invalid mssql set_field", ""};
}
const std::size_t idx = col_index_ci(*tbl, field_name);
if (idx == static_cast<std::size_t>(-1)) {
return util::Error{5063, 0, "column not found", field_name};
}
util::Result<void> MssqlConnection::set_field(
MssqlTable* tbl, const std::string& field_name, const std::string& value) {
if (!valid() || tbl == nullptr) {
return util::Error{5001, 0, "invalid mssql set_field", ""};
}
if (!tbl->pending_append && tbl->pos >= tbl->data.rows.size()) {
return util::Error{5026, 0, "no current record", ""};
}
const std::size_t idx = col_index_ci(*tbl, field_name);

const std::size_t n = tbl->data.columns.size();
if (!tbl->row_dirty && !tbl->pending_append) {
// Seed staging from the current row so unchanged columns survive UPDATE.
tbl->staging_row.assign(n, std::string{});
tbl->staging_nulls.assign(n, true);
if (tbl->pos < tbl->data.rows.size()) {
const auto& row = tbl->data.rows[tbl->pos];
for (std::size_t i = 0; i < n && i < row.size(); ++i) {
tbl->staging_row[i] = row[i].value;
tbl->staging_nulls[i] = row[i].is_null;
}
}
}
if (tbl->staging_row.size() < n) {
tbl->staging_row.resize(n);
tbl->staging_nulls.resize(n, true);
}
tbl->staging_row[idx] = value;
tbl->staging_nulls[idx] = false;
tbl->row_dirty = true;
return util::Result<void>{};
}

util::Result<void> MssqlConnection::flush_record(MssqlTable* tbl) {
if (!valid() || tbl == nullptr) {
return util::Error{5001, 0, "invalid mssql flush", ""};
}
if (!tbl->row_dirty && !tbl->pending_append) return util::Result<void>{};
const std::size_t n = tbl->data.columns.size();

if (tbl->pending_append) {
std::string cols, vals;
bool any = false;
for (std::size_t i = 0; i < n; ++i) {
if (i < tbl->staging_nulls.size() && tbl->staging_nulls[i]) continue;
if (any) { cols += ", "; vals += ", "; }
cols += quote_ident(tbl->data.columns[i].name);
vals += quote_lit(tbl->staging_row[i]);
any = true;
}
if (!any) {
return util::Error{5001, 0, "insert has no columns", tbl->table_name};
}
const std::string sqlq = "INSERT INTO " + quote_ident(tbl->table_name) +
" (" + cols + ") VALUES (" + vals + ")";
auto r = query(sqlq);
if (!r) return r.error();
if (!r.value().ok) {
return util::Error{static_cast<std::int32_t>(r.value().error_number),
0, r.value().message, sqlq};
}
tbl->pending_append = false;
tbl->row_dirty = false;
return refetch(*this, tbl);
}

// UPDATE the current row, keyed on its primary key.
if (tbl->pk_cols.empty()) {
return util::Error{5004, 0, "mssql update requires a primary key",
tbl->table_name};
}
if (tbl->pos >= tbl->data.rows.size()) {
return util::Error{5026, 0, "no current record", ""};
}
const std::string where = pk_where(*tbl, tbl->data.rows[tbl->pos]);
std::vector<bool> is_pk(n, false);
for (std::size_t i : tbl->pk_cols) if (i < n) is_pk[i] = true;
std::string setc;
bool any = false;
for (std::size_t i = 0; i < n; ++i) {
if (is_pk[i] || i >= tbl->staging_row.size()) continue;
if (any) setc += ", ";
setc += quote_ident(tbl->data.columns[i].name) + " = " +
((i < tbl->staging_nulls.size() && tbl->staging_nulls[i])
? std::string("NULL")
: quote_lit(tbl->staging_row[i]));
any = true;
}
if (!any) { tbl->row_dirty = false; return refetch(*this, tbl); }
const std::string sqlq = "UPDATE " + quote_ident(tbl->table_name) +
" SET " + setc + " WHERE " + where;
auto r = query(sqlq);
if (!r) return r.error();
if (!r.value().ok) {
return util::Error{static_cast<std::int32_t>(r.value().error_number),
0, r.value().message, sqlq};
}
tbl->row_dirty = false;
return refetch(*this, tbl);
}
Comment on lines +245 to +311

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Update flush_record to extract the primary key values of the record being inserted or updated, and pass them to refetch so that the cursor position is correctly preserved after the write operation.

util::Result<void> MssqlConnection::flush_record(MssqlTable* tbl) {
    if (!valid() || tbl == nullptr) {
        return util::Error{5001, 0, "invalid mssql flush", ""};
    }
    if (!tbl->row_dirty && !tbl->pending_append) return util::Result<void>{};
    const std::size_t n = tbl->data.columns.size();

    std::vector<std::string> target_pk;
    if (tbl->pending_append) {
        for (std::size_t idx : tbl->pk_cols) {
            if (idx < tbl->staging_row.size()) {
                target_pk.push_back(tbl->staging_row[idx]);
            } else {
                target_pk.push_back("");
            }
        }

        std::string cols, vals;
        bool any = false;
        for (std::size_t i = 0; i < n; ++i) {
            if (i < tbl->staging_nulls.size() && tbl->staging_nulls[i]) continue;
            if (any) { cols += ", "; vals += ", "; }
            cols += quote_ident(tbl->data.columns[i].name);
            vals += quote_lit(tbl->staging_row[i]);
            any = true;
        }
        if (!any) {
            return util::Error{5001, 0, "insert has no columns", tbl->table_name};
        }
        const std::string sqlq = "INSERT INTO " + quote_ident(tbl->table_name) +
                                 " (" + cols + ") VALUES (" + vals + ")";
        auto r = query(sqlq);
        if (!r) return r.error();
        if (!r.value().ok) {
            return util::Error{static_cast<std::int32_t>(r.value().error_number),
                               0, r.value().message, sqlq};
        }
        tbl->pending_append = false;
        tbl->row_dirty      = false;
        return refetch(*this, tbl, target_pk);
    }

    // UPDATE the current row, keyed on its primary key.
    if (tbl->pk_cols.empty()) {
        return util::Error{5004, 0, "mssql update requires a primary key",
                           tbl->table_name};
    }
    if (tbl->pos >= tbl->data.rows.size()) {
        return util::Error{5026, 0, "no current record", ""};
    }
    for (std::size_t idx : tbl->pk_cols) {
        if (idx < tbl->data.rows[tbl->pos].size()) {
            target_pk.push_back(tbl->data.rows[tbl->pos][idx].value);
        } else {
            target_pk.push_back("");
        }
    }
    const std::string where = pk_where(*tbl, tbl->data.rows[tbl->pos]);
    std::vector<bool> is_pk(n, false);
    for (std::size_t i : tbl->pk_cols) if (i < n) is_pk[i] = true;
    std::string setc;
    bool any = false;
    for (std::size_t i = 0; i < n; ++i) {
        if (is_pk[i] || i >= tbl->staging_row.size()) continue;
        if (any) setc += ", ";
        setc += quote_ident(tbl->data.columns[i].name) + " = " +
                ((i < tbl->staging_nulls.size() && tbl->staging_nulls[i])
                     ? std::string("NULL")
                     : quote_lit(tbl->staging_row[i]));
        any = true;
    }
    if (!any) { tbl->row_dirty = false; return refetch(*this, tbl, target_pk); }
    const std::string sqlq = "UPDATE " + quote_ident(tbl->table_name) +
                             " SET " + setc + " WHERE " + where;
    auto r = query(sqlq);
    if (!r) return r.error();
    if (!r.value().ok) {
        return util::Error{static_cast<std::int32_t>(r.value().error_number),
                           0, r.value().message, sqlq};
    }
    tbl->row_dirty = false;
    return refetch(*this, tbl, target_pk);
}


util::Result<void> MssqlConnection::delete_record(MssqlTable* tbl) {
if (!valid() || tbl == nullptr) {
return util::Error{5001, 0, "invalid mssql delete", ""};
}
if (tbl->pending_append) {
return util::Error{5026, 0, "no current record", ""};
}
if (tbl->pk_cols.empty()) {
return util::Error{5004, 0, "mssql delete requires a primary key",
tbl->table_name};
}
if (tbl->pos >= tbl->data.rows.size()) {
return util::Error{5026, 0, "no current record", ""};
}
const std::string sqlq = "DELETE FROM " + quote_ident(tbl->table_name) +
" WHERE " + pk_where(*tbl, tbl->data.rows[tbl->pos]);
auto r = query(sqlq);
if (!r) return r.error();
if (!r.value().ok) {
return util::Error{static_cast<std::int32_t>(r.value().error_number),
0, r.value().message, sqlq};
}
tbl->row_dirty = false;
tbl->pending_append = false;
return refetch(*this, tbl);
}

} // namespace openads::sql_backend

#endif // defined(OPENADS_WITH_MSSQL)
15 changes: 14 additions & 1 deletion src/sql_backend/mssql_connection.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@

namespace openads::sql_backend {

struct MssqlUri; // sql_backend/mssql_uri.h
struct MssqlUri; // sql_backend/mssql_uri.h
struct MssqlTable; // sql_backend/mssql_table.h

class MssqlConnection {
public:
Expand All @@ -44,6 +45,18 @@ class MssqlConnection {
// SQL text is backend-generated; NEVER put secrets or credentials in sql.
util::Result<tds::QueryResult> query(const std::string& sql);

// Navigational write (mirrors the other SQL backends): append_blank stages
// a blank row, set_field stages one column, flush_record emits an INSERT
// (pending_append) or a PK-keyed UPDATE, delete_record a PK-keyed DELETE.
// The result set is re-fetched after each write so navigation/count stay
// consistent. Requires the table to have a discovered primary key.
util::Result<void> append_blank(MssqlTable* tbl);
util::Result<void> set_field(MssqlTable* tbl,
const std::string& field_name,
const std::string& value);
util::Result<void> flush_record(MssqlTable* tbl);
util::Result<void> delete_record(MssqlTable* tbl);

private:
struct Impl;
std::unique_ptr<Impl> impl_;
Expand Down
37 changes: 36 additions & 1 deletion src/sql_backend/mssql_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "sql_backend/sql_common.h"

#include <algorithm>
#include <cctype>
#include <limits>
#include <string>

Expand Down Expand Up @@ -158,7 +159,41 @@ MssqlTable::open(MssqlConnection& c, const std::string& table_name) {
result.message, sql};
}

return from_result(std::move(result));
auto t = from_result(std::move(result));
t->conn = &c;
t->table_name = table_name;

// Discover primary-key columns (best-effort; only writes need them).
// table_name passed is_safe_identifier above, so it is safe to inline.
auto ci_equal = [](const std::string& a, const std::string& b) {
if (a.size() != b.size()) return false;
for (std::size_t i = 0; i < a.size(); ++i) {
if (std::tolower(static_cast<unsigned char>(a[i])) !=
std::tolower(static_cast<unsigned char>(b[i]))) return false;
}
return true;
};
const std::string pk_sql =
"SELECT kcu.COLUMN_NAME "
"FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS tc "
"JOIN INFORMATION_SCHEMA.KEY_COLUMN_USAGE kcu "
" ON tc.CONSTRAINT_NAME = kcu.CONSTRAINT_NAME "
" AND tc.TABLE_NAME = kcu.TABLE_NAME "
"WHERE tc.CONSTRAINT_TYPE = 'PRIMARY KEY' "
" AND tc.TABLE_NAME = '" + table_name + "' "
"ORDER BY kcu.ORDINAL_POSITION";
Comment on lines +176 to +184

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The primary key discovery query directly concatenates table_name into the SQL string. If table_name contains single quotes, this will cause a SQL syntax error or potential SQL injection. Additionally, if table_name is schema-qualified (e.g., dbo.CLIENTES), the query will fail to find the primary key because INFORMATION_SCHEMA.TABLE_CONSTRAINTS.TABLE_NAME does not include the schema. We should escape single quotes and use SQL Server's built-in PARSENAME function to correctly handle schema-qualified and bracketed table names.

    std::string escaped_table_name;
    for (char c : table_name) {
        if (c == '\'') escaped_table_name += "''";
        else escaped_table_name += c;
    }
    const std::string pk_sql =
        "SELECT kcu.COLUMN_NAME "
        "FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS tc "
        "JOIN INFORMATION_SCHEMA.KEY_COLUMN_USAGE kcu "
        "  ON tc.CONSTRAINT_NAME = kcu.CONSTRAINT_NAME "
        " AND tc.TABLE_NAME = kcu.TABLE_NAME "
        "WHERE tc.CONSTRAINT_TYPE = 'PRIMARY KEY' "
        "  AND tc.TABLE_NAME = COALESCE(PARSENAME('" + escaped_table_name + "', 1), '" + escaped_table_name + "') "
        "  AND (PARSENAME('" + escaped_table_name + "', 2) IS NULL OR tc.TABLE_SCHEMA = PARSENAME('" + escaped_table_name + "', 2)) "
        "ORDER BY kcu.ORDINAL_POSITION";

if (auto pk = c.query(pk_sql); pk && pk.value().ok) {
for (const auto& row : pk.value().rows) {
if (row.empty()) continue;
for (std::size_t i = 0; i < t->data.columns.size(); ++i) {
if (ci_equal(t->data.columns[i].name, row[0].value)) {
t->pk_cols.push_back(i);
break;
}
}
}
}
return t;
}

std::unique_ptr<MssqlTable> MssqlTable::from_result(tds::QueryResult qr) {
Expand Down
Loading
Loading