diff --git a/internal/repository/social_account_repo.go b/internal/repository/social_account_repo.go index 1cbc6c7..56ceb8d 100644 --- a/internal/repository/social_account_repo.go +++ b/internal/repository/social_account_repo.go @@ -5,8 +5,10 @@ import ( "database/sql" "fmt" - "github.com/user-management-system/internal/domain" + gormsqlite "gorm.io/driver/sqlite" "gorm.io/gorm" + + "github.com/user-management-system/internal/domain" ) // SocialAccountRepository 社交账号仓库接口 @@ -23,142 +25,78 @@ type SocialAccountRepository interface { // SocialAccountRepositoryImpl 社交账号仓库实现 type SocialAccountRepositoryImpl struct { - db *sql.DB + db *gorm.DB } -// NewSocialAccountRepository 创建社交账号仓库(支持 gorm.DB 或 *sql.DB) +// NewSocialAccountRepository 创建社交账号仓库。 +// 仓库主实现统一基于 GORM;保留 *sql.DB 构造兼容仅用于当前仓库的 SQLite 测试场景。 func NewSocialAccountRepository(db interface{}) (SocialAccountRepository, error) { - var sqlDB *sql.DB switch d := db.(type) { case *gorm.DB: - var err error - sqlDB, err = d.DB() - if err != nil { - return nil, fmt.Errorf("resolve sql db from gorm db failed: %w", err) + if d == nil { + return nil, fmt.Errorf("gorm db is nil") } + return &SocialAccountRepositoryImpl{db: d}, nil case *sql.DB: - sqlDB = d + if d == nil { + return nil, fmt.Errorf("sql db is nil") + } + gormDB, err := gorm.Open(gormsqlite.New(gormsqlite.Config{ + Conn: d, + DriverName: "sqlite", + }), &gorm.Config{}) + if err != nil { + return nil, fmt.Errorf("wrap sql db with gorm failed: %w", err) + } + return &SocialAccountRepositoryImpl{db: gormDB}, nil default: return nil, fmt.Errorf("unsupported db type: %T", db) } - if sqlDB == nil { - return nil, fmt.Errorf("sql db is nil") - } - return &SocialAccountRepositoryImpl{db: sqlDB}, nil } // Create 创建社交账号 func (r *SocialAccountRepositoryImpl) Create(ctx context.Context, account *domain.SocialAccount) error { - query := ` - INSERT INTO user_social_accounts (user_id, provider, open_id, union_id, nickname, avatar, gender, email, phone, extra, status) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - ` - - result, err := r.db.ExecContext(ctx, query, - account.UserID, - account.Provider, - account.OpenID, - account.UnionID, - account.Nickname, - account.Avatar, - account.Gender, - account.Email, - account.Phone, - account.Extra, - account.Status, - ) - if err != nil { - return fmt.Errorf("failed to create social account: %w", err) - } - - id, err := result.LastInsertId() - if err != nil { - return err - } - - account.ID = id - return nil + return r.db.WithContext(ctx).Create(account).Error } // Update 更新社交账号 func (r *SocialAccountRepositoryImpl) Update(ctx context.Context, account *domain.SocialAccount) error { - query := ` - UPDATE user_social_accounts - SET union_id = ?, nickname = ?, avatar = ?, gender = ?, email = ?, phone = ?, extra = ?, status = ?, updated_at = CURRENT_TIMESTAMP - WHERE id = ? - ` - - _, err := r.db.ExecContext(ctx, query, - account.UnionID, - account.Nickname, - account.Avatar, - account.Gender, - account.Email, - account.Phone, - account.Extra, - account.Status, - account.ID, - ) - if err != nil { - return fmt.Errorf("failed to update social account: %w", err) + updates := map[string]interface{}{ + "union_id": account.UnionID, + "nickname": account.Nickname, + "avatar": account.Avatar, + "gender": account.Gender, + "email": account.Email, + "phone": account.Phone, + "extra": account.Extra, + "status": account.Status, } - return nil + return r.db.WithContext(ctx). + Model(&domain.SocialAccount{}). + Where("id = ?", account.ID). + Updates(updates).Error } // Delete 删除社交账号 func (r *SocialAccountRepositoryImpl) Delete(ctx context.Context, id int64) error { - query := `DELETE FROM user_social_accounts WHERE id = ?` - - _, err := r.db.ExecContext(ctx, query, id) - if err != nil { - return fmt.Errorf("failed to delete social account: %w", err) - } - - return nil + return r.db.WithContext(ctx).Delete(&domain.SocialAccount{}, id).Error } // DeleteByProviderAndUserID 删除指定用户和提供商的社交账号 func (r *SocialAccountRepositoryImpl) DeleteByProviderAndUserID(ctx context.Context, provider string, userID int64) error { - query := `DELETE FROM user_social_accounts WHERE provider = ? AND user_id = ?` - - _, err := r.db.ExecContext(ctx, query, provider, userID) - if err != nil { - return fmt.Errorf("failed to delete social account: %w", err) - } - - return nil + return r.db.WithContext(ctx). + Where("provider = ? AND user_id = ?", provider, userID). + Delete(&domain.SocialAccount{}).Error } // GetByID 根据ID获取社交账号 func (r *SocialAccountRepositoryImpl) GetByID(ctx context.Context, id int64) (*domain.SocialAccount, error) { - query := ` - SELECT id, user_id, provider, open_id, union_id, nickname, avatar, gender, email, phone, extra, status, created_at, updated_at - FROM user_social_accounts - WHERE id = ? - ` - var account domain.SocialAccount - err := r.db.QueryRowContext(ctx, query, id).Scan( - &account.ID, - &account.UserID, - &account.Provider, - &account.OpenID, - &account.UnionID, - &account.Nickname, - &account.Avatar, - &account.Gender, - &account.Email, - &account.Phone, - &account.Extra, - &account.Status, - &account.CreatedAt, - &account.UpdatedAt, - ) - if err == sql.ErrNoRows { - return nil, nil - } - if err != nil { + if err := r.db.WithContext(ctx).First(&account, id).Error; err != nil { + if err == gorm.ErrRecordNotFound { + return nil, nil + } return nil, fmt.Errorf("failed to get social account: %w", err) } @@ -167,45 +105,12 @@ func (r *SocialAccountRepositoryImpl) GetByID(ctx context.Context, id int64) (*d // GetByUserID 根据用户ID获取社交账号列表 func (r *SocialAccountRepositoryImpl) GetByUserID(ctx context.Context, userID int64) ([]*domain.SocialAccount, error) { - query := ` - SELECT id, user_id, provider, open_id, union_id, nickname, avatar, gender, email, phone, extra, status, created_at, updated_at - FROM user_social_accounts - WHERE user_id = ? - ORDER BY created_at DESC - ` - - rows, err := r.db.QueryContext(ctx, query, userID) - if err != nil { - return nil, fmt.Errorf("failed to query social accounts: %w", err) - } - defer rows.Close() - var accounts []*domain.SocialAccount - for rows.Next() { - var account domain.SocialAccount - err := rows.Scan( - &account.ID, - &account.UserID, - &account.Provider, - &account.OpenID, - &account.UnionID, - &account.Nickname, - &account.Avatar, - &account.Gender, - &account.Email, - &account.Phone, - &account.Extra, - &account.Status, - &account.CreatedAt, - &account.UpdatedAt, - ) - if err != nil { - return nil, err - } - accounts = append(accounts, &account) - } - if err := rows.Err(); err != nil { - return nil, err + if err := r.db.WithContext(ctx). + Where("user_id = ?", userID). + Order("created_at DESC"). + Find(&accounts).Error; err != nil { + return nil, fmt.Errorf("failed to query social accounts: %w", err) } return accounts, nil @@ -213,33 +118,13 @@ func (r *SocialAccountRepositoryImpl) GetByUserID(ctx context.Context, userID in // GetByProviderAndOpenID 根据提供商和OpenID获取社交账号 func (r *SocialAccountRepositoryImpl) GetByProviderAndOpenID(ctx context.Context, provider, openID string) (*domain.SocialAccount, error) { - query := ` - SELECT id, user_id, provider, open_id, union_id, nickname, avatar, gender, email, phone, extra, status, created_at, updated_at - FROM user_social_accounts - WHERE provider = ? AND open_id = ? - ` - var account domain.SocialAccount - err := r.db.QueryRowContext(ctx, query, provider, openID).Scan( - &account.ID, - &account.UserID, - &account.Provider, - &account.OpenID, - &account.UnionID, - &account.Nickname, - &account.Avatar, - &account.Gender, - &account.Email, - &account.Phone, - &account.Extra, - &account.Status, - &account.CreatedAt, - &account.UpdatedAt, - ) - if err == sql.ErrNoRows { - return nil, nil - } - if err != nil { + if err := r.db.WithContext(ctx). + Where("provider = ? AND open_id = ?", provider, openID). + First(&account).Error; err != nil { + if err == gorm.ErrRecordNotFound { + return nil, nil + } return nil, fmt.Errorf("failed to get social account: %w", err) } @@ -248,54 +133,16 @@ func (r *SocialAccountRepositoryImpl) GetByProviderAndOpenID(ctx context.Context // List 分页获取社交账号列表 func (r *SocialAccountRepositoryImpl) List(ctx context.Context, offset, limit int) ([]*domain.SocialAccount, int64, error) { - // 获取总数 + var accounts []*domain.SocialAccount var total int64 - countQuery := `SELECT COUNT(*) FROM user_social_accounts` - if err := r.db.QueryRowContext(ctx, countQuery).Scan(&total); err != nil { + query := r.db.WithContext(ctx).Model(&domain.SocialAccount{}) + + if err := query.Count(&total).Error; err != nil { return nil, 0, fmt.Errorf("failed to count social accounts: %w", err) } - - // 获取列表 - query := ` - SELECT id, user_id, provider, open_id, union_id, nickname, avatar, gender, email, phone, extra, status, created_at, updated_at - FROM user_social_accounts - ORDER BY created_at DESC - LIMIT ? OFFSET ? - ` - - rows, err := r.db.QueryContext(ctx, query, limit, offset) - if err != nil { + if err := query.Order("created_at DESC").Offset(offset).Limit(limit).Find(&accounts).Error; err != nil { return nil, 0, fmt.Errorf("failed to query social accounts: %w", err) } - defer rows.Close() - - var accounts []*domain.SocialAccount - for rows.Next() { - var account domain.SocialAccount - err := rows.Scan( - &account.ID, - &account.UserID, - &account.Provider, - &account.OpenID, - &account.UnionID, - &account.Nickname, - &account.Avatar, - &account.Gender, - &account.Email, - &account.Phone, - &account.Extra, - &account.Status, - &account.CreatedAt, - &account.UpdatedAt, - ) - if err != nil { - return nil, 0, err - } - accounts = append(accounts, &account) - } - if err := rows.Err(); err != nil { - return nil, 0, err - } return accounts, total, nil } diff --git a/internal/repository/social_account_repository_test.go b/internal/repository/social_account_repository_test.go index b8bdd58..fd32251 100644 --- a/internal/repository/social_account_repository_test.go +++ b/internal/repository/social_account_repository_test.go @@ -182,6 +182,54 @@ func TestSocialAccountRepository_Update(t *testing.T) { } } +func TestSocialAccountRepository_Update_DoesNotRewriteBindingIdentityFields(t *testing.T) { + db := setupSocialAccountTestDB(t) + repo, err := NewSocialAccountRepository(db) + if err != nil { + t.Fatalf("NewSocialAccountRepository() error = %v", err) + } + ctx := context.Background() + + account := &domain.SocialAccount{ + UserID: 1, + Provider: "github", + OpenID: "openid-identity", + Nickname: "before-update", + Status: domain.SocialAccountStatusActive, + } + if err := repo.Create(ctx, account); err != nil { + t.Fatalf("Create() error = %v", err) + } + + account.UserID = 999 + account.Provider = "wechat" + account.OpenID = "rewritten-openid" + account.Nickname = "after-update" + if err := repo.Update(ctx, account); err != nil { + t.Fatalf("Update() error = %v", err) + } + + found, err := repo.GetByID(ctx, account.ID) + if err != nil { + t.Fatalf("GetByID() error = %v", err) + } + if found == nil { + t.Fatal("expected social account after update") + } + if found.UserID != 1 { + t.Fatalf("UserID = %d, want 1", found.UserID) + } + if found.Provider != "github" { + t.Fatalf("Provider = %q, want github", found.Provider) + } + if found.OpenID != "openid-identity" { + t.Fatalf("OpenID = %q, want openid-identity", found.OpenID) + } + if found.Nickname != "after-update" { + t.Fatalf("Nickname = %q, want after-update", found.Nickname) + } +} + func TestSocialAccountRepository_Delete(t *testing.T) { db := setupSocialAccountTestDB(t) repo, err := NewSocialAccountRepository(db)