Skip to content
Prev Previous commit
Next Next commit
Be paranoid about concurrency
  • Loading branch information
kyleconroy committed Aug 25, 2025
commit 9bdaf614464726f45253111ed2ef8aefb684f32b
4 changes: 4 additions & 0 deletions internal/sqltest/docker/enabled.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,12 @@ package docker
import (
"fmt"
"os/exec"

"golang.org/x/sync/singleflight"
)

var flight singleflight.Group

func Installed() error {
if _, err := exec.LookPath("docker"); err != nil {
return fmt.Errorf("docker not found: %w", err)
Expand Down
104 changes: 60 additions & 44 deletions internal/sqltest/docker/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,84 +5,100 @@ import (
"database/sql"
"fmt"
"os/exec"
"sync"
"strings"
"time"

_ "github.com/go-sql-driver/mysql"
)

var mysqlSync sync.Once
var mysqlHost string

func StartMySQLServer(c context.Context) (string, error) {
if err := Installed(); err != nil {
return "", err
}
if mysqlHost != "" {
return mysqlHost, nil
}
value, err, _ := flight.Do("mysql", func() (interface{}, error) {
host, err := startMySQLServer(c)
if err != nil {
return "", err
}
mysqlHost = host
return host, nil
})
if err != nil {
return "", err
}
data, ok := value.(string)
if !ok {
return "", fmt.Errorf("returned value was not a string")
}
return data, nil
}

func startMySQLServer(c context.Context) (string, error) {
{
_, err := exec.Command("docker", "pull", "mysql:8").CombinedOutput()
_, err := exec.Command("docker", "pull", "mysql:9").CombinedOutput()
if err != nil {
return "", fmt.Errorf("docker pull: mysql:8 %w", err)
return "", fmt.Errorf("docker pull: mysql:9 %w", err)
}
}

var syncErr error
mysqlSync.Do(func() {
ctx, cancel := context.WithTimeout(c, 10*time.Second)
defer cancel()
var exists bool
{
cmd := exec.Command("docker", "container", "inspect", "sqlc_sqltest_docker_mysql")
// This means we've already started the container
exists = cmd.Run() == nil
}

if !exists {
cmd := exec.Command("docker", "run",
"--name", "sqlc_sqltest_docker_mysql",
"-e", "MYSQL_ROOT_PASSWORD=mysecretpassword",
"-e", "MYSQL_DATABASE=dinotest",
"-p", "3306:3306",
"-d",
"mysql:8",
"mysql:9",
)

output, err := cmd.CombinedOutput()
fmt.Println(string(output))
if err != nil {
syncErr = err
return
}

// Create a ticker that fires every 10ms
ticker := time.NewTicker(10 * time.Millisecond)
defer ticker.Stop()
msg := `Conflict. The container name "/sqlc_sqltest_docker_mysql" is already in use by container`
if !strings.Contains(string(output), msg) && err != nil {
return "", err
}
}

uri := "root:mysecretpassword@/dinotest"
ctx, cancel := context.WithTimeout(c, 10*time.Second)
defer cancel()

db, err := sql.Open("mysql", uri)
if err != nil {
syncErr = fmt.Errorf("sql.Open: %w", err)
return
}
// Create a ticker that fires every 10ms
ticker := time.NewTicker(10 * time.Millisecond)
defer ticker.Stop()

for {
select {
case <-ctx.Done():
syncErr = fmt.Errorf("timeout reached: %w", ctx.Err())
return

case <-ticker.C:
// Run your function here
if err := db.PingContext(ctx); err != nil {
continue
}
mysqlHost = uri
return
}
}
})
uri := "root:mysecretpassword@/dinotest?multiStatements=true&parseTime=true"

if syncErr != nil {
return "", syncErr
db, err := sql.Open("mysql", uri)
if err != nil {
return "", fmt.Errorf("sql.Open: %w", err)
}

if mysqlHost == "" {
return "", fmt.Errorf("mysql server setup failed")
}
defer db.Close()

for {
select {
case <-ctx.Done():
return "", fmt.Errorf("timeout reached: %w", ctx.Err())

return mysqlHost, nil
case <-ticker.C:
// Run your function here
if err := db.PingContext(ctx); err != nil {
continue
}
return uri, nil
}
}
}
98 changes: 57 additions & 41 deletions internal/sqltest/docker/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,32 +5,57 @@ import (
"fmt"
"log/slog"
"os/exec"
"sync"
"strings"
"time"

"github.com/jackc/pgx/v5"
)

var postgresSync sync.Once
var postgresHost string

func StartPostgreSQLServer(c context.Context) (string, error) {
if err := Installed(); err != nil {
return "", err
}
if postgresHost != "" {
return postgresHost, nil
}
value, err, _ := flight.Do("postgresql", func() (interface{}, error) {
host, err := startPostgreSQLServer(c)
if err != nil {
return "", err
}
postgresHost = host
return host, err
})
if err != nil {
return "", err
}
data, ok := value.(string)
if !ok {
return "", fmt.Errorf("returned value was not a string")
}
return data, nil
}

func startPostgreSQLServer(c context.Context) (string, error) {
{
_, err := exec.Command("docker", "pull", "postgres:16").CombinedOutput()
if err != nil {
return "", fmt.Errorf("docker pull: postgres:16 %w", err)
}
}

var syncErr error
postgresSync.Do(func() {
ctx, cancel := context.WithTimeout(c, 5*time.Second)
defer cancel()
uri := "postgres://postgres:mysecretpassword@localhost:5432/postgres?sslmode=disable"

var exists bool
{
cmd := exec.Command("docker", "container", "inspect", "sqlc_sqltest_docker_postgres")
// This means we've already started the container
exists = cmd.Run() == nil
}

if !exists {
cmd := exec.Command("docker", "run",
"--name", "sqlc_sqltest_docker_postgres",
"-e", "POSTGRES_PASSWORD=mysecretpassword",
Expand All @@ -43,47 +68,38 @@ func StartPostgreSQLServer(c context.Context) (string, error) {

output, err := cmd.CombinedOutput()
fmt.Println(string(output))
if err != nil {
syncErr = err
return

msg := `Conflict. The container name "/sqlc_sqltest_docker_postgres" is already in use by container`
if !strings.Contains(string(output), msg) && err != nil {
return "", err
}
}

// Create a ticker that fires every 10ms
ticker := time.NewTicker(10 * time.Millisecond)
defer ticker.Stop()
ctx, cancel := context.WithTimeout(c, 5*time.Second)
defer cancel()

uri := "postgres://postgres:mysecretpassword@localhost:5432/postgres?sslmode=disable"
// Create a ticker that fires every 10ms
ticker := time.NewTicker(10 * time.Millisecond)
defer ticker.Stop()

for {
select {
case <-ctx.Done():
syncErr = fmt.Errorf("timeout reached: %w", ctx.Err())
return
for {
select {
case <-ctx.Done():
return "", fmt.Errorf("timeout reached: %w", ctx.Err())

case <-ticker.C:
// Run your function here
conn, err := pgx.Connect(ctx, uri)
if err != nil {
slog.Debug("sqltest", "connect", err)
continue
}
if err := conn.Ping(ctx); err != nil {
slog.Error("sqltest", "ping", err)
continue
}
postgresHost = uri
return
case <-ticker.C:
// Run your function here
conn, err := pgx.Connect(ctx, uri)
if err != nil {
slog.Debug("sqltest", "connect", err)
continue
}
defer conn.Close(ctx)
if err := conn.Ping(ctx); err != nil {
slog.Error("sqltest", "ping", err)
continue
}
return uri, nil
}
})

if syncErr != nil {
return "", syncErr
}

if postgresHost == "" {
return "", fmt.Errorf("postgres server setup failed")
}

return postgresHost, nil
}
Loading