Skip to content

Commit d449832

Browse files
authored
feat(sqlfile): Add sqlfile.Split (#4146)
Create the sqlfile.Split method that splits a .sql file into multiple statements.
1 parent 31d8958 commit d449832

File tree

50 files changed

+479
-0
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

50 files changed

+479
-0
lines changed

‎internal/sql/sqlfile/split.go‎

Lines changed: 241 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,241 @@
1+
package sqlfile
2+
3+
import (
4+
"bufio"
5+
"context"
6+
"io"
7+
"strings"
8+
)
9+
10+
// Split reads SQL queries from an io.Reader and returns them as a slice of strings.
11+
// Each SQL query is delimited by a semicolon (;).
12+
// The function handles:
13+
// - Single-line comments (-- comment)
14+
// - Multi-line comments (/* comment */)
15+
// - Single-quoted strings ('string')
16+
// - Double-quoted identifiers ("identifier")
17+
// - Dollar-quoted strings ($$string$$ or $tag$string$tag$)
18+
func Split(ctx context.Context, r io.Reader) ([]string, error) {
19+
scanner := bufio.NewScanner(r)
20+
var queries []string
21+
var currentQuery strings.Builder
22+
var inSingleQuote bool
23+
var inDoubleQuote bool
24+
var inDollarQuote bool
25+
var dollarTag string
26+
var inMultiLineComment bool
27+
28+
for scanner.Scan() {
29+
// Check context cancellation
30+
select {
31+
case <-ctx.Done():
32+
return nil, ctx.Err()
33+
default:
34+
}
35+
36+
line := scanner.Text()
37+
i := 0
38+
lineLen := len(line)
39+
40+
for i < lineLen {
41+
ch := line[i]
42+
43+
// Handle multi-line comments
44+
if inMultiLineComment {
45+
if i+1 < lineLen && ch == '*' && line[i+1] == '/' {
46+
inMultiLineComment = false
47+
currentQuery.WriteString("*/")
48+
i += 2
49+
continue
50+
}
51+
currentQuery.WriteByte(ch)
52+
i++
53+
continue
54+
}
55+
56+
// Handle dollar-quoted strings (PostgreSQL)
57+
if inDollarQuote {
58+
if ch == '$' {
59+
// Try to match the closing tag
60+
endTag := extractDollarTag(line[i:])
61+
if endTag == dollarTag {
62+
inDollarQuote = false
63+
currentQuery.WriteString(endTag)
64+
i += len(endTag)
65+
continue
66+
}
67+
}
68+
currentQuery.WriteByte(ch)
69+
i++
70+
continue
71+
}
72+
73+
// Handle single-quoted strings
74+
if inSingleQuote {
75+
currentQuery.WriteByte(ch)
76+
if ch == '\'' {
77+
// Check for escaped quote ''
78+
if i+1 < lineLen && line[i+1] == '\'' {
79+
currentQuery.WriteByte('\'')
80+
i += 2
81+
continue
82+
}
83+
inSingleQuote = false
84+
}
85+
i++
86+
continue
87+
}
88+
89+
// Handle double-quoted identifiers
90+
if inDoubleQuote {
91+
currentQuery.WriteByte(ch)
92+
if ch == '"' {
93+
// Check for escaped quote ""
94+
if i+1 < lineLen && line[i+1] == '"' {
95+
currentQuery.WriteByte('"')
96+
i += 2
97+
continue
98+
}
99+
inDoubleQuote = false
100+
}
101+
i++
102+
continue
103+
}
104+
105+
// Check for single-line comment
106+
if i+1 < lineLen && ch == '-' && line[i+1] == '-' {
107+
// Rest of line is a comment
108+
currentQuery.WriteString(line[i:])
109+
break
110+
}
111+
112+
// Check for multi-line comment start
113+
if i+1 < lineLen && ch == '/' && line[i+1] == '*' {
114+
inMultiLineComment = true
115+
currentQuery.WriteString("/*")
116+
i += 2
117+
continue
118+
}
119+
120+
// Check for dollar quote start
121+
if ch == '$' {
122+
tag := extractDollarTag(line[i:])
123+
if tag != "" {
124+
inDollarQuote = true
125+
dollarTag = tag
126+
currentQuery.WriteString(tag)
127+
i += len(tag)
128+
continue
129+
}
130+
}
131+
132+
// Check for single quote
133+
if ch == '\'' {
134+
inSingleQuote = true
135+
currentQuery.WriteByte(ch)
136+
i++
137+
continue
138+
}
139+
140+
// Check for double quote
141+
if ch == '"' {
142+
inDoubleQuote = true
143+
currentQuery.WriteByte(ch)
144+
i++
145+
continue
146+
}
147+
148+
// Check for semicolon (statement terminator)
149+
if ch == ';' {
150+
currentQuery.WriteByte(ch)
151+
// Check if there's a comment after the semicolon on the same line
152+
i++
153+
if i < lineLen {
154+
// Skip whitespace
155+
for i < lineLen && (line[i] == ' ' || line[i] == '\t') {
156+
currentQuery.WriteByte(line[i])
157+
i++
158+
}
159+
// If there's a comment, include it
160+
if i+1 < lineLen && line[i] == '-' && line[i+1] == '-' {
161+
currentQuery.WriteString(line[i:])
162+
}
163+
}
164+
query := strings.TrimSpace(currentQuery.String())
165+
if query != "" && query != ";" {
166+
queries = append(queries, query)
167+
}
168+
currentQuery.Reset()
169+
break // Move to next line
170+
}
171+
172+
// Regular character
173+
currentQuery.WriteByte(ch)
174+
i++
175+
}
176+
177+
// Add newline if we're building a query
178+
if currentQuery.Len() > 0 {
179+
currentQuery.WriteByte('\n')
180+
}
181+
}
182+
183+
if err := scanner.Err(); err != nil {
184+
return nil, err
185+
}
186+
187+
// Handle any remaining query
188+
query := strings.TrimSpace(currentQuery.String())
189+
if query != "" && query != ";" {
190+
queries = append(queries, query)
191+
}
192+
193+
return queries, nil
194+
}
195+
196+
// extractDollarTag extracts a dollar-quoted string tag from the beginning of s.
197+
// Returns empty string if no valid dollar tag is found.
198+
// Valid tags: $$ or $identifier$ where identifier contains only alphanumeric and underscore.
199+
func extractDollarTag(s string) string {
200+
if len(s) == 0 || s[0] != '$' {
201+
return ""
202+
}
203+
204+
// Find the closing $
205+
for i := 1; i < len(s); i++ {
206+
if s[i] == '$' {
207+
tag := s[:i+1]
208+
// Validate tag content (only alphanumeric and underscore allowed between $)
209+
tagContent := tag[1 : len(tag)-1]
210+
if isValidDollarTagContent(tagContent) {
211+
return tag
212+
}
213+
return ""
214+
}
215+
// If we hit a character that's not allowed in a tag, it's not a dollar quote
216+
if !isValidDollarTagChar(s[i]) {
217+
return ""
218+
}
219+
}
220+
221+
return ""
222+
}
223+
224+
// isValidDollarTagContent returns true if s contains only valid characters for a dollar tag.
225+
func isValidDollarTagContent(s string) bool {
226+
if s == "" {
227+
return true // $$ is valid
228+
}
229+
for _, ch := range s {
230+
if !isValidDollarTagChar(byte(ch)) {
231+
return false
232+
}
233+
}
234+
return true
235+
}
236+
237+
// isValidDollarTagChar returns true if ch is a valid character in a dollar tag.
238+
// Valid characters are alphanumeric and underscore.
239+
func isValidDollarTagChar(ch byte) bool {
240+
return (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || (ch >= '0' && ch <= '9') || ch == '_'
241+
}

‎internal/sql/sqlfile/split_test.go‎

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
package sqlfile
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"os"
7+
"path/filepath"
8+
"strings"
9+
"testing"
10+
)
11+
12+
func TestSplit(t *testing.T) {
13+
testdataDir := "testdata"
14+
15+
entries, err := os.ReadDir(testdataDir)
16+
if err != nil {
17+
t.Fatalf("Failed to read testdata directory: %v", err)
18+
}
19+
20+
for _, entry := range entries {
21+
if !entry.IsDir() {
22+
continue
23+
}
24+
25+
testName := entry.Name()
26+
t.Run(testName, func(t *testing.T) {
27+
testDir := filepath.Join(testdataDir, testName)
28+
29+
// Read input file
30+
inputPath := filepath.Join(testDir, "input.sql")
31+
inputData, err := os.ReadFile(inputPath)
32+
if err != nil {
33+
t.Fatalf("Failed to read input file: %v", err)
34+
}
35+
36+
// Read expected output files
37+
var expected []string
38+
for i := 1; ; i++ {
39+
outputPath := filepath.Join(testDir, fmt.Sprintf("output_%d.sql", i))
40+
data, err := os.ReadFile(outputPath)
41+
if err != nil {
42+
if os.IsNotExist(err) {
43+
break
44+
}
45+
t.Fatalf("Failed to read output file %s: %v", outputPath, err)
46+
}
47+
expected = append(expected, string(data))
48+
}
49+
50+
// Run Split
51+
ctx := context.Background()
52+
reader := strings.NewReader(string(inputData))
53+
54+
got, err := Split(ctx, reader)
55+
if err != nil {
56+
t.Fatalf("Split() error = %v", err)
57+
}
58+
59+
// Compare results
60+
if len(got) != len(expected) {
61+
t.Errorf("Split() got %d queries, expected %d", len(got), len(expected))
62+
t.Logf("Got: %v", got)
63+
t.Logf("Expected: %v", expected)
64+
return
65+
}
66+
67+
for i := range got {
68+
if got[i] != expected[i] {
69+
t.Errorf("Query %d:\ngot: %q\nexpected: %q", i, got[i], expected[i])
70+
}
71+
}
72+
})
73+
}
74+
}
75+
76+
func TestSplitContextCancellation(t *testing.T) {
77+
ctx, cancel := context.WithCancel(context.Background())
78+
cancel() // Cancel immediately
79+
80+
reader := strings.NewReader("SELECT * FROM users;")
81+
_, err := Split(ctx, reader)
82+
83+
if err != context.Canceled {
84+
t.Errorf("Expected context.Canceled error, got %v", err)
85+
}
86+
}
87+
88+
func TestExtractDollarTag(t *testing.T) {
89+
tests := []struct {
90+
name string
91+
input string
92+
expected string
93+
}{
94+
{
95+
name: "empty dollar quote",
96+
input: "$$",
97+
expected: "$$",
98+
},
99+
{
100+
name: "simple tag",
101+
input: "$tag$",
102+
expected: "$tag$",
103+
},
104+
{
105+
name: "tag with numbers",
106+
input: "$tag123$",
107+
expected: "$tag123$",
108+
},
109+
{
110+
name: "tag with underscore",
111+
input: "$my_tag$",
112+
expected: "$my_tag$",
113+
},
114+
{
115+
name: "not a dollar quote (no closing)",
116+
input: "$tag",
117+
expected: "",
118+
},
119+
{
120+
name: "not a dollar quote (invalid char)",
121+
input: "$tag-name$",
122+
expected: "",
123+
},
124+
{
125+
name: "empty string",
126+
input: "",
127+
expected: "",
128+
},
129+
{
130+
name: "no dollar sign",
131+
input: "tag",
132+
expected: "",
133+
},
134+
{
135+
name: "tag with extra content",
136+
input: "$tag$rest of string",
137+
expected: "$tag$",
138+
},
139+
}
140+
141+
for _, tt := range tests {
142+
t.Run(tt.name, func(t *testing.T) {
143+
got := extractDollarTag(tt.input)
144+
if got != tt.expected {
145+
t.Errorf("extractDollarTag(%q) = %q, expected %q", tt.input, got, tt.expected)
146+
}
147+
})
148+
}
149+
}

0 commit comments

Comments
 (0)