Files
user-system/internal/repository/social_account_repo.go

296 lines
7.7 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package repository
import (
"context"
"database/sql"
"fmt"
"github.com/user-management-system/internal/domain"
"gorm.io/gorm"
)
// SocialAccountRepository 社交账号仓库接口
type SocialAccountRepository interface {
Create(ctx context.Context, account *domain.SocialAccount) error
Update(ctx context.Context, account *domain.SocialAccount) error
Delete(ctx context.Context, id int64) error
DeleteByProviderAndUserID(ctx context.Context, provider string, userID int64) error
GetByID(ctx context.Context, id int64) (*domain.SocialAccount, error)
GetByUserID(ctx context.Context, userID int64) ([]*domain.SocialAccount, error)
GetByProviderAndOpenID(ctx context.Context, provider, openID string) (*domain.SocialAccount, error)
List(ctx context.Context, offset, limit int) ([]*domain.SocialAccount, int64, error)
}
// SocialAccountRepositoryImpl 社交账号仓库实现
type SocialAccountRepositoryImpl struct {
db *sql.DB
}
// NewSocialAccountRepository 创建社交账号仓库(支持 gorm.DB 或 *sql.DB
func NewSocialAccountRepository(db interface{}) (SocialAccountRepository, error) {
var sqlDB *sql.DB
switch d := db.(type) {
case *gorm.DB:
var err error
sqlDB, err = d.DB()
if err != nil {
return nil, fmt.Errorf("resolve sql db from gorm db failed: %w", err)
}
case *sql.DB:
sqlDB = d
default:
return nil, fmt.Errorf("unsupported db type: %T", db)
}
if sqlDB == nil {
return nil, fmt.Errorf("sql db is nil")
}
return &SocialAccountRepositoryImpl{db: sqlDB}, nil
}
// Create 创建社交账号
func (r *SocialAccountRepositoryImpl) Create(ctx context.Context, account *domain.SocialAccount) error {
query := `
INSERT INTO user_social_accounts (user_id, provider, open_id, union_id, nickname, avatar, gender, email, phone, extra, status)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
`
result, err := r.db.ExecContext(ctx, query,
account.UserID,
account.Provider,
account.OpenID,
account.UnionID,
account.Nickname,
account.Avatar,
account.Gender,
account.Email,
account.Phone,
account.Extra,
account.Status,
)
if err != nil {
return fmt.Errorf("failed to create social account: %w", err)
}
id, err := result.LastInsertId()
if err != nil {
return err
}
account.ID = id
return nil
}
// Update 更新社交账号
func (r *SocialAccountRepositoryImpl) Update(ctx context.Context, account *domain.SocialAccount) error {
query := `
UPDATE user_social_accounts
SET union_id = ?, nickname = ?, avatar = ?, gender = ?, email = ?, phone = ?, extra = ?, status = ?, updated_at = CURRENT_TIMESTAMP
WHERE id = ?
`
_, err := r.db.ExecContext(ctx, query,
account.UnionID,
account.Nickname,
account.Avatar,
account.Gender,
account.Email,
account.Phone,
account.Extra,
account.Status,
account.ID,
)
if err != nil {
return fmt.Errorf("failed to update social account: %w", err)
}
return nil
}
// Delete 删除社交账号
func (r *SocialAccountRepositoryImpl) Delete(ctx context.Context, id int64) error {
query := `DELETE FROM user_social_accounts WHERE id = ?`
_, err := r.db.ExecContext(ctx, query, id)
if err != nil {
return fmt.Errorf("failed to delete social account: %w", err)
}
return nil
}
// DeleteByProviderAndUserID 删除指定用户和提供商的社交账号
func (r *SocialAccountRepositoryImpl) DeleteByProviderAndUserID(ctx context.Context, provider string, userID int64) error {
query := `DELETE FROM user_social_accounts WHERE provider = ? AND user_id = ?`
_, err := r.db.ExecContext(ctx, query, provider, userID)
if err != nil {
return fmt.Errorf("failed to delete social account: %w", err)
}
return nil
}
// GetByID 根据ID获取社交账号
func (r *SocialAccountRepositoryImpl) GetByID(ctx context.Context, id int64) (*domain.SocialAccount, error) {
query := `
SELECT id, user_id, provider, open_id, union_id, nickname, avatar, gender, email, phone, extra, status, created_at, updated_at
FROM user_social_accounts
WHERE id = ?
`
var account domain.SocialAccount
err := r.db.QueryRowContext(ctx, query, id).Scan(
&account.ID,
&account.UserID,
&account.Provider,
&account.OpenID,
&account.UnionID,
&account.Nickname,
&account.Avatar,
&account.Gender,
&account.Email,
&account.Phone,
&account.Extra,
&account.Status,
&account.CreatedAt,
&account.UpdatedAt,
)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, fmt.Errorf("failed to get social account: %w", err)
}
return &account, nil
}
// GetByUserID 根据用户ID获取社交账号列表
func (r *SocialAccountRepositoryImpl) GetByUserID(ctx context.Context, userID int64) ([]*domain.SocialAccount, error) {
query := `
SELECT id, user_id, provider, open_id, union_id, nickname, avatar, gender, email, phone, extra, status, created_at, updated_at
FROM user_social_accounts
WHERE user_id = ?
ORDER BY created_at DESC
`
rows, err := r.db.QueryContext(ctx, query, userID)
if err != nil {
return nil, fmt.Errorf("failed to query social accounts: %w", err)
}
defer rows.Close()
var accounts []*domain.SocialAccount
for rows.Next() {
var account domain.SocialAccount
err := rows.Scan(
&account.ID,
&account.UserID,
&account.Provider,
&account.OpenID,
&account.UnionID,
&account.Nickname,
&account.Avatar,
&account.Gender,
&account.Email,
&account.Phone,
&account.Extra,
&account.Status,
&account.CreatedAt,
&account.UpdatedAt,
)
if err != nil {
return nil, err
}
accounts = append(accounts, &account)
}
return accounts, nil
}
// GetByProviderAndOpenID 根据提供商和OpenID获取社交账号
func (r *SocialAccountRepositoryImpl) GetByProviderAndOpenID(ctx context.Context, provider, openID string) (*domain.SocialAccount, error) {
query := `
SELECT id, user_id, provider, open_id, union_id, nickname, avatar, gender, email, phone, extra, status, created_at, updated_at
FROM user_social_accounts
WHERE provider = ? AND open_id = ?
`
var account domain.SocialAccount
err := r.db.QueryRowContext(ctx, query, provider, openID).Scan(
&account.ID,
&account.UserID,
&account.Provider,
&account.OpenID,
&account.UnionID,
&account.Nickname,
&account.Avatar,
&account.Gender,
&account.Email,
&account.Phone,
&account.Extra,
&account.Status,
&account.CreatedAt,
&account.UpdatedAt,
)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, fmt.Errorf("failed to get social account: %w", err)
}
return &account, nil
}
// List 分页获取社交账号列表
func (r *SocialAccountRepositoryImpl) List(ctx context.Context, offset, limit int) ([]*domain.SocialAccount, int64, error) {
// 获取总数
var total int64
countQuery := `SELECT COUNT(*) FROM user_social_accounts`
if err := r.db.QueryRowContext(ctx, countQuery).Scan(&total); err != nil {
return nil, 0, fmt.Errorf("failed to count social accounts: %w", err)
}
// 获取列表
query := `
SELECT id, user_id, provider, open_id, union_id, nickname, avatar, gender, email, phone, extra, status, created_at, updated_at
FROM user_social_accounts
ORDER BY created_at DESC
LIMIT ? OFFSET ?
`
rows, err := r.db.QueryContext(ctx, query, limit, offset)
if err != nil {
return nil, 0, fmt.Errorf("failed to query social accounts: %w", err)
}
defer rows.Close()
var accounts []*domain.SocialAccount
for rows.Next() {
var account domain.SocialAccount
err := rows.Scan(
&account.ID,
&account.UserID,
&account.Provider,
&account.OpenID,
&account.UnionID,
&account.Nickname,
&account.Avatar,
&account.Gender,
&account.Email,
&account.Phone,
&account.Extra,
&account.Status,
&account.CreatedAt,
&account.UpdatedAt,
)
if err != nil {
return nil, 0, err
}
accounts = append(accounts, &account)
}
return accounts, total, nil
}