Files
user-system/internal/repository/user_role.go

176 lines
5.2 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package repository
import (
"context"
"gorm.io/gorm"
"github.com/user-management-system/internal/domain"
)
// UserRoleRepository 用户角色关联数据访问层
type UserRoleRepository struct {
db *gorm.DB
}
// NewUserRoleRepository 创建用户角色关联数据访问层
func NewUserRoleRepository(db *gorm.DB) *UserRoleRepository {
return &UserRoleRepository{db: db}
}
// Create 创建用户角色关联
func (r *UserRoleRepository) Create(ctx context.Context, userRole *domain.UserRole) error {
return r.db.WithContext(ctx).Create(userRole).Error
}
// Delete 删除用户角色关联
func (r *UserRoleRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&domain.UserRole{}, id).Error
}
// DeleteByUserID 删除用户的所有角色
func (r *UserRoleRepository) DeleteByUserID(ctx context.Context, userID int64) error {
return r.db.WithContext(ctx).Where("user_id = ?", userID).Delete(&domain.UserRole{}).Error
}
// DeleteByRoleID 删除角色的所有用户
func (r *UserRoleRepository) DeleteByRoleID(ctx context.Context, roleID int64) error {
return r.db.WithContext(ctx).Where("role_id = ?", roleID).Delete(&domain.UserRole{}).Error
}
// GetByUserID 根据用户ID获取角色列表
func (r *UserRoleRepository) GetByUserID(ctx context.Context, userID int64) ([]*domain.UserRole, error) {
var userRoles []*domain.UserRole
err := r.db.WithContext(ctx).Where("user_id = ?", userID).Find(&userRoles).Error
if err != nil {
return nil, err
}
return userRoles, nil
}
// GetByRoleID 根据角色ID获取用户列表
func (r *UserRoleRepository) GetByRoleID(ctx context.Context, roleID int64) ([]*domain.UserRole, error) {
var userRoles []*domain.UserRole
err := r.db.WithContext(ctx).Where("role_id = ?", roleID).Find(&userRoles).Error
if err != nil {
return nil, err
}
return userRoles, nil
}
// GetRoleIDsByUserID 根据用户ID获取角色ID列表
func (r *UserRoleRepository) GetRoleIDsByUserID(ctx context.Context, userID int64) ([]int64, error) {
var roleIDs []int64
err := r.db.WithContext(ctx).Model(&domain.UserRole{}).Where("user_id = ?", userID).Pluck("role_id", &roleIDs).Error
if err != nil {
return nil, err
}
return roleIDs, nil
}
// GetUserRolesAndPermissions 获取用户角色和权限PERF-01 优化:合并为单次 JOIN 查询)
func (r *UserRoleRepository) GetUserRolesAndPermissions(ctx context.Context, userID int64) ([]*domain.Role, []*domain.Permission, error) {
var results []struct {
RoleID int64
RoleName string
RoleCode string
RoleStatus int
PermissionID int64
PermissionCode string
PermissionName string
}
// 使用 LEFT JOIN 一次性获取用户角色和权限
err := r.db.WithContext(ctx).
Raw(`
SELECT DISTINCT r.id as role_id, r.name as role_name, r.code as role_code, r.status as role_status,
p.id as permission_id, p.code as permission_code, p.name as permission_name
FROM user_roles ur
JOIN roles r ON ur.role_id = r.id
LEFT JOIN role_permissions rp ON r.id = rp.role_id
LEFT JOIN permissions p ON rp.permission_id = p.id
WHERE ur.user_id = ? AND r.status = 1
`, userID).
Scan(&results).Error
if err != nil {
return nil, nil, err
}
// 构建角色和权限列表
roleMap := make(map[int64]*domain.Role)
permMap := make(map[int64]*domain.Permission)
for _, row := range results {
if _, ok := roleMap[row.RoleID]; !ok {
roleMap[row.RoleID] = &domain.Role{
ID: row.RoleID,
Name: row.RoleName,
Code: row.RoleCode,
Status: domain.RoleStatus(row.RoleStatus),
}
}
if row.PermissionID > 0 {
if _, ok := permMap[row.PermissionID]; !ok {
permMap[row.PermissionID] = &domain.Permission{
ID: row.PermissionID,
Code: row.PermissionCode,
Name: row.PermissionName,
}
}
}
}
roles := make([]*domain.Role, 0, len(roleMap))
for _, role := range roleMap {
roles = append(roles, role)
}
perms := make([]*domain.Permission, 0, len(permMap))
for _, perm := range permMap {
perms = append(perms, perm)
}
return roles, perms, nil
}
// GetUserIDByRoleID 根据角色ID获取用户ID列表
func (r *UserRoleRepository) GetUserIDByRoleID(ctx context.Context, roleID int64) ([]int64, error) {
var userIDs []int64
err := r.db.WithContext(ctx).Model(&domain.UserRole{}).Where("role_id = ?", roleID).Pluck("user_id", &userIDs).Error
if err != nil {
return nil, err
}
return userIDs, nil
}
// Exists 检查用户角色关联是否存在
func (r *UserRoleRepository) Exists(ctx context.Context, userID, roleID int64) (bool, error) {
var count int64
err := r.db.WithContext(ctx).Model(&domain.UserRole{}).
Where("user_id = ? AND role_id = ?", userID, roleID).
Count(&count).Error
return count > 0, err
}
// BatchCreate 批量创建用户角色关联
func (r *UserRoleRepository) BatchCreate(ctx context.Context, userRoles []*domain.UserRole) error {
if len(userRoles) == 0 {
return nil
}
return r.db.WithContext(ctx).Create(&userRoles).Error
}
// BatchDelete 批量删除用户角色关联
func (r *UserRoleRepository) BatchDelete(ctx context.Context, userRoles []*domain.UserRole) error {
if len(userRoles) == 0 {
return nil
}
var ids []int64
for _, ur := range userRoles {
ids = append(ids, ur.ID)
}
return r.db.WithContext(ctx).Delete(&domain.UserRole{}, ids).Error
}