fix: wrap AssignRoles in transaction and eliminate N+1 queries

- AssignRoles: wrap DeleteByUserID + BatchCreate in DB transaction (P1)
- GetUserRoles: use GetByIDs batch query instead of per-role GetByID loop (N+1 fix)
- ListAdmins: use GetByIDs batch query instead of per-user GetByID loop (N+1 fix)
- Add WithTx/DB methods to UserRoleRepository for transaction support
- Add GetByIDs to UserRepository (batch user lookup)
- Add .gitattributes to normalize line endings to LF (P2)
This commit is contained in:
2026-04-11 10:32:33 +08:00
parent 8c1cf54213
commit c2096ff008
4 changed files with 77 additions and 24 deletions

32
.gitattributes vendored Normal file
View File

@@ -0,0 +1,32 @@
# Normalize line endings to LF for all text files
* text=auto eol=lf
# Enforce LF for source files
*.go text eol=lf
*.ts text eol=lf
*.tsx text eol=lf
*.js text eol=lf
*.jsx text eol=lf
*.css text eol=lf
*.scss text eol=lf
*.html text eol=lf
*.htm text eol=lf
*.json text eol=lf
*.yaml text eol=lf
*.yml text eol=lf
*.md text eol=lf
*.sh text eol=lf
*.ps1 text eol=lf
*.mjs text eol=lf
*.cjs text eol=lf
# Binary files
*.png binary
*.jpg binary
*.jpeg binary
*.gif binary
*.ico binary
*.pdf binary
*.zip binary
*.gz binary
*.tar binary

View File

@@ -31,6 +31,11 @@ func NewUserRepository(db *gorm.DB) *UserRepository {
return &UserRepository{db: db}
}
// DB returns the underlying GORM DB for transaction support
func (r *UserRepository) DB() *gorm.DB {
return r.db
}
// Create 创建用户
func (r *UserRepository) Create(ctx context.Context, user *domain.User) error {
return r.db.WithContext(ctx).Create(user).Error
@@ -56,6 +61,19 @@ func (r *UserRepository) GetByID(ctx context.Context, id int64) (*domain.User, e
return &user, nil
}
// GetByIDs 批量获取用户(消除 N+1 查询)
func (r *UserRepository) GetByIDs(ctx context.Context, ids []int64) ([]*domain.User, error) {
if len(ids) == 0 {
return []*domain.User{}, nil
}
var users []*domain.User
err := r.db.WithContext(ctx).Where("id IN ?", ids).Find(&users).Error
if err != nil {
return nil, err
}
return users, nil
}
// GetByUsername 根据用户名获取用户
func (r *UserRepository) GetByUsername(ctx context.Context, username string) (*domain.User, error) {
var user domain.User

View File

@@ -18,6 +18,16 @@ func NewUserRoleRepository(db *gorm.DB) *UserRoleRepository {
return &UserRoleRepository{db: db}
}
// DB returns the underlying GORM DB for transaction support
func (r *UserRoleRepository) DB() *gorm.DB {
return r.db
}
// WithTx returns a new repository instance that uses the given transaction
func (r *UserRoleRepository) WithTx(tx *gorm.DB) *UserRoleRepository {
return &UserRoleRepository{db: tx}
}
// Create 创建用户角色关联
func (r *UserRoleRepository) Create(ctx context.Context, userRole *domain.UserRole) error {
return r.db.WithContext(ctx).Create(userRole).Error

View File

@@ -235,14 +235,10 @@ func (s *UserService) GetUserRoles(ctx context.Context, userID int64) ([]*domain
roleIDs[i] = ur.RoleID
}
// 批量获取角色详情
var roles []*domain.Role
for _, roleID := range roleIDs {
role, err := s.roleRepo.GetByID(ctx, roleID)
// 批量获取角色详情(消除 N+1 查询)
roles, err := s.roleRepo.GetByIDs(ctx, roleIDs)
if err != nil {
continue // 跳过不存在的角色
}
roles = append(roles, role)
return nil, fmt.Errorf("failed to fetch roles: %w", err)
}
return roles, nil
@@ -255,19 +251,14 @@ func (s *UserService) AssignRoles(ctx context.Context, userID int64, roleIDs []i
return err
}
// 验证所有角色存在
// 验证所有角色存在(预先验证,避免在事务内做不必要的查询)
for _, roleID := range roleIDs {
if _, err := s.roleRepo.GetByID(ctx, roleID); err != nil {
return fmt.Errorf("角色 %d 不存在", roleID)
}
}
// 删除用户现有角色
if err := s.userRoleRepo.DeleteByUserID(ctx, userID); err != nil {
return err
}
// 创建新的用户角色关联
// 构建新的用户角色关联
var userRoles []*domain.UserRole
for _, roleID := range roleIDs {
userRoles = append(userRoles, &domain.UserRole{
@@ -276,7 +267,13 @@ func (s *UserService) AssignRoles(ctx context.Context, userID int64, roleIDs []i
})
}
return s.userRoleRepo.BatchCreate(ctx, userRoles)
// 使用事务包装删旧建新操作,确保原子性
return s.userRoleRepo.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
if err := s.userRoleRepo.WithTx(tx).DeleteByUserID(ctx, userID); err != nil {
return err
}
return s.userRoleRepo.WithTx(tx).BatchCreate(ctx, userRoles)
})
}
// getAdminRoleID looks up the admin role ID by code to avoid hardcoded magic numbers.
@@ -304,14 +301,10 @@ func (s *UserService) ListAdmins(ctx context.Context) ([]*domain.User, error) {
return []*domain.User{}, nil
}
// 获取所有管理员用户
var admins []*domain.User
for _, adminID := range adminUserIDs {
user, err := s.userRepo.GetByID(ctx, adminID)
// 批量获取所有管理员用户(消除 N+1 查询)
admins, err := s.userRepo.GetByIDs(ctx, adminUserIDs)
if err != nil {
continue // 跳过不存在的用户
}
admins = append(admins, user)
return nil, fmt.Errorf("failed to fetch admin users: %w", err)
}
return admins, nil