Files
user-system/internal/database/db.go
long-agent 582ad7a069 test: add comprehensive test coverage and improve code quality
- Add new test files for auth, service, and handler modules
- Improve test organization and coverage
- Refactor code for better maintainability
- Add captcha, settings, stats, and theme handler tests
- Add auth module tests (CAS, OAuth, password, SSO, state)
- Add service layer tests for auth, export, permissions, roles
- All Go tests pass (exit code 0)
- All frontend tests pass (325 tests in 59 files)
2026-04-17 20:43:50 +08:00

250 lines
7.2 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package database
import (
"fmt"
"log"
"time"
"github.com/glebarez/sqlite"
"gorm.io/gorm"
"github.com/user-management-system/internal/auth"
"github.com/user-management-system/internal/config"
"github.com/user-management-system/internal/domain"
)
type DB struct {
*gorm.DB
}
func NewDB(cfg *config.Config) (*DB, error) {
// 当前仅支持 SQLite
// 如果配置中指定了数据库路径则使用它,否则使用默认路径
dbPath := "./data/user_management.db"
if cfg != nil && cfg.Database.DBName != "" {
dbPath = cfg.Database.DBName
}
dialector := sqlite.Open(dbPath)
db, err := gorm.Open(dialector, &gorm.Config{})
if err != nil {
return nil, fmt.Errorf("connect database failed: %w", err)
}
// WARN-02 修复:开启 WAL 模式提升并发读写性能
// WALWrite-Ahead Logging允许读写并发显著减少写操作对读操作的阻塞
sqlDB, err := db.DB()
if err != nil {
return nil, fmt.Errorf("get underlying sql.DB failed: %w", err)
}
// 开启 WAL 模式
if _, err := sqlDB.Exec("PRAGMA journal_mode=WAL"); err != nil {
log.Printf("warn: enable WAL mode failed: %v", err)
}
// 开启同步模式 NORMALWAL 下 NORMAL 已足够安全,比 FULL 快很多)
if _, err := sqlDB.Exec("PRAGMA synchronous=NORMAL"); err != nil {
log.Printf("warn: set synchronous=NORMAL failed: %v", err)
}
// 缓存大小8MB单位负数表示 KB
if _, err := sqlDB.Exec("PRAGMA cache_size=-8192"); err != nil {
log.Printf("warn: set cache_size failed: %v", err)
}
// 开启外键约束SQLite 默认关闭)
if _, err := sqlDB.Exec("PRAGMA foreign_keys=ON"); err != nil {
log.Printf("warn: enable foreign_keys failed: %v", err)
}
// Busy Timeout5 秒(减少写冲突时的 SQLITE_BUSY 错误)
if _, err := sqlDB.Exec("PRAGMA busy_timeout=5000"); err != nil {
log.Printf("warn: set busy_timeout failed: %v", err)
}
// 连接池配置SQLite 本身不支持真正的并发写,但需要控制连接数量
sqlDB.SetMaxOpenConns(10)
sqlDB.SetMaxIdleConns(5)
sqlDB.SetConnMaxLifetime(30 * time.Minute)
sqlDB.SetConnMaxIdleTime(10 * time.Minute)
log.Println("database: SQLite WAL mode enabled, connection pool configured")
return &DB{DB: db}, nil
}
func (db *DB) AutoMigrate(cfg *config.Config) error {
log.Println("starting database migration")
if err := db.DB.AutoMigrate(
&domain.User{},
&domain.Role{},
&domain.Permission{},
&domain.UserRole{},
&domain.RolePermission{},
&domain.Device{},
&domain.LoginLog{},
&domain.OperationLog{},
&domain.SocialAccount{},
&domain.Webhook{},
&domain.WebhookDelivery{},
&domain.PasswordHistory{},
); err != nil {
return fmt.Errorf("database migration failed: %w", err)
}
if err := db.initDefaultData(cfg); err != nil {
return fmt.Errorf("initialize default data failed: %w", err)
}
return nil
}
func (db *DB) initDefaultData(cfg *config.Config) error {
var count int64
if err := db.DB.Model(&domain.Role{}).Count(&count).Error; err != nil {
return err
}
if count > 0 {
// 角色已存在,仍需补充权限数据(升级场景)
if err := db.ensurePermissions(); err != nil {
log.Printf("warn: ensure permissions failed: %v", err)
}
log.Println("default data already exists, skipping bootstrap")
return nil
}
log.Println("bootstrapping default roles and permissions")
// 1. 创建角色
var adminRoleID int64
var userRoleID int64
for _, predefined := range domain.PredefinedRoles {
role := predefined
if err := db.DB.Create(&role).Error; err != nil {
return fmt.Errorf("create role failed: %w", err)
}
if role.Code == "admin" {
adminRoleID = role.ID
}
if role.Code == "user" {
userRoleID = role.ID
}
}
// 2. 创建权限
permIDs, err := db.createDefaultPermissions()
if err != nil {
return fmt.Errorf("create permissions failed: %w", err)
}
// 3. 给 admin 角色绑定所有权限
if adminRoleID > 0 {
for _, permID := range permIDs {
db.DB.Create(&domain.RolePermission{RoleID: adminRoleID, PermissionID: permID})
}
log.Printf("assigned %d permissions to admin role", len(permIDs))
}
// 4. 给普通用户角色绑定基础权限
if userRoleID > 0 {
userPermCodes := []string{"profile:view", "profile:edit", "log:view_own"}
for _, code := range userPermCodes {
var perm domain.Permission
if err := db.DB.Where("code = ?", code).First(&perm).Error; err == nil {
db.DB.Create(&domain.RolePermission{RoleID: userRoleID, PermissionID: perm.ID})
}
}
}
// 5. 创建 admin 用户
adminUsername := cfg.Default.AdminEmail
adminPassword := cfg.Default.AdminPassword
if adminUsername == "" || adminPassword == "" {
log.Println("admin bootstrap skipped: default.admin_email/admin_password not configured")
return nil
}
passwordHash, err := auth.HashPassword(adminPassword)
if err != nil {
return fmt.Errorf("hash admin password failed: %w", err)
}
adminUser := &domain.User{
Username: adminUsername,
Email: domain.StrPtr(adminUsername),
Password: passwordHash,
Nickname: "系统管理员",
Status: domain.UserStatusActive,
}
if err := db.DB.Create(adminUser).Error; err != nil {
return fmt.Errorf("create admin user failed: %w", err)
}
if adminRoleID == 0 {
return fmt.Errorf("admin role missing during bootstrap")
}
if err := db.DB.Create(&domain.UserRole{
UserID: adminUser.ID,
RoleID: adminRoleID,
}).Error; err != nil {
return fmt.Errorf("assign admin role failed: %w", err)
}
log.Printf("bootstrap completed: admin user=%s, roles=%d, permissions=%d",
adminUser.Username, 2, len(permIDs))
return nil
}
// ensurePermissions 在升级场景中补充缺失的权限数据
func (db *DB) ensurePermissions() error {
var permCount int64
db.DB.Model(&domain.Permission{}).Count(&permCount)
if permCount > 0 {
return nil // 已有权限数据
}
log.Println("permissions table is empty, seeding default permissions")
permIDs, err := db.createDefaultPermissions()
if err != nil {
return err
}
// 找到 admin 角色并绑定所有权限
var adminRole domain.Role
if err := db.DB.Where("code = ?", "admin").First(&adminRole).Error; err == nil {
for _, permID := range permIDs {
db.DB.Create(&domain.RolePermission{RoleID: adminRole.ID, PermissionID: permID})
}
log.Printf("assigned %d permissions to admin role (upgrade)", len(permIDs))
}
// 找到普通用户角色并绑定基础权限
var userRole domain.Role
if err := db.DB.Where("code = ?", "user").First(&userRole).Error; err == nil {
userPermCodes := []string{"profile:view", "profile:edit", "log:view_own"}
for _, code := range userPermCodes {
var perm domain.Permission
if err := db.DB.Where("code = ?", code).First(&perm).Error; err == nil {
db.DB.Create(&domain.RolePermission{RoleID: userRole.ID, PermissionID: perm.ID})
}
}
}
return nil
}
// createDefaultPermissions 创建默认权限列表,返回所有权限 ID
func (db *DB) createDefaultPermissions() ([]int64, error) {
permissions := domain.DefaultPermissions()
var ids []int64
for i := range permissions {
p := permissions[i]
// 使用 FirstOrCreate 防止重复插入(幂等)
result := db.DB.Where("code = ?", p.Code).FirstOrCreate(&p)
if result.Error != nil {
log.Printf("warn: create permission %s failed: %v", p.Code, result.Error)
continue
}
ids = append(ids, p.ID)
}
return ids, nil
}