Skip to content

Commit

Permalink
Merge pull request #198 from lamoda/imporve-pg-fixtures-load
Browse files Browse the repository at this point in the history
Load Postgres fixtures in a single transaction, truncate all tables at once
  • Loading branch information
fetinin authored Jan 17, 2023
2 parents 89d4ea3 + d5c6e1c commit e529fc4
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 99 deletions.
48 changes: 27 additions & 21 deletions fixtures/postgres/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,48 +210,54 @@ func (f *LoaderPostgres) loadTables(ctx *loadContext) error {
defer func() { _ = tx.Rollback() }()

// truncate first
truncatedTables := make(map[string]bool)
for _, lt := range ctx.tables {
if _, ok := truncatedTables[lt.name.getFullName()]; ok {
// already truncated
continue
}
if err := f.truncateTable(lt.name); err != nil {
return err
}
truncatedTables[lt.name.getFullName()] = true
if err := f.truncateTables(tx, ctx.tables...); err != nil {
return err
}

// then load data
for _, lt := range ctx.tables {
if len(lt.rows) == 0 {
continue
}
if err := f.loadTable(ctx, lt.name, lt.rows); err != nil {
return fmt.Errorf("failed to load table '%s' because:\n%s", lt.name, err)
if err := f.loadTable(ctx, tx, lt.name, lt.rows); err != nil {
return fmt.Errorf("failed to load table '%s' because:\n%s", lt.name.getFullName(), err)
}
}
// alter the sequences so they contain max id + 1
if err := f.fixSequences(); err != nil {
if err := f.fixSequences(tx); err != nil {
return err
}

return tx.Commit()
}

// truncateTable truncates table
func (f *LoaderPostgres) truncateTable(name tableName) error {
query := fmt.Sprintf("TRUNCATE TABLE %s CASCADE", name.getFullName())
// truncateTables truncates table
func (f *LoaderPostgres) truncateTables(tx *sql.Tx, tables ...loadedTable) error {
set := make(map[string]struct{})
tablesToTruncate := make([]string, 0, len(tables))
for _, t := range tables {
tableName := t.name.getFullName()
if _, ok := set[tableName]; ok {
// already truncated
continue
}

tablesToTruncate = append(tablesToTruncate, tableName)
set[tableName] = struct{}{}
}

query := fmt.Sprintf("TRUNCATE TABLE %s CASCADE", strings.Join(tablesToTruncate, ","))
if f.debug {
fmt.Println("Issuing SQL:", query)
}
_, err := f.db.Exec(query)
_, err := tx.Exec(query)
if err != nil {
return err
}
return nil
}

func (f *LoaderPostgres) loadTable(ctx *loadContext, t tableName, rows table) error {
func (f *LoaderPostgres) loadTable(ctx *loadContext, tx *sql.Tx, t tableName, rows table) error {
// $extend keyword allows to import values from a named row
for i, row := range rows {
if base, ok := row["$extend"]; ok {
Expand All @@ -275,7 +281,7 @@ func (f *LoaderPostgres) loadTable(ctx *loadContext, t tableName, rows table) er
fmt.Println("Issuing SQL:", query)
}
// issuing query
insertedRows, err := f.db.Query(query)
insertedRows, err := tx.Query(query)
if err != nil {
return err
}
Expand Down Expand Up @@ -326,7 +332,7 @@ func (f *LoaderPostgres) loadTable(ctx *loadContext, t tableName, rows table) er
return err
}

func (f *LoaderPostgres) fixSequences() error {
func (f *LoaderPostgres) fixSequences(tx *sql.Tx) error {
query := `
DO $$
DECLARE
Expand All @@ -353,7 +359,7 @@ END$$
if f.debug {
fmt.Println("Issuing SQL:", query)
}
_, err := f.db.Exec(query)
_, err := tx.Exec(query)
return err
}

Expand Down
101 changes: 23 additions & 78 deletions fixtures/postgres/postgres_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,26 +30,17 @@ func TestBuildInsertQuery(t *testing.T) {
require.NoError(t, err)

query, err := l.buildInsertQuery(&ctx, newTableName("table"), ctx.tables[0].rows)
require.NoError(t, err)

if err != nil {
t.Error("must not produce error, error:", err.Error())
t.Fail()
}

if query != expected {
t.Error("must generate proper SQL, got result:", query)
t.Fail()
}
require.Equal(t, expected, query)
}

func TestLoadTablesShouldResolveSchema(t *testing.T) {
yml, err := ioutil.ReadFile("../testdata/sql_schema.yaml")
require.NoError(t, err)

db, mock, err := sqlmock.New()
if err != nil {
t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
}
require.NoError(t, err)
defer func() { _ = db.Close() }()

ctx := loadContext{
Expand All @@ -59,21 +50,12 @@ func TestLoadTablesShouldResolveSchema(t *testing.T) {

l := New(db, "", true)

err = l.loadYml([]byte(yml), &ctx)
if err != nil {
t.Error(err)
t.Fail()
}
err = l.loadYml(yml, &ctx)
require.NoError(t, err)

mock.ExpectBegin()

mock.ExpectExec("^TRUNCATE TABLE \"schema1\".\"table1\" CASCADE$").
WillReturnResult(sqlmock.NewResult(0, 0))

mock.ExpectExec("^TRUNCATE TABLE \"schema2\".\"table2\" CASCADE$").
WillReturnResult(sqlmock.NewResult(0, 0))

mock.ExpectExec("^TRUNCATE TABLE \"public\".\"table3\" CASCADE$").
mock.ExpectExec("^TRUNCATE TABLE \"schema1\".\"table1\",\"schema2\".\"table2\",\"public\".\"table3\" CASCADE$").
WillReturnResult(sqlmock.NewResult(0, 0))

q := `^INSERT INTO "schema1"."table1" AS row \("f1", "f2"\) VALUES ` +
Expand Down Expand Up @@ -112,25 +94,18 @@ func TestLoadTablesShouldResolveSchema(t *testing.T) {
mock.ExpectCommit()

err = l.loadTables(&ctx)
if err != nil {
t.Error(err)
t.Fail()
}
require.NoError(t, err)

if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled expectations: %s", err)
t.Fail()
}
err = mock.ExpectationsWereMet()
require.NoError(t, err)
}

func TestLoadTablesShouldResolveRefs(t *testing.T) {
yml, err := ioutil.ReadFile("../testdata/sql_refs.yaml")
require.NoError(t, err)

db, mock, err := sqlmock.New()
if err != nil {
t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
}
require.NoError(t, err)
defer func() { _ = db.Close() }()

ctx := loadContext{
Expand All @@ -140,21 +115,12 @@ func TestLoadTablesShouldResolveRefs(t *testing.T) {

l := New(db, "", true)

err = l.loadYml([]byte(yml), &ctx)
if err != nil {
t.Error(err)
t.Fail()
}
err = l.loadYml(yml, &ctx)
require.NoError(t, err)

mock.ExpectBegin()

mock.ExpectExec("^TRUNCATE TABLE \"public\".\"table1\" CASCADE$").
WillReturnResult(sqlmock.NewResult(0, 0))

mock.ExpectExec("^TRUNCATE TABLE \"public\".\"table2\" CASCADE$").
WillReturnResult(sqlmock.NewResult(0, 0))

mock.ExpectExec("^TRUNCATE TABLE \"public\".\"table3\" CASCADE$").
mock.ExpectExec("^TRUNCATE TABLE \"public\".\"table1\",\"public\".\"table2\",\"public\".\"table3\" CASCADE$").
WillReturnResult(sqlmock.NewResult(0, 0))

q := `^INSERT INTO "public"."table1" AS row \("f1", "f2"\) VALUES ` +
Expand Down Expand Up @@ -193,25 +159,18 @@ func TestLoadTablesShouldResolveRefs(t *testing.T) {
mock.ExpectCommit()

err = l.loadTables(&ctx)
if err != nil {
t.Error(err)
t.Fail()
}
require.NoError(t, err)

if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled expectations: %s", err)
t.Fail()
}
err = mock.ExpectationsWereMet()
require.NoError(t, err)
}

func TestLoadTablesShouldExtendRows(t *testing.T) {
yml, err := ioutil.ReadFile("../testdata/sql_extend.yaml")
require.NoError(t, err)

db, mock, err := sqlmock.New()
if err != nil {
t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
}
require.NoError(t, err)
defer func() { _ = db.Close() }()

ctx := loadContext{
Expand All @@ -221,21 +180,12 @@ func TestLoadTablesShouldExtendRows(t *testing.T) {

l := New(db, "", true)

err = l.loadYml([]byte(yml), &ctx)
if err != nil {
t.Error(err)
t.Fail()
}
err = l.loadYml(yml, &ctx)
require.NoError(t, err)

mock.ExpectBegin()

mock.ExpectExec("^TRUNCATE TABLE \"public\".\"table1\" CASCADE$").
WillReturnResult(sqlmock.NewResult(0, 0))

mock.ExpectExec("^TRUNCATE TABLE \"public\".\"table2\" CASCADE$").
WillReturnResult(sqlmock.NewResult(0, 0))

mock.ExpectExec("^TRUNCATE TABLE \"public\".\"table3\" CASCADE$").
mock.ExpectExec("^TRUNCATE TABLE \"public\".\"table1\",\"public\".\"table2\",\"public\".\"table3\" CASCADE$").
WillReturnResult(sqlmock.NewResult(0, 0))

q := `^INSERT INTO "public"."table1" AS row \("f1", "f2"\) VALUES ` +
Expand Down Expand Up @@ -276,13 +226,8 @@ func TestLoadTablesShouldExtendRows(t *testing.T) {
mock.ExpectCommit()

err = l.loadTables(&ctx)
if err != nil {
t.Error(err)
t.Fail()
}
require.NoError(t, err)

if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled expectations: %s", err)
t.Fail()
}
err = mock.ExpectationsWereMet()
require.NoError(t, err)
}

0 comments on commit e529fc4

Please sign in to comment.