Files

257 lines
7.8 KiB
Go

package repository
import (
"context"
"time"
"gorm.io/gorm"
"github.com/user-management-system/internal/domain"
)
// DeviceRepository 设备数据访问层
type DeviceRepository struct {
db *gorm.DB
}
// NewDeviceRepository 创建设备数据访问层
func NewDeviceRepository(db *gorm.DB) *DeviceRepository {
return &DeviceRepository{db: db}
}
// Create 创建设备
func (r *DeviceRepository) Create(ctx context.Context, device *domain.Device) error {
// GORM omits zero values on insert for fields with DB defaults. Explicitly
// backfill inactive status so callers can persist status=0 devices.
requestedStatus := device.Status
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
if err := tx.Create(device).Error; err != nil {
return err
}
if requestedStatus == domain.DeviceStatusInactive {
if err := tx.Model(&domain.Device{}).Where("id = ?", device.ID).Update("status", requestedStatus).Error; err != nil {
return err
}
device.Status = requestedStatus
}
return nil
})
}
// Update 更新设备
func (r *DeviceRepository) Update(ctx context.Context, device *domain.Device) error {
return r.db.WithContext(ctx).Save(device).Error
}
// Delete 删除设备
func (r *DeviceRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&domain.Device{}, id).Error
}
// GetByID 根据ID获取设备
func (r *DeviceRepository) GetByID(ctx context.Context, id int64) (*domain.Device, error) {
var device domain.Device
err := r.db.WithContext(ctx).First(&device, id).Error
if err != nil {
return nil, err
}
return &device, nil
}
// GetByDeviceID 根据设备ID和用户ID获取设备
func (r *DeviceRepository) GetByDeviceID(ctx context.Context, userID int64, deviceID string) (*domain.Device, error) {
var device domain.Device
err := r.db.WithContext(ctx).Where("user_id = ? AND device_id = ?", userID, deviceID).First(&device).Error
if err != nil {
return nil, err
}
return &device, nil
}
// List 获取设备列表
func (r *DeviceRepository) List(ctx context.Context, offset, limit int) ([]*domain.Device, int64, error) {
var devices []*domain.Device
var total int64
query := r.db.WithContext(ctx).Model(&domain.Device{})
// 获取总数
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
}
// 获取列表
if err := query.Offset(offset).Limit(limit).Find(&devices).Error; err != nil {
return nil, 0, err
}
return devices, total, nil
}
// ListByUserID 根据用户ID获取设备列表
func (r *DeviceRepository) ListByUserID(ctx context.Context, userID int64, offset, limit int) ([]*domain.Device, int64, error) {
var devices []*domain.Device
var total int64
query := r.db.WithContext(ctx).Model(&domain.Device{}).Where("user_id = ?", userID)
// 获取总数
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
}
// 获取列表
if err := query.Offset(offset).Limit(limit).Order("last_active_time DESC").Find(&devices).Error; err != nil {
return nil, 0, err
}
return devices, total, nil
}
// ListByStatus 根据状态获取设备列表
func (r *DeviceRepository) ListByStatus(ctx context.Context, status domain.DeviceStatus, offset, limit int) ([]*domain.Device, int64, error) {
var devices []*domain.Device
var total int64
query := r.db.WithContext(ctx).Model(&domain.Device{}).Where("status = ?", status)
// 获取总数
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
}
// 获取列表
if err := query.Offset(offset).Limit(limit).Find(&devices).Error; err != nil {
return nil, 0, err
}
return devices, total, nil
}
// UpdateStatus 更新设备状态
func (r *DeviceRepository) UpdateStatus(ctx context.Context, id int64, status domain.DeviceStatus) error {
return r.db.WithContext(ctx).Model(&domain.Device{}).Where("id = ?", id).Update("status", status).Error
}
// UpdateLastActiveTime 更新最后活跃时间
func (r *DeviceRepository) UpdateLastActiveTime(ctx context.Context, id int64) error {
now := time.Now()
return r.db.WithContext(ctx).Model(&domain.Device{}).Where("id = ?", id).Update("last_active_time", now).Error
}
// Exists 检查设备是否存在
func (r *DeviceRepository) Exists(ctx context.Context, userID int64, deviceID string) (bool, error) {
var count int64
err := r.db.WithContext(ctx).Model(&domain.Device{}).
Where("user_id = ? AND device_id = ?", userID, deviceID).
Count(&count).Error
return count > 0, err
}
// DeleteByUserID 删除用户的所有设备
func (r *DeviceRepository) DeleteByUserID(ctx context.Context, userID int64) error {
return r.db.WithContext(ctx).Where("user_id = ?", userID).Delete(&domain.Device{}).Error
}
// GetActiveDevices 获取活跃设备
func (r *DeviceRepository) GetActiveDevices(ctx context.Context, userID int64) ([]*domain.Device, error) {
var devices []*domain.Device
thirtyDaysAgo := time.Now().Add(-30 * 24 * time.Hour)
err := r.db.WithContext(ctx).
Where("user_id = ? AND last_active_time > ?", userID, thirtyDaysAgo).
Order("last_active_time DESC").
Find(&devices).Error
if err != nil {
return nil, err
}
return devices, nil
}
// TrustDevice 设置设备为信任状态
func (r *DeviceRepository) TrustDevice(ctx context.Context, deviceID int64, expiresAt *time.Time) error {
updates := map[string]interface{}{
"is_trusted": true,
"trust_expires_at": expiresAt,
}
return r.db.WithContext(ctx).Model(&domain.Device{}).Where("id = ?", deviceID).Updates(updates).Error
}
// UntrustDevice 取消设备信任状态
func (r *DeviceRepository) UntrustDevice(ctx context.Context, deviceID int64) error {
updates := map[string]interface{}{
"is_trusted": false,
"trust_expires_at": nil,
}
return r.db.WithContext(ctx).Model(&domain.Device{}).Where("id = ?", deviceID).Updates(updates).Error
}
// DeleteAllByUserIDExcept 删除用户的所有设备(除指定设备外)
func (r *DeviceRepository) DeleteAllByUserIDExcept(ctx context.Context, userID int64, exceptDeviceID int64) error {
return r.db.WithContext(ctx).
Where("user_id = ? AND id != ?", userID, exceptDeviceID).
Delete(&domain.Device{}).Error
}
// GetTrustedDevices 获取用户的信任设备列表
func (r *DeviceRepository) GetTrustedDevices(ctx context.Context, userID int64) ([]*domain.Device, error) {
var devices []*domain.Device
now := time.Now()
err := r.db.WithContext(ctx).
Where("user_id = ? AND is_trusted = ? AND (trust_expires_at IS NULL OR trust_expires_at > ?)", userID, true, now).
Order("last_active_time DESC").
Find(&devices).Error
if err != nil {
return nil, err
}
return devices, nil
}
// ListDevicesParams 设备列表查询参数
type ListDevicesParams struct {
UserID int64
Status domain.DeviceStatus
IsTrusted *bool
Keyword string
Offset int
Limit int
}
// ListAll 获取所有设备列表(支持筛选)
func (r *DeviceRepository) ListAll(ctx context.Context, params *ListDevicesParams) ([]*domain.Device, int64, error) {
var devices []*domain.Device
var total int64
query := r.db.WithContext(ctx).Model(&domain.Device{})
// 按用户ID筛选
if params.UserID > 0 {
query = query.Where("user_id = ?", params.UserID)
}
// 按状态筛选
if params.Status >= 0 {
query = query.Where("status = ?", params.Status)
}
// 按信任状态筛选
if params.IsTrusted != nil {
query = query.Where("is_trusted = ?", *params.IsTrusted)
}
// 按关键词筛选(设备名/IP/位置)
if params.Keyword != "" {
search := "%" + params.Keyword + "%"
query = query.Where("device_name LIKE ? OR ip LIKE ? OR location LIKE ?", search, search, search)
}
// 获取总数
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
}
// 获取列表
if err := query.Offset(params.Offset).Limit(params.Limit).
Order("last_active_time DESC").Find(&devices).Error; err != nil {
return nil, 0, err
}
return devices, total, nil
}