Skip to content

gormでトランザクション制御

cf. https://gorm.io/docs/transactions.html

実装

  • infrastructure/tx_manager.goを新規作成します:
package infrastructure

import (
    "context"

    "gorm.io/gorm"
)

// トランザクション管理のインターフェースです。
type TxManager interface {
    WithTransaction(ctx context.Context, f func(ctx context.Context) error) error
}

type txManager struct {
    db *gorm.DB
}

func NewTxManager(db *gorm.DB) TxManager {
    return &txManager{db: db}
}

// トランザクションを開始しつつ業務処理を実行します。
// エラー発生時はロールバックが実行されます。
func (t *txManager) WithTransaction(ctx context.Context, f func(ctx context.Context) error) error {

    if existingTx := GetTx(ctx); existingTx != nil {
        return f(ctx)
    }

    return t.db.Transaction(func(tx *gorm.DB) error {
        ctxWithTx := context.WithValue(ctx, txKey{}, tx)
        return f(ctxWithTx)
    })
}

type txKey struct{}

// コンテキストからトランザクションを取得します。
// トランザクションがない場合はnilを返します。
func GetTx(ctx context.Context) *gorm.DB {

    tx, ok := ctx.Value(txKey{}).(*gorm.DB)
    if !ok {
        return nil
    }
    return tx
}

// コンテキストにトランザクションをセットします。
func SetTx(ctx context.Context, tx *gorm.DB) context.Context {
    return context.WithValue(ctx, txKey{}, tx)
}
  • persistence配下のsqlについて、トランザクションを取得する処理を追記します:
package persistence

import (
    "context"
    "easyapp/internal/infrastructure"
    "easyapp/internal/infrastructure/persistence/table"

    "gorm.io/gorm"
)

type UsersSql interface {

    // ユーザ情報を保存します。
    Save(ctx context.Context, u table.Users) error
}

type usersSql struct {
    db *gorm.DB
}

func NewUsersSql(db *gorm.DB) UsersSql {
    return &usersSql{db: db}
}

func (s *usersSql) Save(ctx context.Context, u table.Users) error {

    // トランザクションがあれば取得
    db := s.db
    if tx := infrastructure.GetTx(ctx); tx != nil {
        db = tx
    }

    if err := gorm.G[table.Users](db).Create(ctx, &u); err != nil {
        return err
    }

    return nil
}
  • usecaseでトランザクションを管理します。txManagerを構造体に含め、WithTransactionで業務処理をラップします:
package usecase

import (
    "context"
    "easyapp/internal/domain"
    "easyapp/internal/infrastructure"
    "easyapp/internal/usecase/params"
)

type UserUsecase interface {

    // ユーザ情報を保存します。
    RegistUser(ctx context.Context, in params.RegistUserIn) (params.RegistUserOut, error)
}

type userUsecase struct {
    userRepository domain.UserRepository
    txManager      infrastructure.TxManager
}

func NewUserUsecase(
    userRepository domain.UserRepository,
    txManager infrastructure.TxManager,
) UserUsecase {
    return &userUsecase{userRepository: userRepository, txManager: txManager}
}

func (u *userUsecase) RegistUser(
    ctx context.Context,
    in params.RegistUserIn,
) (
    params.RegistUserOut,
    error,
) {

    // トランザクションを開始しつつ業務処理を実行
    if err := u.txManager.WithTransaction(ctx, func(ctx context.Context) error {
        return u.userRepository.Save(ctx, domain.NewUser(in.Name(), in.Password(), in.Age()))
    }); err != nil {
        return params.NewRegistUserOut(false), err
    }

    return params.NewRegistUserOut(true), nil
}

テスト

  • users_sql_test.go
package persistence

import (
    "context"
    "easyapp/internal/infrastructure"
    "easyapp/internal/infrastructure/persistence/table"
    "easyapp/internal/infrastructure/persistence/test"
    "errors"
    "testing"

    "github.com/mattn/go-sqlite3"
    "github.com/stretchr/testify/assert"
    "gorm.io/gorm"
)

func Test_UsersSql_Save(t *testing.T) {

    tests := []struct {
        name          string            // テストケース名
        ctx           context.Context   // コンテキスト
        withTx        bool              // トランザクション有無
        users         table.Users       // 入力値
        setup         func(db *gorm.DB) // 事前セットアップ関数
        expectedError error             // 期待されるエラー
    }{
        {
            name:          "success without tx",
            ctx:           context.Background(),
            withTx:        false,
            users:         table.Users{Name: "nob", Password: "passwd", Age: 13},
            setup:         func(db *gorm.DB) {},
            expectedError: nil,
        },
        {
            name:          "success with tx",
            ctx:           context.Background(),
            withTx:        true,
            users:         table.Users{Name: "nob", Password: "passwd", Age: 13},
            setup:         func(db *gorm.DB) {},
            expectedError: nil,
        },
        {
            name:   "failed to query",
            ctx:    context.Background(),
            withTx: false,
            users:  table.Users{Name: "nob", Password: "passwd", Age: 13},
            setup: func(db *gorm.DB) {
                db.Exec("DROP TABLE users") // クエリエラーのためにテーブル破棄
            },
            expectedError: sqlite3.Error{},
        },
    }

    for _, testcase := range tests {

        t.Run(testcase.name, func(t *testing.T) {

            // テストデータベースに接続
            db := test.ConnectTestDB(t, "users")

            // 事前セットアップ
            testcase.setup(db)

            // トランザクション開始
            if testcase.withTx {
                tx := db.Begin()
                testcase.ctx = infrastructure.SetTx(context.Background(), tx)
            }

            // sqlの実行
            result := NewUsersSql(db).Save(testcase.ctx, testcase.users)

            // レスポンスの確認
            if result != nil {
                var sqliteError sqlite3.Error
                errors.As(result, &sqliteError)
            } else {
                assert.Equal(t, testcase.expectedError, result)
            }

        })
    }
}
  • user_usecase_test.go
package usecase

import (
    "context"
    "easyapp/internal/domain"
    "easyapp/internal/usecase/params"
    "errors"
    "testing"

    "github.com/stretchr/testify/assert"
    "github.com/stretchr/testify/mock"
)

// repositoryモックの定義
type mockUserRepository struct {
    mock.Mock
}

func (m *mockUserRepository) Save(ctx context.Context, u domain.User) error {
    args := m.Called(ctx, u)
    return args.Error(0)
}

// txManagerモックの定義
type mockTxManager struct{}

// 関数fの結果を返すようにモック化
func (m *mockTxManager) WithTransaction(
    ctx context.Context,
    f func(ctx context.Context) error,
) error {
    return f(ctx)
}

// UserUsecase_RegistUserのテスト
func Test_UserUsecase_RegistUser(t *testing.T) {

    tests := []struct {
        name                string                      // テストケース名
        requestBody         params.RegistUserIn         // リクエストボディ
        setupRepositoryMock func(m *mockUserRepository) // repositoryモック設定
        expectedBody        params.RegistUserOut        // 期待されるレスポンスボディ
        expectedError       error                       // 期待されるエラー
    }{
        {
            name:        "success",
            requestBody: params.NewRegistUserIn("nob", "passwd", 13),
            setupRepositoryMock: func(m *mockUserRepository) {
                m.On(
                    "Save",
                    mock.Anything,
                    domain.NewUser("nob", "passwd", 13),
                ).Return(
                    nil,
                )
            },
            expectedBody:  params.NewRegistUserOut(true),
            expectedError: nil,
        },
        {
            name:        "repository error",
            requestBody: params.NewRegistUserIn("nob", "passwd", 13),
            setupRepositoryMock: func(m *mockUserRepository) {
                m.On(
                    "Save",
                    mock.Anything,
                    domain.NewUser("nob", "passwd", 13),
                ).Return(
                    errors.New("repository error"),
                )
            },
            expectedBody:  params.NewRegistUserOut(false),
            expectedError: errors.New("repository error"),
        },
    }

    for _, testcase := range tests {

        // モック初期化
        mockRepository := new(mockUserRepository)
        mockTxManager := new(mockTxManager)

        t.Run(testcase.name, func(t *testing.T) {
            // モックの期待される動作を定義
            testcase.setupRepositoryMock(mockRepository)

            // usecaseの実行
            result, err := NewUserUsecase(mockRepository, mockTxManager).RegistUser(
                context.Background(),
                testcase.requestBody,
            )

            // レスポンスの検証
            assert.Equal(t, testcase.expectedBody, result)
            assert.Equal(t, testcase.expectedError, err)
        })
    }
}