Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix(sqlite): handle comparison operators and sorting
  • Loading branch information
mgilbir committed Oct 6, 2025
commit 4428adf2918d338c1066fce3fd3819257de89150
188 changes: 181 additions & 7 deletions internal/engine/sqlite/convert.go
Original file line number Diff line number Diff line change
Expand Up @@ -356,21 +356,103 @@ func (c *cc) convertComparison(n *parser.Expr_comparisonContext) ast.Node {
return &ast.In{
Expr: lexpr,
List: rexprs,
Not: false,
Not: n.NOT_() != nil,
Sel: nil,
Location: n.GetStart().GetStart(),
}
}

operator := c.extractComparisonOperator(n)
rexprIdx := 1
rexpr := c.convert(n.Expr(rexprIdx))

// Special handling for IS NOT NULL where NOT NULL might be parsed as a unary expression
if operator == "IS" && len(n.AllExpr()) > 1 {
if rExpr, ok := n.Expr(1).(*parser.Expr_unaryContext); ok {
// Check if this is a NOT NULL expression by looking at the text content
text := rExpr.GetText()
if strings.ToUpper(text) == "NOTNULL" || strings.ToUpper(text) == "NOT NULL" {
operator = "IS NOT"
rexpr = &ast.A_Const{Val: &ast.Null{}}
}
}
}

return &ast.A_Expr{
Name: &ast.List{
Items: []ast.Node{
&ast.String{Str: "="}, // TODO: add actual comparison
&ast.String{Str: operator},
},
},
Lexpr: lexpr,
Rexpr: c.convert(n.Expr(1)),
Rexpr: rexpr,
}
}

func (c *cc) extractComparisonOperator(n *parser.Expr_comparisonContext) string {
switch {
case n.LT2() != nil:
return "<<"
case n.GT2() != nil:
return ">>"
case n.AMP() != nil:
return "&"
case n.PIPE() != nil:
return "|"
case n.LT_EQ() != nil:
return "<="
case n.GT_EQ() != nil:
return ">="
case n.LT() != nil:
return "<"
case n.GT() != nil:
return ">"
case n.NOT_EQ1() != nil:
return "!="
case n.NOT_EQ2() != nil:
return "<>"
case n.ASSIGN() != nil || n.EQ() != nil:
return "="
case n.IS_() != nil:
if n.NOT_() != nil {
return "IS NOT"
}
return "IS"
case n.LIKE_() != nil:
if n.NOT_() != nil {
return "NOT LIKE"
}
return "LIKE"
case n.GLOB_() != nil:
if n.NOT_() != nil {
return "NOT GLOB"
}
return "GLOB"
case n.MATCH_() != nil:
if n.NOT_() != nil {
return "NOT MATCH"
}
return "MATCH"
case n.REGEXP_() != nil:
if n.NOT_() != nil {
return "NOT REGEXP"
}
return "REGEXP"
}

var parts []string
for _, child := range n.GetChildren() {
if term, ok := child.(antlr.TerminalNode); ok {
text := strings.TrimSpace(term.GetText())
if text != "" {
parts = append(parts, text)
}
}
}
if len(parts) > 0 {
return strings.Join(parts, " ")
}
return "="
}

func (c *cc) convertMultiSelect_stmtContext(n *parser.Select_stmtContext) ast.Node {
Expand Down Expand Up @@ -514,6 +596,11 @@ func (c *cc) convertMultiSelect_stmtContext(n *parser.Select_stmtContext) ast.No
limitCount, limitOffset := c.convertLimit_stmtContext(n.Limit_stmt())
selectStmt.LimitCount = limitCount
selectStmt.LimitOffset = limitOffset
if orderBy := n.Order_by_stmt(); orderBy != nil {
if sortClause, ok := c.convertOrderby_stmtContext(orderBy).(*ast.List); ok {
selectStmt.SortClause = sortClause
}
}
selectStmt.WithClause = &ast.WithClause{Ctes: &ctes}
return selectStmt
}
Expand Down Expand Up @@ -626,10 +713,34 @@ func (c *cc) convertOrderby_stmtContext(n parser.IOrder_by_stmtContext) ast.Node
if !ok {
continue
}
list.Items = append(list.Items, &ast.CaseExpr{
Xpr: c.convert(term.Expr()),
Location: term.Expr().GetStart().GetStart(),
})

expr := c.convert(term.Expr())
sortBy := &ast.SortBy{
Node: expr,
SortbyDir: ast.SortByDirDefault,
SortbyNulls: ast.SortByNullsDefault,
UseOp: &ast.List{},
}

if ascDescCtx := term.Asc_desc(); ascDescCtx != nil {
if ascDesc, ok := ascDescCtx.(*parser.Asc_descContext); ok {
if ascDesc.DESC_() != nil {
sortBy.SortbyDir = ast.SortByDirDesc
} else if ascDesc.ASC_() != nil {
sortBy.SortbyDir = ast.SortByDirAsc
}
}
}

if term.NULLS_() != nil {
if term.FIRST_() != nil {
sortBy.SortbyNulls = ast.SortByNullsFirst
} else if term.LAST_() != nil {
sortBy.SortbyNulls = ast.SortByNullsLast
}
}

list.Items = append(list.Items, sortBy)
}
return list
}
Expand Down Expand Up @@ -1135,6 +1246,63 @@ func (c *cc) convertCase(n *parser.Expr_caseContext) ast.Node {
return e
}

func (c *cc) convertUnaryExpr(n *parser.Expr_unaryContext) ast.Node {
// Handle unary expressions like NOT NULL
children := n.GetChildren()
if len(children) >= 2 {
for i, child := range children {
if term, ok := child.(antlr.TerminalNode); ok {
if term.GetSymbol().GetTokenType() == parser.SQLiteParserNOT_ {
if i+1 < len(children) {
if nextTerm, ok := children[i+1].(antlr.TerminalNode); ok {
if nextTerm.GetSymbol().GetTokenType() == parser.SQLiteParserNULL_ {
return &ast.A_Const{Val: &ast.Null{}}
}
}
}
}
}
}
}

// For other unary expressions, try to convert the inner expression
if n.Expr() != nil {
return c.convert(n.Expr())
}

return todo("convertUnaryExpr", n)
}

func (c *cc) convertNullComparison(n *parser.Expr_null_compContext) ast.Node {
expr := c.convert(n.Expr())

var operator string
switch {
case n.ISNULL_() != nil:
operator = "IS NULL"
case n.NOTNULL_() != nil:
operator = "IS NOT NULL"
case n.NOT_() != nil && n.NULL_() != nil:
operator = "IS NOT NULL"
case n.NULL_() != nil:
operator = "IS NULL"
default:
operator = "IS NULL" // fallback
}

return &ast.A_Expr{
Name: &ast.List{
Items: []ast.Node{
&ast.String{Str: operator},
},
},
Lexpr: expr,
Rexpr: &ast.A_Const{
Val: &ast.Null{},
},
}
}

func (c *cc) convert(node node) ast.Node {
switch n := node.(type) {

Expand Down Expand Up @@ -1226,6 +1394,12 @@ func (c *cc) convert(node node) ast.Node {
case *parser.Expr_caseContext:
return c.convertCase(n)

case *parser.Expr_null_compContext:
return c.convertNullComparison(n)

case *parser.Expr_unaryContext:
return c.convertUnaryExpr(n)

default:
return todo("convert(case=default)", n)
}
Expand Down
Loading