Skip to content

標準ライブラリでトランザクション制御

cf. https://go.dev/doc/database/execute-transactions

実装

  • infrastructure/tx_manager.goを新規作成します:
package infrastructure

import (
    "context"
    "database/sql"
)

// トランザクション管理のインターフェースです。
type TxManager interface {
    WithTransaction(ctx context.Context, f func(ctx context.Context) error) error
}

type txManager struct {
    db *sql.DB
}

func NewTxManager(db *sql.DB) TxManager {
    return &txManager{db: db}
}

type txKey struct{}

// トランザクションを開始しつつ業務処理を実行します。
// エラー発生時はロールバックが実行されます。
func (t *txManager) WithTransaction(ctx context.Context, f func(ctx context.Context) error) error {

    if existingTx := GetTx(ctx); existingTx != nil {
        return f(ctx)
    }

    tx, err := t.db.BeginTx(ctx, nil)
    if err != nil {
        return err
    }

    ctx = context.WithValue(ctx, txKey{}, tx)

    if err := f(ctx); err != nil {
        tx.Rollback()
        return err
    }

    return tx.Commit()
}

// コンテキストからトランザクションを取得します。
// トランザクションがない場合はnilを返します。
func GetTx(ctx context.Context) *sql.Tx {

    tx, ok := ctx.Value(txKey{}).(*sql.Tx)
    if !ok {
        return nil
    }
    return tx
}

// コンテキストにトランザクションをセットします。
func SetTx(ctx context.Context, tx *sql.Tx) context.Context {
    return context.WithValue(ctx, txKey{}, tx)
}

// SQL発行処理のインターフェースです。
type Exec interface {
    ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
}
  • persistence配下のsqlについて、トランザクションを取得する処理を追記します:
package persistence

import (
    "context"
    "database/sql"
    "easyapp/internal/infrastructure"
    "easyapp/internal/infrastructure/persistence/table"
)

type UsersSql interface {

    // ユーザ情報を保存します。
    Save(ctx context.Context, u table.Users) error
}

type usersSql struct {
    db *sql.DB
}

func NewUsersSql(db *sql.DB) UsersSql {
    return &usersSql{db: db}
}

func (s *usersSql) Save(ctx context.Context, u table.Users) error {

    // トランザクションを取得
    // トランザクションがなければ*sql.DBを利用
    var exec infrastructure.Exec = s.db
    if tx := infrastructure.GetTx(ctx); tx != nil {
        exec = tx
    }

    _, err := exec.ExecContext(
        ctx,
        "INSERT INTO users (name, password, age) VALUES (?, ?, ?)",
        u.Name,
        u.Password,
        u.Age,
    )

    return err
}
  • usecaseでトランザクションを管理します。txManagerを構造体に含め、WithTransactionで業務処理をラップします:
package usecase

import (
    "context"
    "easyapp/internal/domain"
    "easyapp/internal/infrastructure"
    "easyapp/internal/usecase/params"
)

type UserUsecase interface {

    // ユーザ情報を保存します。
    RegistUser(ctx context.Context, in params.RegistUserIn) (params.RegistUserOut, error)
}

type userUsecase struct {
    userRepository domain.UserRepository
    txManager      infrastructure.TxManager // トランザクション制御
}

func NewUserUsecase(
    userRepository domain.UserRepository,
    txManager infrastructure.TxManager,
) UserUsecase {
    return &userUsecase{
        userRepository: userRepository,
        txManager:      txManager,
    }
}

func (u *userUsecase) RegistUser(
    ctx context.Context,
    in params.RegistUserIn,
) (
    params.RegistUserOut,
    error,
) {

    // トランザクションを開始しつつ業務処理を実行
    if err := u.txManager.WithTransaction(ctx, func(ctx context.Context) error {
        return u.userRepository.Save(ctx, domain.NewUser(in.Name(), in.Password(), in.Age()))
    }); err != nil {
        return params.NewRegistUserOut(false), err
    }

    return params.NewRegistUserOut(true), nil
}

テスト

  • users_sql_test.go
package persistence

import (
    "context"
    "database/sql"
    "easyapp/internal/infrastructure"
    "easyapp/internal/infrastructure/persistence/table"
    "easyapp/internal/infrastructure/persistence/test"
    "testing"

    _ "github.com/mattn/go-sqlite3"
    "github.com/stretchr/testify/assert"
)

func Test_UsersSql_Save(t *testing.T) {

    tests := []struct {
        name          string           // テストケース名
        ctx           context.Context  // コンテキスト
        withTx        bool             // トランザクション有無
        users         table.Users      // 入力値
        setup         func(db *sql.DB) // 事前セットアップ関数
        expectedError error            // 期待されるエラー
    }{
        {
            name:          "success without tx",
            ctx:           context.Background(),
            withTx:        false,
            users:         table.Users{Name: "nob", Password: "passwd", Age: 13},
            setup:         func(db *sql.DB) {},
            expectedError: nil,
        },
        {
            name:          "success with tx",
            ctx:           context.Background(),
            withTx:        true,
            users:         table.Users{Name: "nob", Password: "passwd", Age: 13},
            setup:         func(db *sql.DB) {},
            expectedError: nil,
        },
    }

    for _, testcase := range tests {

        t.Run(testcase.name, func(t *testing.T) {

            // テストデータベースに接続
            db := test.ConnectTestDB(t, "users")

            // 事前セットアップ
            testcase.setup(db)

            // トランザクション開始
            if testcase.withTx {
                tx, err := db.Begin()
                assert.NoError(t, err)
                testcase.ctx = infrastructure.SetTx(context.Background(), tx)
            }

            // sqlの実行
            result := NewUsersSql(db).Save(testcase.ctx, testcase.users)

            // レスポンスの確認
            assert.Equal(t, testcase.expectedError, result)
        })
    }
}
  • user_usecase_test.go
package usecase

import (
    "context"
    "easyapp/internal/domain"
    "easyapp/internal/usecase/params"
    "errors"
    "testing"

    "github.com/stretchr/testify/assert"
    "github.com/stretchr/testify/mock"
)

// repositoryモックの定義
type mockUserRepository struct {
    mock.Mock
}

func (m *mockUserRepository) Save(ctx context.Context, u domain.User) error {
    args := m.Called(ctx, u)
    return args.Error(0)
}

// txManagerモックの定義
type mockTxManager struct{}

// 関数fの結果を返すようにモック化
func (m *mockTxManager) WithTransaction(
    ctx context.Context,
    f func(ctx context.Context) error,
) error {
    return f(ctx)
}

// UserUsecase_RegistUserのテスト
func Test_UserUsecase_RegistUser(t *testing.T) {

    tests := []struct {
        name                string                      // テストケース名
        requestBody         params.RegistUserIn         // リクエストボディ
        setupRepositoryMock func(m *mockUserRepository) // repositoryモック設定
        expectedBody        params.RegistUserOut        // 期待されるレスポンスボディ
        expectedError       error                       // 期待されるエラー
    }{
        {
            name:        "success",
            requestBody: params.NewRegistUserIn("nob", "passwd", 13),
            setupRepositoryMock: func(m *mockUserRepository) {
                m.On(
                    "Save",
                    mock.Anything,
                    domain.NewUser("nob", "passwd", 13),
                ).Return(
                    nil,
                )
            },
            expectedBody:  params.NewRegistUserOut(true),
            expectedError: nil,
        },
        {
            name:        "repository error",
            requestBody: params.NewRegistUserIn("nob", "passwd", 13),
            setupRepositoryMock: func(m *mockUserRepository) {
                m.On(
                    "Save",
                    mock.Anything,
                    domain.NewUser("nob", "passwd", 13),
                ).Return(
                    errors.New("repository error"),
                )
            },
            expectedBody:  params.NewRegistUserOut(false),
            expectedError: errors.New("repository error"),
        },
    }

    for _, testcase := range tests {

        // モック初期化
        mockRepository := new(mockUserRepository)
        mockTxManager := new(mockTxManager)

        t.Run(testcase.name, func(t *testing.T) {
            // モックの期待される動作を定義
            testcase.setupRepositoryMock(mockRepository)

            // usecaseの実行
            result, err := NewUserUsecase(mockRepository, mockTxManager).RegistUser(
                context.Background(),
                testcase.requestBody,
            )

            // レスポンスの検証
            assert.Equal(t, testcase.expectedBody, result)
            assert.Equal(t, testcase.expectedError, err)
        })
    }
}