213 lines
5.8 KiB
Go
213 lines
5.8 KiB
Go
|
|
package database
|
||
|
|
|
||
|
|
import (
|
||
|
|
"fmt"
|
||
|
|
"log"
|
||
|
|
|
||
|
|
"github.com/glebarez/sqlite"
|
||
|
|
"gorm.io/gorm"
|
||
|
|
|
||
|
|
"github.com/user-management-system/internal/auth"
|
||
|
|
"github.com/user-management-system/internal/config"
|
||
|
|
"github.com/user-management-system/internal/domain"
|
||
|
|
)
|
||
|
|
|
||
|
|
type DB struct {
|
||
|
|
*gorm.DB
|
||
|
|
}
|
||
|
|
|
||
|
|
func NewDB(cfg *config.Config) (*DB, error) {
|
||
|
|
// 当前仅支持 SQLite
|
||
|
|
// 如果配置中指定了数据库路径则使用它,否则使用默认路径
|
||
|
|
dbPath := "./data/user_management.db"
|
||
|
|
if cfg != nil && cfg.Database.DBName != "" {
|
||
|
|
dbPath = cfg.Database.DBName
|
||
|
|
}
|
||
|
|
dialector := sqlite.Open(dbPath)
|
||
|
|
|
||
|
|
db, err := gorm.Open(dialector, &gorm.Config{})
|
||
|
|
if err != nil {
|
||
|
|
return nil, fmt.Errorf("connect database failed: %w", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
return &DB{DB: db}, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func (db *DB) AutoMigrate(cfg *config.Config) error {
|
||
|
|
log.Println("starting database migration")
|
||
|
|
if err := db.DB.AutoMigrate(
|
||
|
|
&domain.User{},
|
||
|
|
&domain.Role{},
|
||
|
|
&domain.Permission{},
|
||
|
|
&domain.UserRole{},
|
||
|
|
&domain.RolePermission{},
|
||
|
|
&domain.Device{},
|
||
|
|
&domain.LoginLog{},
|
||
|
|
&domain.OperationLog{},
|
||
|
|
&domain.SocialAccount{},
|
||
|
|
&domain.Webhook{},
|
||
|
|
&domain.WebhookDelivery{},
|
||
|
|
&domain.PasswordHistory{},
|
||
|
|
); err != nil {
|
||
|
|
return fmt.Errorf("database migration failed: %w", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
if err := db.initDefaultData(cfg); err != nil {
|
||
|
|
return fmt.Errorf("initialize default data failed: %w", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func (db *DB) initDefaultData(cfg *config.Config) error {
|
||
|
|
var count int64
|
||
|
|
if err := db.DB.Model(&domain.Role{}).Count(&count).Error; err != nil {
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
if count > 0 {
|
||
|
|
// 角色已存在,仍需补充权限数据(升级场景)
|
||
|
|
if err := db.ensurePermissions(); err != nil {
|
||
|
|
log.Printf("warn: ensure permissions failed: %v", err)
|
||
|
|
}
|
||
|
|
log.Println("default data already exists, skipping bootstrap")
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
log.Println("bootstrapping default roles and permissions")
|
||
|
|
|
||
|
|
// 1. 创建角色
|
||
|
|
var adminRoleID int64
|
||
|
|
var userRoleID int64
|
||
|
|
for _, predefined := range domain.PredefinedRoles {
|
||
|
|
role := predefined
|
||
|
|
if err := db.DB.Create(&role).Error; err != nil {
|
||
|
|
return fmt.Errorf("create role failed: %w", err)
|
||
|
|
}
|
||
|
|
if role.Code == "admin" {
|
||
|
|
adminRoleID = role.ID
|
||
|
|
}
|
||
|
|
if role.Code == "user" {
|
||
|
|
userRoleID = role.ID
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// 2. 创建权限
|
||
|
|
permIDs, err := db.createDefaultPermissions()
|
||
|
|
if err != nil {
|
||
|
|
return fmt.Errorf("create permissions failed: %w", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
// 3. 给 admin 角色绑定所有权限
|
||
|
|
if adminRoleID > 0 {
|
||
|
|
for _, permID := range permIDs {
|
||
|
|
db.DB.Create(&domain.RolePermission{RoleID: adminRoleID, PermissionID: permID})
|
||
|
|
}
|
||
|
|
log.Printf("assigned %d permissions to admin role", len(permIDs))
|
||
|
|
}
|
||
|
|
|
||
|
|
// 4. 给普通用户角色绑定基础权限
|
||
|
|
if userRoleID > 0 {
|
||
|
|
userPermCodes := []string{"profile:view", "profile:edit", "log:view_own"}
|
||
|
|
for _, code := range userPermCodes {
|
||
|
|
var perm domain.Permission
|
||
|
|
if err := db.DB.Where("code = ?", code).First(&perm).Error; err == nil {
|
||
|
|
db.DB.Create(&domain.RolePermission{RoleID: userRoleID, PermissionID: perm.ID})
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// 5. 创建 admin 用户
|
||
|
|
adminUsername := cfg.Default.AdminEmail
|
||
|
|
adminPassword := cfg.Default.AdminPassword
|
||
|
|
if adminUsername == "" || adminPassword == "" {
|
||
|
|
log.Println("admin bootstrap skipped: default.admin_email/admin_password not configured")
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
passwordHash, err := auth.HashPassword(adminPassword)
|
||
|
|
if err != nil {
|
||
|
|
return fmt.Errorf("hash admin password failed: %w", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
adminUser := &domain.User{
|
||
|
|
Username: adminUsername,
|
||
|
|
Email: domain.StrPtr(adminUsername),
|
||
|
|
Password: passwordHash,
|
||
|
|
Nickname: "系统管理员",
|
||
|
|
Status: domain.UserStatusActive,
|
||
|
|
}
|
||
|
|
if err := db.DB.Create(adminUser).Error; err != nil {
|
||
|
|
return fmt.Errorf("create admin user failed: %w", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
if adminRoleID == 0 {
|
||
|
|
return fmt.Errorf("admin role missing during bootstrap")
|
||
|
|
}
|
||
|
|
|
||
|
|
if err := db.DB.Create(&domain.UserRole{
|
||
|
|
UserID: adminUser.ID,
|
||
|
|
RoleID: adminRoleID,
|
||
|
|
}).Error; err != nil {
|
||
|
|
return fmt.Errorf("assign admin role failed: %w", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
log.Printf("bootstrap completed: admin user=%s, roles=%d, permissions=%d",
|
||
|
|
adminUser.Username, 2, len(permIDs))
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
// ensurePermissions 在升级场景中补充缺失的权限数据
|
||
|
|
func (db *DB) ensurePermissions() error {
|
||
|
|
var permCount int64
|
||
|
|
db.DB.Model(&domain.Permission{}).Count(&permCount)
|
||
|
|
if permCount > 0 {
|
||
|
|
return nil // 已有权限数据
|
||
|
|
}
|
||
|
|
|
||
|
|
log.Println("permissions table is empty, seeding default permissions")
|
||
|
|
permIDs, err := db.createDefaultPermissions()
|
||
|
|
if err != nil {
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
|
||
|
|
// 找到 admin 角色并绑定所有权限
|
||
|
|
var adminRole domain.Role
|
||
|
|
if err := db.DB.Where("code = ?", "admin").First(&adminRole).Error; err == nil {
|
||
|
|
for _, permID := range permIDs {
|
||
|
|
db.DB.Create(&domain.RolePermission{RoleID: adminRole.ID, PermissionID: permID})
|
||
|
|
}
|
||
|
|
log.Printf("assigned %d permissions to admin role (upgrade)", len(permIDs))
|
||
|
|
}
|
||
|
|
|
||
|
|
// 找到普通用户角色并绑定基础权限
|
||
|
|
var userRole domain.Role
|
||
|
|
if err := db.DB.Where("code = ?", "user").First(&userRole).Error; err == nil {
|
||
|
|
userPermCodes := []string{"profile:view", "profile:edit", "log:view_own"}
|
||
|
|
for _, code := range userPermCodes {
|
||
|
|
var perm domain.Permission
|
||
|
|
if err := db.DB.Where("code = ?", code).First(&perm).Error; err == nil {
|
||
|
|
db.DB.Create(&domain.RolePermission{RoleID: userRole.ID, PermissionID: perm.ID})
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
// createDefaultPermissions 创建默认权限列表,返回所有权限 ID
|
||
|
|
func (db *DB) createDefaultPermissions() ([]int64, error) {
|
||
|
|
permissions := domain.DefaultPermissions()
|
||
|
|
var ids []int64
|
||
|
|
for i := range permissions {
|
||
|
|
p := permissions[i]
|
||
|
|
// 使用 FirstOrCreate 防止重复插入(幂等)
|
||
|
|
result := db.DB.Where("code = ?", p.Code).FirstOrCreate(&p)
|
||
|
|
if result.Error != nil {
|
||
|
|
log.Printf("warn: create permission %s failed: %v", p.Code, result.Error)
|
||
|
|
continue
|
||
|
|
}
|
||
|
|
ids = append(ids, p.ID)
|
||
|
|
}
|
||
|
|
return ids, nil
|
||
|
|
}
|