Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Better implementation
  • Loading branch information
andrewmbenton committed Dec 21, 2023
commit c5c6a7a0acb4c9db1eec6886b44db4e6b96ff86f
13 changes: 6 additions & 7 deletions internal/compiler/analyze.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (

analyzer "github.com/sqlc-dev/sqlc/internal/analysis"
"github.com/sqlc-dev/sqlc/internal/config"
"github.com/sqlc-dev/sqlc/internal/metadata"
"github.com/sqlc-dev/sqlc/internal/source"
"github.com/sqlc-dev/sqlc/internal/sql/ast"
"github.com/sqlc-dev/sqlc/internal/sql/named"
Expand Down Expand Up @@ -107,15 +106,15 @@ func combineAnalysis(prev *analysis, a *analyzer.Analysis) *analysis {
return prev
}

func (c *Compiler) analyzeQuery(raw *ast.RawStmt, query string, paramAnnotations map[string]metadata.ParamMetadata) (*analysis, error) {
return c._analyzeQuery(raw, query, paramAnnotations, true)
func (c *Compiler) analyzeQuery(raw *ast.RawStmt, query string) (*analysis, error) {
return c._analyzeQuery(raw, query, true)
}

func (c *Compiler) inferQuery(raw *ast.RawStmt, query string, paramAnnotations map[string]metadata.ParamMetadata) (*analysis, error) {
return c._analyzeQuery(raw, query, paramAnnotations, false)
func (c *Compiler) inferQuery(raw *ast.RawStmt, query string) (*analysis, error) {
return c._analyzeQuery(raw, query, false)
}

func (c *Compiler) _analyzeQuery(raw *ast.RawStmt, query string, paramAnnotations map[string]metadata.ParamMetadata, failfast bool) (*analysis, error) {
func (c *Compiler) _analyzeQuery(raw *ast.RawStmt, query string, failfast bool) (*analysis, error) {
errors := make([]error, 0)
check := func(err error) error {
if failfast {
Expand Down Expand Up @@ -174,7 +173,7 @@ func (c *Compiler) _analyzeQuery(raw *ast.RawStmt, query string, paramAnnotation
return nil, err
}

params, err := c.resolveCatalogRefs(qc, rvs, refs, namedParams, embeds, paramAnnotations)
params, err := c.resolveCatalogRefs(qc, rvs, refs, namedParams, embeds)
if err := check(err); err != nil {
return nil, err
}
Expand Down
20 changes: 18 additions & 2 deletions internal/compiler/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ func (c *Compiler) parseQuery(stmt ast.Node, src string, o opts.Parser) (*Query,

var anlys *analysis
if c.analyzer != nil {
inference, _ := c.inferQuery(raw, rawSQL, md.Params)
inference, _ := c.inferQuery(raw, rawSQL)
if inference == nil {
inference = &analysis{}
}
Expand Down Expand Up @@ -100,12 +100,28 @@ func (c *Compiler) parseQuery(stmt ast.Node, src string, o opts.Parser) (*Query,
// FOOTGUN: combineAnalysis mutates inference
anlys = combineAnalysis(inference, result)
} else {
anlys, err = c.analyzeQuery(raw, rawSQL, md.Params)
anlys, err = c.analyzeQuery(raw, rawSQL)
if err != nil {
return nil, err
}
}

// Override the inferrerd type and nullability of annotated named params
for i, param := range anlys.Parameters {
if !param.Column.IsNamedParam {
continue
}
if paramMetadata, ok := md.Params[param.Column.Name]; ok {
anlys.Parameters[i].Column.DataType = paramMetadata.DatabaseType
switch paramMetadata.Nullability {
case metadata.ParamNullabilityForceNotNull:
anlys.Parameters[i].Column.NotNull = true
case metadata.ParamNullabilityForceNullable:
anlys.Parameters[i].Column.NotNull = false
}
}
}

expanded := anlys.Query

// If the query string was edited, make sure the syntax is valid
Expand Down
16 changes: 1 addition & 15 deletions internal/compiler/resolve.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"log/slog"
"strconv"

"github.com/sqlc-dev/sqlc/internal/metadata"
"github.com/sqlc-dev/sqlc/internal/sql/ast"
"github.com/sqlc-dev/sqlc/internal/sql/astutils"
"github.com/sqlc-dev/sqlc/internal/sql/catalog"
Expand All @@ -22,7 +21,7 @@ func dataType(n *ast.TypeName) string {
}
}

func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, args []paramRef, params *named.ParamSet, embeds rewrite.EmbedSet, paramAnnotations map[string]metadata.ParamMetadata) ([]Parameter, error) {
func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, args []paramRef, params *named.ParamSet, embeds rewrite.EmbedSet) ([]Parameter, error) {
c := comp.catalog

aliasMap := map[string]*ast.TableName{}
Expand Down Expand Up @@ -620,18 +619,5 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar,
})
}
}

// Override the inferrerd type and nullability of annotated named params
for i, param := range a {
if param.Column.IsNamedParam {
if md, ok := paramAnnotations[param.Column.Name]; ok {
a[i].Column.DataType = md.DatabaseType
if md.ForceNotNull != nil {
a[i].Column.NotNull = *md.ForceNotNull
}
}
}
}

return a, nil
}
31 changes: 16 additions & 15 deletions internal/metadata/meta.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,17 @@ const (

type ParamMetadata struct {
DatabaseType string
// unset => nil, "!" => true, "?" => false
ForceNotNull *bool
Nullability ParamNullability
}

type ParamNullability int

const (
ParamNullabilityUnspecified ParamNullability = iota // ""
ParamNullabilityForceNotNull // "!"
ParamNullabilityForceNullable // "?"
)

// A query name must be a valid Go identifier
//
// https://golang.org/ref/spec#Identifiers
Expand Down Expand Up @@ -143,24 +150,18 @@ func ParseParamsAndFlags(comments []string) (map[string]ParamMetadata, map[strin
paramToken := s.Text()
rest = append(rest, paramToken)
}
var hasSuffix, suffixValue bool
switch name[len(name)-1] {
case '!':
var nullability ParamNullability
switch {
case strings.HasSuffix(name, "!"):
name = name[:len(name)-1]
hasSuffix = true
suffixValue = true
case '?':
nullability = ParamNullabilityForceNotNull
case strings.HasSuffix(name, "?"):
name = name[:len(name)-1]
hasSuffix = true
suffixValue = false
}
var forceNotNull *bool
if hasSuffix {
forceNotNull = &suffixValue
nullability = ParamNullabilityForceNullable
}
params[name] = ParamMetadata{
DatabaseType: strings.Join(rest, " "),
ForceNotNull: forceNotNull,
Nullability: nullability,
}
default:
flags[token] = true
Expand Down