Skip to content

Commit

Permalink
Add support for creating namespaced databases
Browse files Browse the repository at this point in the history
  • Loading branch information
neekolas committed Nov 27, 2024
1 parent f05ea4c commit f1897ed
Show file tree
Hide file tree
Showing 3 changed files with 216 additions and 16 deletions.
139 changes: 127 additions & 12 deletions pkg/db/pgx.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,34 @@ package db
import (
"context"
"database/sql"
"errors"
"fmt"
"regexp"
"time"

"github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/jackc/pgx/v5/stdlib"
"github.com/xmtp/xmtpd/pkg/migrations"
)

func newPGXDB(
ctx context.Context,
dsn string,
waitForDB, statementTimeout time.Duration,
) (*sql.DB, error) {
const MAX_NAMESPACE_LENGTH = 32

var allowedNamespaceRe = regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]*$`)

func waitUntilDBReady(ctx context.Context, db *pgxpool.Pool, waitTime time.Duration) error {
waitUntil := time.Now().Add(waitTime)

err := db.Ping(ctx)

for err != nil && time.Now().Before(waitUntil) {
time.Sleep(3 * time.Second)
err = db.Ping(ctx)
}
return err
}

func parseConfig(dsn string, statementTimeout time.Duration) (*pgxpool.Config, error) {
config, err := pgxpool.ParseConfig(dsn)
if err != nil {
return nil, err
Expand All @@ -24,31 +39,131 @@ func newPGXDB(
config.ConnConfig.RuntimeParams["statement_timeout"] = fmt.Sprint(
statementTimeout.Milliseconds(),
)
return config, nil
}

func newPGXDB(
ctx context.Context,
config *pgxpool.Config,
waitForDB time.Duration,
) (*sql.DB, error) {
dbPool, err := pgxpool.NewWithConfig(ctx, config)
if err != nil {
return nil, err
}

if err = waitUntilDBReady(ctx, dbPool, waitForDB); err != nil {
return nil, err
}

db := stdlib.OpenDBFromPool(dbPool)

waitUntil := time.Now().Add(waitForDB)
return db, nil
}

err = db.Ping()
for err != nil && time.Now().Before(waitUntil) {
time.Sleep(3 * time.Second)
err = db.Ping()
func isValidNamespace(namespace string) error {
if len(namespace) == 0 || len(namespace) > MAX_NAMESPACE_LENGTH {
return fmt.Errorf(
"namespace length must be between 1 and %d characters",
MAX_NAMESPACE_LENGTH,
)
}
// PostgreSQL identifiers must start with a letter or underscore
if !allowedNamespaceRe.MatchString(namespace) {
return fmt.Errorf(
"namespace must start with a letter or underscore and contain only letters, numbers, and underscores",
)
}
return nil
}

// Creates a new database with the given namespace if it doesn't exist
func createNamespace(
ctx context.Context,
config *pgxpool.Config,
namespace string,
waitForDB time.Duration,
) error {
if err := isValidNamespace(namespace); err != nil {
return err
}

// Make a copy of the config so we don't dirty it
config = config.Copy()
// Change the database to postgres so we are able to create new DBs
config.ConnConfig.Database = "postgres"

// Create a temporary connection to the postgres DB
adminConn, err := pgxpool.NewWithConfig(ctx, config)
if err != nil {
return fmt.Errorf("failed to connect to postgres: %w", err)
}
defer adminConn.Close()

return db, err
if err = waitUntilDBReady(ctx, adminConn, waitForDB); err != nil {
return err
}

// Create database if it doesn't exist
_, err = adminConn.Exec(ctx, fmt.Sprintf(`CREATE DATABASE "%s"`, namespace))
if err != nil {
// Ignore error if database already exists
var pgErr *pgconn.PgError
// Error code 42P04 is for "duplicate database"
// https://www.postgresql.org/docs/current/errcodes-appendix.html
if errors.As(err, &pgErr) && pgErr.Code == "42P04" {
return nil
}

return fmt.Errorf("failed to create database: %w", err)
}

return nil
}

// Creates a new database with the given namespace if it doesn't exist and returns the full DSN for the new database.
func NewNamespacedDB(
ctx context.Context,
dsn string,
namespace string,
waitForDB, statementTimeout time.Duration,
) (*sql.DB, error) {
// Parse the DSN to get the config
config, err := parseConfig(dsn, statementTimeout)
if err != nil {
return nil, fmt.Errorf("failed to parse DSN: %w", err)
}

if err = createNamespace(ctx, config, namespace, waitForDB); err != nil {
return nil, err
}

config.ConnConfig.Database = namespace

db, err := newPGXDB(ctx, config, waitForDB)
if err != nil {
return nil, err
}

err = migrations.Migrate(ctx, db)
if err != nil {
return nil, err
}

return db, nil
}

func NewDB(
ctx context.Context,
dsn string,
waitForDB, statementTimeout time.Duration,
) (*sql.DB, error) {
db, err := newPGXDB(ctx, dsn, waitForDB, statementTimeout)
config, err := parseConfig(dsn, statementTimeout)
if err != nil {
return nil, err
}

db, err := newPGXDB(ctx, config, waitForDB)
if err != nil {
return nil, err
}
Expand Down
85 changes: 85 additions & 0 deletions pkg/db/pgx_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
package db

import (
"context"
"testing"
"time"

"github.com/stretchr/testify/require"
"github.com/xmtp/xmtpd/pkg/testutils"
)

func TestNamespacedDB(t *testing.T) {
startingDsn := testutils.LocalTestDBDSNPrefix + "/foo" + testutils.LocalTestDBDSNSuffix
newDBName := "xmtp_" + testutils.RandomString(24)
// Create namespaced DB
namespacedDB, err := NewNamespacedDB(
context.Background(),
startingDsn,
newDBName,
time.Second,
time.Second,
)
t.Cleanup(func() { namespacedDB.Close() })
require.NoError(t, err)

result, err := namespacedDB.Query("SELECT current_database();")
require.NoError(t, err)
defer result.Close()

require.True(t, result.Next())
var dbName string
err = result.Scan(&dbName)
require.NoError(t, err)
require.Equal(t, newDBName, dbName)
}

func TestNamespaceRepeat(t *testing.T) {
startingDsn := testutils.LocalTestDBDSNPrefix + "/foo" + testutils.LocalTestDBDSNSuffix
newDBName := "xmtp_" + testutils.RandomString(24)
// Create namespaced DB
db1, err := NewNamespacedDB(
context.Background(),
startingDsn,
newDBName,
time.Second,
time.Second,
)
require.NoError(t, err)
require.NotNil(t, db1)
t.Cleanup(func() { db1.Close() })

// Create again with the same name
db2, err := NewNamespacedDB(
context.Background(),
startingDsn,
newDBName,
time.Second,
time.Second,
)
require.NoError(t, err)
require.NotNil(t, db2)
t.Cleanup(func() { db2.Close() })
}

func TestNamespacedDBInvalidName(t *testing.T) {
_, err := NewNamespacedDB(
context.Background(),
testutils.LocalTestDBDSNPrefix+"/foo"+testutils.LocalTestDBDSNSuffix,
"invalid/name",
time.Second,
time.Second,
)
require.Error(t, err)
}

func TestNamespacedDBInvalidDSN(t *testing.T) {
_, err := NewNamespacedDB(
context.Background(),
"invalid-dsn",
"dbname",
time.Second,
time.Second,
)
require.Error(t, err)
}
8 changes: 4 additions & 4 deletions pkg/testutils/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ import (
)

const (
localTestDBDSNPrefix = "postgres://postgres:xmtp@localhost:8765"
localTestDBDSNSuffix = "?sslmode=disable"
LocalTestDBDSNPrefix = "postgres://postgres:xmtp@localhost:8765"
LocalTestDBDSNSuffix = "?sslmode=disable"
)

func openDB(t testing.TB, dsn string) (*sql.DB, string, func()) {
Expand All @@ -28,15 +28,15 @@ func openDB(t testing.TB, dsn string) (*sql.DB, string, func()) {
}

func newCtlDB(t testing.TB) (*sql.DB, string, func()) {
return openDB(t, localTestDBDSNPrefix+localTestDBDSNSuffix)
return openDB(t, LocalTestDBDSNPrefix+LocalTestDBDSNSuffix)
}

func newInstanceDB(t testing.TB, ctx context.Context, ctlDB *sql.DB) (*sql.DB, string, func()) {
dbName := "test_" + RandomStringLower(12)
_, err := ctlDB.Exec("CREATE DATABASE " + dbName)
require.NoError(t, err)

db, dsn, cleanup := openDB(t, localTestDBDSNPrefix+"/"+dbName+localTestDBDSNSuffix)
db, dsn, cleanup := openDB(t, LocalTestDBDSNPrefix+"/"+dbName+LocalTestDBDSNSuffix)
require.NoError(t, migrations.Migrate(ctx, db))

return db, dsn, func() {
Expand Down

0 comments on commit f1897ed

Please sign in to comment.