Files

214 lines
5.8 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package repository
import (
"context"
"errors"
"gorm.io/gorm"
"github.com/user-management-system/internal/domain"
)
// RoleRepository 角色数据访问层
type RoleRepository struct {
db *gorm.DB
}
// NewRoleRepository 创建角色数据访问层
func NewRoleRepository(db *gorm.DB) *RoleRepository {
return &RoleRepository{db: db}
}
// Create 创建角色
func (r *RoleRepository) Create(ctx context.Context, role *domain.Role) error {
// GORM omits zero values on insert for fields with DB defaults. Explicitly
// backfill disabled status so callers can persist status=0 roles.
requestedStatus := role.Status
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
if err := tx.Create(role).Error; err != nil {
return err
}
if requestedStatus == domain.RoleStatusDisabled {
if err := tx.Model(&domain.Role{}).Where("id = ?", role.ID).Update("status", requestedStatus).Error; err != nil {
return err
}
role.Status = requestedStatus
}
return nil
})
}
// Update 更新角色
func (r *RoleRepository) Update(ctx context.Context, role *domain.Role) error {
return r.db.WithContext(ctx).Save(role).Error
}
// Delete 删除角色
func (r *RoleRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&domain.Role{}, id).Error
}
// GetByID 根据ID获取角色
func (r *RoleRepository) GetByID(ctx context.Context, id int64) (*domain.Role, error) {
var role domain.Role
err := r.db.WithContext(ctx).First(&role, id).Error
if err != nil {
return nil, err
}
return &role, nil
}
// GetByCode 根据代码获取角色
func (r *RoleRepository) GetByCode(ctx context.Context, code string) (*domain.Role, error) {
var role domain.Role
err := r.db.WithContext(ctx).Where("code = ?", code).First(&role).Error
if err != nil {
return nil, err
}
return &role, nil
}
// List 获取角色列表
func (r *RoleRepository) List(ctx context.Context, offset, limit int) ([]*domain.Role, int64, error) {
var roles []*domain.Role
var total int64
query := r.db.WithContext(ctx).Model(&domain.Role{})
// 获取总数
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
}
// 获取列表
if err := query.Offset(offset).Limit(limit).Find(&roles).Error; err != nil {
return nil, 0, err
}
return roles, total, nil
}
// ListByStatus 根据状态获取角色列表
func (r *RoleRepository) ListByStatus(ctx context.Context, status domain.RoleStatus, offset, limit int) ([]*domain.Role, int64, error) {
var roles []*domain.Role
var total int64
query := r.db.WithContext(ctx).Model(&domain.Role{}).Where("status = ?", status)
// 获取总数
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
}
// 获取列表
if err := query.Offset(offset).Limit(limit).Find(&roles).Error; err != nil {
return nil, 0, err
}
return roles, total, nil
}
// GetDefaultRoles 获取默认角色
func (r *RoleRepository) GetDefaultRoles(ctx context.Context) ([]*domain.Role, error) {
var roles []*domain.Role
err := r.db.WithContext(ctx).Where("is_default = ?", true).Find(&roles).Error
if err != nil {
return nil, err
}
return roles, nil
}
// ExistsByCode 检查角色代码是否存在
func (r *RoleRepository) ExistsByCode(ctx context.Context, code string) (bool, error) {
var count int64
err := r.db.WithContext(ctx).Model(&domain.Role{}).Where("code = ?", code).Count(&count).Error
return count > 0, err
}
// UpdateStatus 更新角色状态
func (r *RoleRepository) UpdateStatus(ctx context.Context, id int64, status domain.RoleStatus) error {
return r.db.WithContext(ctx).Model(&domain.Role{}).Where("id = ?", id).Update("status", status).Error
}
// Search 搜索角色
func (r *RoleRepository) Search(ctx context.Context, keyword string, offset, limit int) ([]*domain.Role, int64, error) {
var roles []*domain.Role
var total int64
query := r.db.WithContext(ctx).Model(&domain.Role{}).
Where("name LIKE ? OR code LIKE ? OR description LIKE ?", "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%")
// 获取总数
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
}
// 获取列表
if err := query.Offset(offset).Limit(limit).Find(&roles).Error; err != nil {
return nil, 0, err
}
return roles, total, nil
}
// ListByParentID 根据父ID获取角色列表
func (r *RoleRepository) ListByParentID(ctx context.Context, parentID int64) ([]*domain.Role, error) {
var roles []*domain.Role
err := r.db.WithContext(ctx).Where("parent_id = ?", parentID).Find(&roles).Error
if err != nil {
return nil, err
}
return roles, nil
}
// GetByIDs 根据ID列表批量获取角色
func (r *RoleRepository) GetByIDs(ctx context.Context, ids []int64) ([]*domain.Role, error) {
if len(ids) == 0 {
return []*domain.Role{}, nil
}
var roles []*domain.Role
err := r.db.WithContext(ctx).Where("id IN ?", ids).Find(&roles).Error
if err != nil {
return nil, err
}
return roles, nil
}
// GetAncestorIDs 获取角色的所有祖先角色ID用于权限继承
func (r *RoleRepository) GetAncestorIDs(ctx context.Context, roleID int64) ([]int64, error) {
var ancestorIDs []int64
currentID := roleID
// 循环向上查找父角色,直到没有父角色为止
for {
var role domain.Role
err := r.db.WithContext(ctx).Select("id", "parent_id").First(&role, currentID).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
break
}
return nil, err
}
if role.ParentID == nil {
break
}
ancestorIDs = append(ancestorIDs, *role.ParentID)
currentID = *role.ParentID
}
return ancestorIDs, nil
}
// GetAncestors 获取角色的完整继承链(从父到子)
func (r *RoleRepository) GetAncestors(ctx context.Context, roleID int64) ([]*domain.Role, error) {
ancestorIDs, err := r.GetAncestorIDs(ctx, roleID)
if err != nil {
return nil, err
}
if len(ancestorIDs) == 0 {
return []*domain.Role{}, nil
}
return r.GetByIDs(ctx, ancestorIDs)
}