Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions internal/compiler/compile.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package compiler

import (
"context"
"errors"
"fmt"
"io"
Expand Down Expand Up @@ -39,6 +40,13 @@ func (c *Compiler) parseCatalog(schemas []string) error {
}
contents := migrations.RemoveRollbackStatements(string(blob))
c.schema = append(c.schema, contents)

// In accurate mode, we only need to collect schema files for migrations
// but don't build the internal catalog from them
if c.accurateMode {
continue
}

stmts, err := c.parser.Parse(strings.NewReader(contents))
if err != nil {
merr.Add(filename, contents, 0, err)
Expand All @@ -58,6 +66,15 @@ func (c *Compiler) parseCatalog(schemas []string) error {
}

func (c *Compiler) parseQueries(o opts.Parser) (*Result, error) {
ctx := context.Background()

// In accurate mode, initialize the database connection pool before parsing queries
if c.accurateMode && c.pgAnalyzer != nil {
if err := c.pgAnalyzer.EnsurePool(ctx, c.schema); err != nil {
return nil, fmt.Errorf("failed to initialize database connection: %w", err)
}
}

var q []*Query
merr := multierr.New()
set := map[string]struct{}{}
Expand Down Expand Up @@ -113,6 +130,18 @@ func (c *Compiler) parseQueries(o opts.Parser) (*Result, error) {
if len(q) == 0 {
return nil, fmt.Errorf("no queries contained in paths %s", strings.Join(c.conf.Queries, ","))
}

// In accurate mode, build the catalog from the database after parsing all queries
if c.accurateMode && c.pgAnalyzer != nil {
// Default to "public" schema if no specific schemas are specified
schemas := []string{"public"}
cat, err := c.pgAnalyzer.IntrospectSchema(ctx, schemas)
if err != nil {
return nil, fmt.Errorf("failed to introspect database schema: %w", err)
}
c.catalog = cat
}

return &Result{
Catalog: c.catalog,
Queries: q,
Expand Down
39 changes: 37 additions & 2 deletions internal/compiler/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
sqliteanalyze "github.com/sqlc-dev/sqlc/internal/engine/sqlite/analyzer"
"github.com/sqlc-dev/sqlc/internal/opts"
"github.com/sqlc-dev/sqlc/internal/sql/catalog"
"github.com/sqlc-dev/sqlc/internal/x/expander"
)

type Compiler struct {
Expand All @@ -27,6 +28,15 @@ type Compiler struct {
selector selector

schema []string

// accurateMode indicates that the compiler should use database-only analysis
// and skip building the internal catalog from schema files
accurateMode bool
// pgAnalyzer is the PostgreSQL-specific analyzer used in accurate mode
// for schema introspection
pgAnalyzer *pganalyze.Analyzer
// expander is used to expand SELECT * and RETURNING * in accurate mode
expander *expander.Expander
}

func NewCompiler(conf config.SQL, combo config.CombinedSettings) (*Compiler, error) {
Expand All @@ -37,6 +47,9 @@ func NewCompiler(conf config.SQL, combo config.CombinedSettings) (*Compiler, err
c.client = client
}

// Check for accurate mode
accurateMode := conf.Analyzer.Accurate != nil && *conf.Analyzer.Accurate

switch conf.Engine {
case config.EngineSQLite:
c.parser = sqlite.NewParser()
Expand All @@ -56,10 +69,32 @@ func NewCompiler(conf config.SQL, combo config.CombinedSettings) (*Compiler, err
c.catalog = dolphin.NewCatalog()
c.selector = newDefaultSelector()
case config.EnginePostgreSQL:
c.parser = postgresql.NewParser()
parser := postgresql.NewParser()
c.parser = parser
c.catalog = postgresql.NewCatalog()
c.selector = newDefaultSelector()
if conf.Database != nil {

if accurateMode {
// Accurate mode requires a database connection
if conf.Database == nil {
return nil, fmt.Errorf("accurate mode requires database configuration")
}
if conf.Database.URI == "" && !conf.Database.Managed {
return nil, fmt.Errorf("accurate mode requires database.uri or database.managed")
}
c.accurateMode = true
// Create the PostgreSQL analyzer for schema introspection
c.pgAnalyzer = pganalyze.New(c.client, *conf.Database)
// Use the analyzer wrapped with cache for query analysis
c.analyzer = analyzer.Cached(
c.pgAnalyzer,
combo.Global,
*conf.Database,
)
// Create the expander using the pgAnalyzer as the column getter
// The parser implements both Parser and format.Dialect interfaces
c.expander = expander.New(c.pgAnalyzer, parser, parser)
} else if conf.Database != nil {
if conf.Analyzer.Database == nil || *conf.Analyzer.Database {
c.analyzer = analyzer.Cached(
pganalyze.New(c.client, *conf.Database),
Expand Down
51 changes: 50 additions & 1 deletion internal/compiler/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,56 @@ func (c *Compiler) parseQuery(stmt ast.Node, src string, o opts.Parser) (*Query,
}

var anlys *analysis
if c.analyzer != nil {
if c.accurateMode && c.expander != nil {
// In accurate mode, use the expander for star expansion
// and rely entirely on the database analyzer for type resolution
expandedQuery, err := c.expander.Expand(ctx, rawSQL)
if err != nil {
return nil, fmt.Errorf("star expansion failed: %w", err)
}

// Parse named parameters from the expanded query
expandedStmts, err := c.parser.Parse(strings.NewReader(expandedQuery))
if err != nil {
return nil, fmt.Errorf("parsing expanded query failed: %w", err)
}
if len(expandedStmts) == 0 {
return nil, errors.New("no statements in expanded query")
}
expandedRaw := expandedStmts[0].Raw

// Use the analyzer to get type information from the database
result, err := c.analyzer.Analyze(ctx, expandedRaw, expandedQuery, c.schema, nil)
if err != nil {
return nil, err
}

// Convert the analyzer result to the internal analysis format
var cols []*Column
for _, col := range result.Columns {
cols = append(cols, convertColumn(col))
}
var params []Parameter
for _, p := range result.Params {
params = append(params, Parameter{
Number: int(p.Number),
Column: convertColumn(p.Column),
})
}

// Determine the insert table if applicable
var table *ast.TableName
if insert, ok := expandedRaw.Stmt.(*ast.InsertStmt); ok {
table, _ = ParseTableName(insert.Relation)
}

anlys = &analysis{
Table: table,
Columns: cols,
Parameters: params,
Query: expandedQuery,
}
} else if c.analyzer != nil {
inference, _ := c.inferQuery(raw, rawSQL)
if inference == nil {
inference = &analysis{}
Expand Down
1 change: 1 addition & 0 deletions internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ type SQL struct {

type Analyzer struct {
Database *bool `json:"database" yaml:"database"`
Accurate *bool `json:"accurate" yaml:"accurate"`
}

// TODO: Figure out a better name for this
Expand Down
3 changes: 3 additions & 0 deletions internal/config/v_one.json
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@
"properties": {
"database": {
"type": "boolean"
},
"accurate": {
"type": "boolean"
}
}
},
Expand Down
3 changes: 3 additions & 0 deletions internal/config/v_two.json
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@
"properties": {
"database": {
"type": "boolean"
},
"accurate": {
"type": "boolean"
}
}
},
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"contexts": ["managed-db"]
}
31 changes: 31 additions & 0 deletions internal/endtoend/testdata/accurate_cte/postgresql/stdlib/go/db.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
-- name: ListExpensiveProducts :many
WITH expensive AS (
SELECT * FROM products WHERE price > 100
)
SELECT * FROM expensive;

-- name: GetProductStats :one
WITH product_stats AS (
SELECT COUNT(*) as total, AVG(price) as avg_price FROM products
)
SELECT * FROM product_stats;
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
CREATE TABLE products (
id SERIAL PRIMARY KEY,
name TEXT NOT NULL,
price NUMERIC(10,2) NOT NULL
);
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
version: "2"
sql:
- engine: postgresql
schema: "schema.sql"
queries: "query.sql"
analyzer:
accurate: true
gen:
go:
package: "querytest"
out: "go"
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"contexts": ["managed-db"]
}

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading