Skip to content

Commit 0e307b9

Browse files
committed
feat: initial db pull implementation
1 parent 9bf6d7f commit 0e307b9

File tree

11 files changed

+927
-7
lines changed

11 files changed

+927
-7
lines changed

packages/cli/package.json

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
},
4747
"devDependencies": {
4848
"@types/better-sqlite3": "^7.6.13",
49+
"@types/pg": "^8.11.11",
4950
"@types/semver": "^7.7.0",
5051
"@types/tmp": "catalog:",
5152
"@zenstackhq/eslint-config": "workspace:*",
@@ -54,6 +55,7 @@
5455
"@zenstackhq/typescript-config": "workspace:*",
5556
"@zenstackhq/vitest-config": "workspace:*",
5657
"better-sqlite3": "^12.2.0",
58+
"pg": "^8.16.3",
5759
"tmp": "catalog:"
5860
}
5961
}

packages/cli/src/actions/action-utils.ts

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,20 @@ export async function loadSchemaDocument(schemaFile: string) {
4747
return loadResult.model;
4848
}
4949

50+
export async function loadSchemaDocumentWithServices(schemaFile: string) {
51+
const loadResult = await loadDocument(schemaFile);
52+
if (!loadResult.success) {
53+
loadResult.errors.forEach((err) => {
54+
console.error(colors.red(err));
55+
});
56+
throw new CliError('Schema contains errors. See above for details.');
57+
}
58+
loadResult.warnings.forEach((warn) => {
59+
console.warn(colors.yellow(warn));
60+
});
61+
return { services: loadResult.services, model: loadResult.model };
62+
}
63+
5064
export function handleSubProcessError(err: unknown) {
5165
if (err instanceof Error && 'status' in err && typeof err.status === 'number') {
5266
process.exit(err.status);

packages/cli/src/actions/db.ts

Lines changed: 66 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,36 @@
1+
import { ZModelCodeGenerator } from '@zenstackhq/sdk';
12
import fs from 'node:fs';
23
import { execPackage } from '../utils/exec-utils';
3-
import { generateTempPrismaSchema, getSchemaFile, handleSubProcessError } from './action-utils';
4+
import { generateTempPrismaSchema, getSchemaFile, handleSubProcessError, loadSchemaDocumentWithServices } from './action-utils';
5+
import { syncEnums, syncRelation, syncTable, type Relation } from './pull';
6+
import { providers } from './pull/provider';
7+
import { getDatasource, getDbName } from './pull/utils';
48

5-
type Options = {
9+
type PushOptions = {
610
schema?: string;
711
acceptDataLoss?: boolean;
812
forceReset?: boolean;
913
};
1014

15+
type PullOptions = {
16+
schema?: string;
17+
};
18+
1119
/**
1220
* CLI action for db related commands
1321
*/
14-
export async function run(command: string, options: Options) {
22+
export async function run(command: string, options: PushOptions) {
1523
switch (command) {
1624
case 'push':
1725
await runPush(options);
1826
break;
27+
case 'pull':
28+
await runPull(options);
29+
break;
1930
}
2031
}
2132

22-
async function runPush(options: Options) {
33+
async function runPush(options: PushOptions) {
2334
// generate a temp prisma schema file
2435
const schemaFile = getSchemaFile(options.schema);
2536
const prismaSchemaFile = await generateTempPrismaSchema(schemaFile);
@@ -45,3 +56,54 @@ async function runPush(options: Options) {
4556
}
4657
}
4758
}
59+
60+
async function runPull(options: PullOptions) {
61+
const schemaFile = getSchemaFile(options.schema);
62+
const { model, services } = await loadSchemaDocumentWithServices(schemaFile);
63+
64+
const SUPPORTED_PROVIDERS = ['sqlite', 'postgresql']
65+
const datasource = getDatasource(model)
66+
67+
if (!datasource) {
68+
throw new Error('No datasource found in the schema.')
69+
}
70+
71+
if (!SUPPORTED_PROVIDERS.includes(datasource.provider)) {
72+
throw new Error(`Unsupported datasource provider: ${datasource.provider}`)
73+
}
74+
75+
const provider = providers[datasource.provider];
76+
77+
if (!provider) {
78+
throw new Error(
79+
`No introspection provider found for: ${datasource.provider}`
80+
)
81+
}
82+
83+
const { enums, tables } = await provider.introspect(datasource.url)
84+
85+
syncEnums(enums, model)
86+
87+
const resolveRelations: Relation[] = []
88+
for (const table of tables) {
89+
const relations = syncTable({ table, model, provider })
90+
resolveRelations.push(...relations)
91+
}
92+
93+
for (const rel of resolveRelations) {
94+
syncRelation(model, rel, services);
95+
}
96+
97+
for (const d of model.declarations) {
98+
if (d.$type !== 'DataModel') continue
99+
const found = tables.find((t) => getDbName(d) === t.name)
100+
if (!found) {
101+
delete (d.$container as any)[d.$containerProperty!][d.$containerIndex!]
102+
}
103+
}
104+
105+
model.declarations = model.declarations.filter((d) => d !== undefined)
106+
107+
const zmpdelSchema = await new ZModelCodeGenerator().generate(model)
108+
fs.writeFileSync(schemaFile, zmpdelSchema)
109+
}
Lines changed: 238 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,238 @@
1+
import type { ZModelServices } from '@zenstackhq/language'
2+
import type {
3+
Attribute,
4+
AttributeArg,
5+
DataField,
6+
DataFieldAttribute,
7+
DataFieldType,
8+
DataModel,
9+
Enum,
10+
EnumField,
11+
Model,
12+
UnsupportedFieldType
13+
} from '@zenstackhq/language/ast'
14+
import type { IntrospectedEnum, IntrospectedTable, IntrospectionProvider } from './provider'
15+
import { getAttributeRef, getDbName } from './utils'
16+
17+
export function syncEnums(dbEnums: IntrospectedEnum[], model: Model) {
18+
for (const dbEnum of dbEnums) {
19+
let schemaEnum = model.declarations.find(
20+
(d) => d.$type === 'Enum' && getDbName(d) === dbEnum.enum_type
21+
) as Enum | undefined
22+
23+
if (!schemaEnum) {
24+
schemaEnum = {
25+
$type: 'Enum' as const,
26+
$container: model,
27+
name: dbEnum.enum_type,
28+
attributes: [],
29+
comments: [],
30+
fields: [],
31+
}
32+
model.declarations.push(schemaEnum)
33+
}
34+
schemaEnum.fields = dbEnum.values.map((v) => {
35+
const existingValue = schemaEnum.fields.find((f) => getDbName(f) === v)
36+
if (!existingValue) {
37+
const enumField: EnumField = {
38+
$type: 'EnumField' as const,
39+
$container: schemaEnum,
40+
name: v,
41+
attributes: [],
42+
comments: [],
43+
}
44+
return enumField
45+
}
46+
return existingValue
47+
})
48+
}
49+
}
50+
51+
export type Relation = {
52+
schema: string
53+
table: string
54+
column: string
55+
type: 'one' | 'many'
56+
fk_name: string
57+
nullable: boolean
58+
references: {
59+
schema: string | null
60+
table: string | null
61+
column: string | null
62+
}
63+
}
64+
65+
export function syncTable({
66+
model,
67+
provider,
68+
table,
69+
}: {
70+
table: IntrospectedTable
71+
model: Model
72+
provider: IntrospectionProvider
73+
}) {
74+
const relations: Relation[] = []
75+
let modelTable = model.declarations.find(
76+
(d) => d.$type === 'DataModel' && getDbName(d) === table.name
77+
) as DataModel | undefined
78+
79+
if (!modelTable) {
80+
modelTable = {
81+
$type: 'DataModel' as const,
82+
$container: model,
83+
name: table.name,
84+
fields: [],
85+
attributes: [],
86+
comments: [],
87+
isView: false,
88+
mixins: [],
89+
}
90+
model.declarations.push(modelTable)
91+
}
92+
93+
modelTable.fields = table.columns.map((col) => {
94+
if (col.foreign_key_table) {
95+
relations.push({
96+
schema: table.schema,
97+
table: table.name,
98+
column: col.name,
99+
type: col.unique ? 'one' : 'many',
100+
fk_name: col.foreign_key_name!,
101+
nullable: col.nullable,
102+
references: {
103+
schema: col.foreign_key_schema,
104+
table: col.foreign_key_table,
105+
column: col.foreign_key_column,
106+
},
107+
})
108+
}
109+
110+
const fieldPrefix = /[0-9]/g.test(col.name.charAt(0)) ? '_' : ''
111+
const fieldName = `${fieldPrefix}${col.name}`
112+
113+
const existingField = modelTable!.fields.find(
114+
(f) => getDbName(f) === fieldName
115+
)
116+
if (!existingField) {
117+
const builtinType = provider.getBuiltinType(col.datatype)
118+
const unsupported: UnsupportedFieldType = {
119+
get $container() {
120+
return type
121+
},
122+
$type: 'UnsupportedFieldType' as const,
123+
value: {
124+
get $container() {
125+
return unsupported
126+
},
127+
$type: 'StringLiteral',
128+
value: col.datatype,
129+
},
130+
}
131+
132+
const type: DataFieldType = {
133+
get $container() {
134+
return field
135+
},
136+
$type: 'DataFieldType' as const,
137+
type: builtinType.type === 'Unsupported' ? undefined : builtinType.type,
138+
array: builtinType.isArray,
139+
unsupported:
140+
builtinType.type === 'Unsupported' ? unsupported : undefined,
141+
optional: col.nullable,
142+
reference: col.options.length
143+
? {
144+
$refText: col.datatype,
145+
ref: model.declarations.find(
146+
(d) => d.$type === 'Enum' && getDbName(d) === col.datatype
147+
) as Enum | undefined,
148+
}
149+
: undefined,
150+
}
151+
152+
const field: DataField = {
153+
$type: 'DataField' as const,
154+
type,
155+
$container: modelTable!,
156+
name: fieldName,
157+
get attributes() {
158+
if (fieldPrefix !== '') return []
159+
160+
const attr: DataFieldAttribute = {
161+
$type: 'DataFieldAttribute' as const,
162+
get $container() {
163+
return field
164+
},
165+
decl: {
166+
$refText: '@map',
167+
ref: model.$document?.references.find(
168+
(r) =>
169+
//@ts-ignore
170+
r.ref.$type === 'Attribute' && r.ref.name === '@map'
171+
)?.ref as Attribute,
172+
},
173+
get args() {
174+
const arg: AttributeArg = {
175+
$type: 'AttributeArg' as const,
176+
get $container() {
177+
return attr
178+
},
179+
name: 'name',
180+
$resolvedParam: {
181+
name: 'name',
182+
},
183+
get value() {
184+
return {
185+
$type: 'StringLiteral' as const,
186+
$container: arg,
187+
value: col.name,
188+
}
189+
},
190+
}
191+
192+
return [arg]
193+
},
194+
}
195+
196+
return [attr]
197+
},
198+
comments: [],
199+
}
200+
return field
201+
}
202+
return existingField
203+
})
204+
205+
return relations
206+
}
207+
208+
export function syncRelation(model: Model, relation: Relation, services: ZModelServices) {
209+
const idAttribute = getAttributeRef('@id', services)
210+
const uniqueAttribute = getAttributeRef('@unique', services)
211+
const relationAttribute = getAttributeRef('@relation', services)
212+
213+
if (!idAttribute || !uniqueAttribute || !relationAttribute) {
214+
throw new Error('Cannot find required attributes in the model.')
215+
}
216+
217+
const sourceModel = model.declarations.find(
218+
(d) => d.$type === 'DataModel' && getDbName(d) === relation.table
219+
) as DataModel | undefined
220+
if (!sourceModel) return
221+
222+
const sourceField = sourceModel.fields.find(
223+
(f) => getDbName(f) === relation.column
224+
) as DataField | undefined
225+
if (!sourceField) return
226+
227+
const targetModel = model.declarations.find(
228+
(d) => d.$type === 'DataModel' && getDbName(d) === relation.references.table
229+
) as DataModel | undefined
230+
if (!targetModel) return
231+
232+
const targetField = targetModel.fields.find(
233+
(f) => getDbName(f) === relation.references.column
234+
)
235+
if (!targetField) return
236+
237+
//TODO: Finish relation sync
238+
}
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
export * from './provider'
2+
3+
import { postgresql } from "./postgresql";
4+
import { sqlite } from "./sqlite";
5+
6+
export const providers = {
7+
postgresql,
8+
sqlite
9+
};

0 commit comments

Comments
 (0)