Skip to content

Commit 12fbf7c

Browse files
committed
Enable nested beforeTemplateIsBaked calls
1 parent 4f8443a commit 12fbf7c

File tree

6 files changed

+153
-48
lines changed

6 files changed

+153
-48
lines changed

src/index.ts

Lines changed: 74 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,7 @@ import type {
1616
import { Pool } from "pg"
1717
import type { Jsonifiable } from "type-fest"
1818
import type { ExecutionContext } from "ava"
19-
import { once } from "node:events"
20-
import { createBirpc } from "birpc"
19+
import { BirpcReturn, createBirpc } from "birpc"
2120
import { ExecResult } from "testcontainers"
2221
import isPlainObject from "lodash/isPlainObject"
2322

@@ -136,57 +135,86 @@ export const getTestPostgresDatabaseFactory = <
136135
}
137136

138137
let rpcCallback: (data: any) => void
139-
const rpc = createBirpc<SharedWorkerFunctions, TestWorkerFunctions>(
140-
{
141-
runBeforeTemplateIsBakedHook: async (connection, params) => {
142-
if (options?.beforeTemplateIsBaked) {
143-
const connectionDetails =
144-
mapWorkerConnectionDetailsToConnectionDetails(connection)
145-
146-
// Ignore if the pool is terminated by the shared worker
147-
// (This happens in CI for some reason even though we drain the pool first.)
148-
connectionDetails.pool.on("error", (error) => {
149-
if (
150-
error.message.includes(
151-
"terminating connection due to administrator command"
152-
)
153-
) {
154-
return
155-
}
138+
const rpc: BirpcReturn<SharedWorkerFunctions, TestWorkerFunctions> =
139+
createBirpc<SharedWorkerFunctions, TestWorkerFunctions>(
140+
{
141+
runBeforeTemplateIsBakedHook: async (connection, params) => {
142+
if (options?.beforeTemplateIsBaked) {
143+
const connectionDetails =
144+
mapWorkerConnectionDetailsToConnectionDetails(connection)
145+
146+
// Ignore if the pool is terminated by the shared worker
147+
// (This happens in CI for some reason even though we drain the pool first.)
148+
connectionDetails.pool.on("error", (error) => {
149+
if (
150+
error.message.includes(
151+
"terminating connection due to administrator command"
152+
)
153+
) {
154+
return
155+
}
156+
157+
throw error
158+
})
159+
160+
const createdNestedConnections: ConnectionDetails[] = []
161+
const hookResult = await options.beforeTemplateIsBaked({
162+
params: params as any,
163+
connection: connectionDetails,
164+
containerExec: async (command): Promise<ExecResult> =>
165+
rpc.execCommandInContainer(command),
166+
// This is what allows a consumer to get a "nested" database from within their beforeTemplateIsBaked hook
167+
beforeTemplateIsBaked: async (options) => {
168+
const { connectionDetails, beforeTemplateIsBakedResult } =
169+
await rpc.getTestDatabase({
170+
params: options.params,
171+
databaseDedupeKey: options.databaseDedupeKey,
172+
})
156173

157-
throw error
158-
})
174+
const mappedConnection =
175+
mapWorkerConnectionDetailsToConnectionDetails(
176+
connectionDetails
177+
)
159178

160-
const hookResult = await options.beforeTemplateIsBaked({
161-
params: params as any,
162-
connection: connectionDetails,
163-
containerExec: async (command): Promise<ExecResult> =>
164-
rpc.execCommandInContainer(command),
165-
})
179+
createdNestedConnections.push(mappedConnection)
166180

167-
await teardownConnection(connectionDetails)
181+
return {
182+
...mappedConnection,
183+
beforeTemplateIsBakedResult,
184+
}
185+
},
186+
})
168187

169-
if (hookResult && !isSerializable(hookResult)) {
170-
throw new TypeError(
171-
"Return value of beforeTemplateIsBaked() hook could not be serialized. Make sure it returns only JSON-serializable values."
188+
await Promise.all(
189+
createdNestedConnections.map(async (connection) => {
190+
await teardownConnection(connection)
191+
await rpc.dropDatabase(connection.database)
192+
})
172193
)
173-
}
174194

175-
return hookResult
176-
}
177-
},
178-
},
179-
{
180-
post: async (data) => {
181-
const worker = await workerPromise
182-
await worker.available
183-
worker.publish(data)
184-
},
185-
on: (data) => {
186-
rpcCallback = data
195+
await teardownConnection(connectionDetails)
196+
197+
if (hookResult && !isSerializable(hookResult)) {
198+
throw new TypeError(
199+
"Return value of beforeTemplateIsBaked() hook could not be serialized. Make sure it returns only JSON-serializable values."
200+
)
201+
}
202+
203+
return hookResult
204+
}
205+
},
187206
},
188-
}
189-
)
207+
{
208+
post: async (data) => {
209+
const worker = await workerPromise
210+
await worker.available
211+
worker.publish(data)
212+
},
213+
on: (data) => {
214+
rpcCallback = data
215+
},
216+
}
217+
)
190218

191219
// Automatically cleaned up by AVA since each test file runs in a separate worker
192220
const _messageHandlerPromise = (async () => {

src/internal-types.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,5 @@ export interface SharedWorkerFunctions {
3030
beforeTemplateIsBakedResult: unknown
3131
}>
3232
execCommandInContainer: (command: string[]) => Promise<ExecResult>
33+
dropDatabase: (databaseName: string) => Promise<void>
3334
}

src/public-types.ts

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,11 @@ export interface GetTestPostgresDatabaseFactoryOptions<
5555
connection: ConnectionDetails
5656
params: Params
5757
containerExec: (command: string[]) => Promise<ExecResult>
58+
beforeTemplateIsBaked: (
59+
options: {
60+
params: Params
61+
} & Pick<GetTestPostgresDatabaseOptions, "databaseDedupeKey">
62+
) => Promise<GetTestPostgresDatabaseResult>
5863
}) => Promise<any>
5964
}
6065

src/tests/hooks.test.ts

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import test from "ava"
22
import { getTestPostgresDatabaseFactory } from "~/index"
33
import { countDatabaseTemplates } from "./utils/count-database-templates"
4+
import { doesDatabaseExist } from "./utils/does-database-exist"
45

56
test("beforeTemplateIsBaked", async (t) => {
67
let wasHookCalled = false
@@ -145,3 +146,48 @@ test("beforeTemplateIsBaked (result isn't serializable)", async (t) => {
145146
}
146147
)
147148
})
149+
150+
test("beforeTemplateIsBaked, get nested database", async (t) => {
151+
type DatabaseParams = {
152+
type: "foo" | "bar"
153+
}
154+
155+
let nestedDatabaseName: string | undefined = undefined
156+
157+
const getTestServer = getTestPostgresDatabaseFactory<DatabaseParams>({
158+
postgresVersion: process.env.POSTGRES_VERSION,
159+
workerDedupeKey: "beforeTemplateIsBakedHookNestedDatabase",
160+
beforeTemplateIsBaked: async ({
161+
params,
162+
connection: { pool },
163+
beforeTemplateIsBaked,
164+
}) => {
165+
if (params.type === "foo") {
166+
await pool.query(`CREATE TABLE "foo" ("id" SERIAL PRIMARY KEY)`)
167+
return { createdFoo: true }
168+
}
169+
170+
await pool.query(`CREATE TABLE "bar" ("id" SERIAL PRIMARY KEY)`)
171+
const fooDatabase = await beforeTemplateIsBaked({
172+
params: { type: "foo" },
173+
})
174+
t.deepEqual(fooDatabase.beforeTemplateIsBakedResult, { createdFoo: true })
175+
176+
nestedDatabaseName = fooDatabase.database
177+
178+
await t.notThrowsAsync(async () => {
179+
await fooDatabase.pool.query(`INSERT INTO "foo" DEFAULT VALUES`)
180+
})
181+
182+
return { createdBar: true }
183+
},
184+
})
185+
186+
const database = await getTestServer(t, { type: "bar" })
187+
t.deepEqual(database.beforeTemplateIsBakedResult, { createdBar: true })
188+
189+
t.false(
190+
await doesDatabaseExist(database.pool, nestedDatabaseName!),
191+
"Nested database should have been cleaned up after the parent hook completed"
192+
)
193+
})
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import { Pool } from "pg"
2+
3+
export const doesDatabaseExist = async (pool: Pool, databaseName: string) => {
4+
const {
5+
rows: [{ count }],
6+
} = await pool.query(
7+
'SELECT COUNT(*) FROM "pg_database" WHERE "datname" = $1',
8+
[databaseName]
9+
)
10+
11+
return count > 0
12+
}

src/worker.ts

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,10 @@ export class Worker {
7373
const container = (await this.startContainerPromise).container
7474
return container.exec(command)
7575
},
76+
dropDatabase: async (databaseName) => {
77+
const { postgresClient } = await this.startContainerPromise
78+
await postgresClient.query(`DROP DATABASE ${databaseName}`)
79+
},
7680
},
7781
rpcChannel
7882
)
@@ -148,8 +152,17 @@ export class Worker {
148152
return
149153
}
150154

151-
await this.forceDisconnectClientsFrom(databaseName!)
152-
await postgresClient.query(`DROP DATABASE ${databaseName}`)
155+
try {
156+
await this.forceDisconnectClientsFrom(databaseName!)
157+
await postgresClient.query(`DROP DATABASE ${databaseName}`)
158+
} catch (error) {
159+
if ((error as Error)?.message?.includes("does not exist")) {
160+
// Database was likely a nested database and manually dropped by the test worker, ignore
161+
return
162+
}
163+
164+
throw error
165+
}
153166
})
154167

155168
return {

0 commit comments

Comments
 (0)