Skip to content

標準ライブラリでトランザクション制御

cf. https://go.dev/doc/database/execute-transactions

実装

  • infrastructure/tx_manager.goを新規作成します:
go
package infrastructure

import (
	"context"
	"database/sql"
)

// トランザクション管理のインターフェースです。
type TxManager interface {
	WithTransaction(ctx context.Context, f func(ctx context.Context) error) error
}

type txManager struct {
	db *sql.DB
}

func NewTxManager(db *sql.DB) TxManager {
	return &txManager{db: db}
}

type txKey struct{}

// トランザクションを開始しつつ業務処理を実行します。
// エラー発生時はロールバックが実行されます。
func (t *txManager) WithTransaction(ctx context.Context, f func(ctx context.Context) error) error {

	if existingTx := GetTx(ctx); existingTx != nil {
		return f(ctx)
	}

	tx, err := t.db.BeginTx(ctx, nil)
	if err != nil {
		return err
	}

	ctx = context.WithValue(ctx, txKey{}, tx)

	if err := f(ctx); err != nil {
		tx.Rollback()
		return err
	}

	return tx.Commit()
}

// コンテキストからトランザクションを取得します。
// トランザクションがない場合はnilを返します。
func GetTx(ctx context.Context) *sql.Tx {

	tx, ok := ctx.Value(txKey{}).(*sql.Tx)
	if !ok {
		return nil
	}
	return tx
}

// コンテキストにトランザクションをセットします。
func SetTx(ctx context.Context, tx *sql.Tx) context.Context {
	return context.WithValue(ctx, txKey{}, tx)
}

// SQL発行処理のインターフェースです。
type Exec interface {
	ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
}
  • repository配下のsqlについて、トランザクションを取得する処理を追記します:
go
package repository

import (
	"context"
	"database/sql"
	"easyapp/internal/domain"
	"easyapp/internal/infrastructure"
)

type userRepository struct {
	db *sql.DB
}

func NewUserRepository(db *sql.DB) domain.UserRepository {
	return &userRepository{db: db}
}

func (r *userRepository) Save(ctx context.Context, user domain.User) error {

	// トランザクションを取得、トランザクションがなければ*sql.DBを利用
	var exec infrastructure.Exec = r.db
	if tx := infrastructure.GetTx(ctx); tx != nil {
		exec = tx
	}

	_, err := exec.ExecContext(
		ctx,
		"INSERT INTO users (name, password, age) VALUES (?, ?, ?)",
		user.Name(),
		user.Password(),
		user.Age(),
	)

	return err
}
  • usecaseでトランザクションを管理します。txManagerを構造体に含め、WithTransactionで業務処理をラップします:
go
package usecase

import (
	"context"
	"easyapp/internal/domain"
	"easyapp/internal/infrastructure"
	"easyapp/internal/usecase/params"
)

// 認証のusecaseインターフェースです。
type UserUsecase interface {

	// ユーザ情報を登録します。
	Regist(ctx context.Context, in params.RegistIn) (params.RegistOut, error)
}

type userUsecase struct {
	userRepository domain.UserRepository
	txManager      infrastructure.TxManager // トランザクション制御
}

func NewUserUsecase(
	userRepository domain.UserRepository,
	txManager infrastructure.TxManager,
) UserUsecase {
	return &userUsecase{
		userRepository: userRepository,
		txManager:      txManager,
	}
}

func (u *userUsecase) Regist(ctx context.Context, in params.RegistIn) (params.RegistOut, error) {

	// トランザクションを開始しつつ業務処理を実行
	if err := u.txManager.WithTransaction(ctx, func(ctx context.Context) error {
		return u.userRepository.Save(ctx, domain.NewUser(in.Name(), in.Password(), in.Age()))
	}); err != nil {
		return params.NewRegistOut(false), err
	}

	return params.NewRegistOut(true), nil
}

テスト

  • user_repository_test.go
go
package repository

import (
	"context"
	"database/sql"
	"easyapp/internal/domain"
	"easyapp/internal/infrastructure"
	"easyapp/internal/infrastructure/repository/testdata"
	"testing"

	"github.com/stretchr/testify/assert"
)

func TestSave(t *testing.T) {

	tests := []struct {
		name          string           // テストケース名
		ctx           context.Context  // コンテキスト
		withTx        bool             // トランザクション有無
		users         domain.User      // 入力値
		setup         func(db *sql.DB) // 事前セットアップ関数
		expectedError error            // 期待されるエラー
	}{
		{
			name:          "success without tx",
			ctx:           context.Background(),
			withTx:        false,
			users:         domain.NewUser("nob", "passwd", 13),
			setup:         func(db *sql.DB) {},
			expectedError: nil,
		},
		{
			name:          "success with tx",
			ctx:           context.Background(),
			withTx:        true,
			users:         domain.NewUser("nob", "passwd", 13),
			setup:         func(db *sql.DB) {},
			expectedError: nil,
		},
	}

	for _, testcase := range tests {

		t.Run(testcase.name, func(t *testing.T) {
			// テストデータベースに接続
			db := testdata.ConnectTestDB(t, "users")

			// 事前セットアップ
			testcase.setup(db)

			// トランザクション開始
			if testcase.withTx {
				tx, err := db.Begin()
				assert.NoError(t, err)
				testcase.ctx = infrastructure.SetTx(context.Background(), tx)
			}

			// sqlの実行
			result := NewUserRepository(db).Save(testcase.ctx, testcase.users)

			// レスポンスの確認
			assert.Equal(t, testcase.expectedError, result)
		})
	}
}
  • user_usecase_test.go
go
package usecase

import (
	"context"
	"easyapp/internal/domain"
	"easyapp/internal/usecase/params"
	"errors"
	"testing"

	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/mock"
)

// repositoryモックの定義
type mockUserRepository struct {
	mock.Mock
}

func (m *mockUserRepository) Save(ctx context.Context, u domain.User) error {
	args := m.Called(ctx, u)
	return args.Error(0)
}

// txManagerモックの定義
type mockTxManager struct{}

// 関数fの結果を返すようにモック化
func (m *mockTxManager) WithTransaction(
	ctx context.Context,
	f func(ctx context.Context) error,
) error {
	return f(ctx)
}

// RegistUserのテスト
func TestRegistUser(t *testing.T) {

	tests := []struct {
		name                string                      // テストケース名
		requestBody         params.RegistIn             // リクエストボディ
		setupRepositoryMock func(m *mockUserRepository) // repositoryモック設定
		expectedBody        params.RegistOut            // 期待されるレスポンスボディ
		expectedError       error                       // 期待されるエラー
	}{
		{
			name:        "success",
			requestBody: params.NewRegistIn("nob", "passwd", 13),
			setupRepositoryMock: func(m *mockUserRepository) {
				m.On(
					"Save",
					mock.Anything,
					domain.NewUser("nob", "passwd", 13),
				).Return(
					nil,
				)
			},
			expectedBody:  params.NewRegistOut(true),
			expectedError: nil,
		},
		{
			name:        "repository error",
			requestBody: params.NewRegistIn("nob", "passwd", 13),
			setupRepositoryMock: func(m *mockUserRepository) {
				m.On(
					"Save",
					mock.Anything,
					domain.NewUser("nob", "passwd", 13),
				).Return(
					errors.New("repository error"),
				)
			},
			expectedBody:  params.NewRegistOut(false),
			expectedError: errors.New("repository error"),
		},
	}

	for _, testcase := range tests {

		// モック初期化
		mockRepository := new(mockUserRepository)
		mockTxManager := new(mockTxManager)

		t.Run(testcase.name, func(t *testing.T) {
			// モックの期待される動作を定義
			testcase.setupRepositoryMock(mockRepository)

			// usecaseの実行
			result, err := NewUserUsecase(mockRepository, mockTxManager).Regist(
				context.Background(),
				testcase.requestBody,
			)

			// レスポンスの検証
			assert.Equal(t, testcase.expectedBody, result)
			assert.Equal(t, testcase.expectedError, err)
		})
	}
}