Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "@query-doctor/sqlcommenter-drizzle",
"version": "0.5.0",
"version": "0.6.0",
"description": "SQLCommenter patch for drizzle-orm",
"main": "dist/cjs/index.js",
"type": "module",
Expand Down
128 changes: 86 additions & 42 deletions nodejs/sqlcommenter-nodejs/packages/sqlcommenter-drizzle/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,91 @@ function patchImmediateMethod(target: Function, thisArg: unknown, args: any[]) {

const DRIZZLE_ORM_MODE_METHODS = ["findFirst", "findMany"] as const;

const CRUD_METHODS = [
"select",
"selectDistinct",
"selectDistinctOn",
"insert",
"update",
"delete",
] as const;

// Marks a db/transaction object whose query methods we've already wrapped, so re-patching
// (e.g. a double `patchDrizzle` call) is a no-op.
const PATCHED_METHODS = Symbol("sqlcommenter-drizzle.patched-methods");

type QueryMethodHost = {
execute?: unknown;
transaction?: unknown;
query?: Record<string, Record<string, unknown>>;
[key: string]: unknown;
};

/**
* Wraps the query-building methods on a drizzle db — or a transaction handle — so the caller is
* captured for every query built through it.
*
* `db.transaction(cb)` hands `cb` a fresh `tx` object whose methods are NOT the ones patched on
* the top-level db, so queries built inside a transaction would otherwise lose their
* `file`/`func_name` tags (only `db_driver`, added in `prepareQuery`, would survive). We wrap the
* transaction callback and recursively patch the `tx` — including any nested savepoint `tx` — the
* same way.
*/
function patchQueryMethods(target: QueryMethodHost) {
const guard = target as unknown as Record<symbol, boolean>;
if (!target || typeof target !== "object" || guard[PATCHED_METHODS]) {
return;
}
guard[PATCHED_METHODS] = true;

if (typeof target.execute === "function") {
target.execute = new Proxy(target.execute, {
apply: (fn, thisArg, args) => patchImmediateMethod(fn, thisArg, args),
});
}
if (target.query) {
for (const key in target.query) {
const schema = target.query[key];
for (const func of DRIZZLE_ORM_MODE_METHODS) {
if (!schema || typeof schema[func] !== "function") {
continue;
}
schema[func] = new Proxy(schema[func] as Function, {
apply: (fn, thisArg, args) => patchBuilderMethod(fn, thisArg, args),
});
}
}
}
for (const method of CRUD_METHODS) {
// not all drivers have all these calls so better be safe
if (typeof target[method] !== "function") {
continue;
}
// Patching the CRUD entrypoints. The caller is captured here, when the query is built,
// because the build-time stack is the only place the user's call site is still visible —
// by the time the query executes (a microtask later) it's gone. `patchBuilderMethod` tags
// the built query so the caller is reattached for its own synchronous `prepareQuery` window.
target[method] = new Proxy(target[method] as Function, {
apply: (fn, thisArg, args) => patchBuilderMethod(fn, thisArg, args),
});
}
if (typeof target.transaction === "function") {
target.transaction = new Proxy(target.transaction, {
apply(fn, thisArg, args) {
const [callback, ...rest] = args as [unknown, ...unknown[]];
if (typeof callback !== "function") {
return Reflect.apply(fn, thisArg, args);
}
const wrapped = function (this: unknown, tx: QueryMethodHost, ...cbArgs: unknown[]) {
patchQueryMethods(tx);
return (callback as Function).apply(this, [tx, ...cbArgs]);
};
return Reflect.apply(fn, thisArg, [wrapped, ...rest]);
},
});
}
}

export function patchDrizzle<T>(
drizzle: T & {
// is this nullable?
Expand All @@ -213,48 +298,7 @@ export function patchDrizzle<T>(
} catch (e) {
console.error("Error patching driver", e);
}
const methods = [
"select",
"selectDistinct",
"selectDistinctOn",
"insert",
"update",
"delete",
] as const;
if (typeof drizzle.execute === "function") {
drizzle.execute = new Proxy(drizzle.execute, {
apply: (target, thisArg, args) =>
patchImmediateMethod(target, thisArg, args),
});
}
if (drizzle && "query" in drizzle && drizzle.query) {
for (const key in drizzle.query) {
for (const func of DRIZZLE_ORM_MODE_METHODS) {
const schema = drizzle.query[key as keyof typeof drizzle.query];
if (!schema[func] || typeof schema[func] !== "function") {
continue;
}
schema[func] = new Proxy(schema[func], {
apply: (target, thisArg, args) =>
patchBuilderMethod(target, thisArg, args),
});
}
}
}
for (const method of methods) {
// not all drivers have all these calls so better be safe
if (!drizzle[method] || typeof drizzle[method] !== "function") {
continue;
}
// Patching the CRUD entrypoints. The caller is captured here, when the query is built,
// because the build-time stack is the only place the user's call site is still visible —
// by the time the query executes (a microtask later) it's gone. `patchBuilderMethod` tags
// the built query so the caller is reattached for its own synchronous `prepareQuery` window.
drizzle[method] = new Proxy(drizzle[method], {
apply: (target, thisArg, args) =>
patchBuilderMethod(target, thisArg, args),
});
}
patchQueryMethods(drizzle as QueryMethodHost);
return drizzle;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import { test } from "node:test";
import assert from "node:assert";
import { pgTable, serial, text } from "drizzle-orm/pg-core";
import { eq } from "drizzle-orm";
import { drizzle } from "drizzle-orm/pglite";
import { patchDrizzle } from "../src/index.js";

const t = pgTable("t", {
id: serial("id").primaryKey(),
name: text("name"),
});
const u = pgTable("u", {
id: serial("id").primaryKey(),
name: text("name"),
});

function tag(sql: string, key: string): string | undefined {
const match = sql.match(new RegExp(`${key}='([^']*)'`));
return match ? decodeURIComponent(match[1]) : undefined;
}

async function setupLoggedDb() {
const logged: string[] = [];
const db = patchDrizzle(
drizzle({
schema: { t, u },
logger: { logQuery: (query) => logged.push(query) },
}),
);
await db.$client.exec(
"CREATE TABLE t (id serial primary key, name text); CREATE TABLE u (id serial primary key, name text);",
);
return { db, logged };
}

// `patchDrizzle` patches the top-level db, but `db.transaction(cb)` hands `cb` a fresh `tx` whose
// methods are unpatched — so without wrapping the transaction, queries built inside it lose their
// `file` tag (only `db_driver`, added in prepareQuery, would survive).
test("queries inside a transaction still get a file tag", async () => {
const { db, logged } = await setupLoggedDb();
await db.transaction(async (tx) => {
await tx.insert(t).values({ name: "a" });
});

const sql = logged.find((q) => q.includes('into "t"'))!;
assert.match(
tag(sql, "file") ?? "",
/:\d+:\d+$/,
"file must be captured inside a transaction",
);
// The direct caller here is an anonymous transaction arrow, so there is no symbol.
assert.strictEqual(tag(sql, "func_name"), undefined);
});

test("a named transaction callback carries its func_name", async () => {
const { db, logged } = await setupLoggedDb();
async function persistThing(tx: Parameters<Parameters<typeof db.transaction>[0]>[0]) {
await tx.insert(t).values({ name: "a" });
}
await db.transaction(persistThing);

const sql = logged.find((q) => q.includes('into "t"'))!;
assert.ok(tag(sql, "file"), "file is always captured");
assert.strictEqual(tag(sql, "func_name"), "persistThing");
});

test("nested (savepoint) transactions are tagged too", async () => {
const { db, logged } = await setupLoggedDb();
await db.transaction(async (tx) => {
await tx.transaction(async (tx2) => {
await tx2.insert(u).values({ name: "n" });
});
});

const sql = logged.find((q) => q.includes('into "u"'))!;
assert.ok(
/:\d+:\d+$/.test(tag(sql, "file") ?? ""),
"file must be captured inside a nested transaction",
);
});

test("wrapping the transaction preserves commit semantics", async () => {
const { db } = await setupLoggedDb();
await db.transaction(async (tx) => {
await tx.insert(t).values({ name: "committed" });
});
const rows = await db.select().from(t).where(eq(t.name, "committed"));
assert.strictEqual(rows.length, 1);
});

test("wrapping the transaction preserves rollback semantics", async () => {
const { db } = await setupLoggedDb();
await assert.rejects(
db.transaction(async (tx) => {
await tx.insert(t).values({ name: "rolledback" });
throw new Error("boom");
}),
/boom/,
);
const rows = await db.select().from(t).where(eq(t.name, "rolledback"));
assert.strictEqual(rows.length, 0, "the errored transaction must roll back");
});

// Named callbacks passed straight to `transaction` give each concurrent tx a distinct symbol,
// so this asserts the per-query caller isn't clobbered across concurrent transactions.
test("concurrent transactions each keep their own caller", async () => {
const { db, logged } = await setupLoggedDb();
async function txIntoT(tx: Parameters<Parameters<typeof db.transaction>[0]>[0]) {
await tx.insert(t).values({ name: "A" });
}
async function txIntoU(tx: Parameters<Parameters<typeof db.transaction>[0]>[0]) {
await tx.insert(u).values({ name: "B" });
}
await Promise.all([db.transaction(txIntoT), db.transaction(txIntoU)]);

const tSql = logged.find((q) => q.includes('into "t"'))!;
const uSql = logged.find((q) => q.includes('into "u"'))!;
assert.strictEqual(tag(tSql, "func_name"), "txIntoT");
assert.strictEqual(tag(uSql, "func_name"), "txIntoU");
});
Loading