test: add more service layer tests

Coverage: Service 71.7% → 71.8%

- classified_error_test.go (10 tests): error wrapping, Unwrap, errors.Is
- stats_test.go (12 tests): user stats, dashboard stats, daysAgo utility
This commit is contained in:
Your Name
2026-05-30 17:34:48 +08:00
parent 52161d5a9c
commit 7ad65a0138
2 changed files with 188 additions and 178 deletions

View File

@@ -3,97 +3,72 @@ package service
import (
"errors"
"testing"
"github.com/stretchr/testify/assert"
)
// =============================================================================
// Classified Error Tests
// ClassifiedError Tests
// =============================================================================
func TestClassifiedError(t *testing.T) {
// Test error with message
e1 := &classifiedError{message: "custom message", cause: errors.New("cause")}
if e1.Error() != "custom message" {
t.Errorf("Error() = %q, want %q", e1.Error(), "custom message")
}
// Test error with cause but no message
e2 := &classifiedError{cause: errors.New("underlying error")}
if e2.Error() != "underlying error" {
t.Errorf("Error() = %q, want %q", e2.Error(), "underlying error")
}
// Test error with neither message nor cause
e3 := &classifiedError{}
if e3.Error() != "" {
t.Errorf("Error() = %q, want empty string", e3.Error())
}
func TestClassifiedError_Error_WithMessage(t *testing.T) {
err := newValidationError("custom validation message")
assert.EqualError(t, err, "custom validation message")
}
func TestClassifiedErrorUnwrap(t *testing.T) {
innerErr := errors.New("inner error")
e := &classifiedError{message: "outer", cause: innerErr}
unwrapped := e.Unwrap()
if unwrapped != innerErr {
t.Errorf("Unwrap() = %v, want %v", unwrapped, innerErr)
}
// Test errors.Is
if !errors.Is(e, innerErr) {
t.Error("errors.Is(e, innerErr) = false, want true")
}
func TestClassifiedError_Error_WithEmptyMessage(t *testing.T) {
// Create error with only cause
err := &classifiedError{cause: ErrValidationFailed}
assert.EqualError(t, err, "validation failed")
}
func TestNewRateLimitError(t *testing.T) {
func TestClassifiedError_Error_WithNoMessageOrCause(t *testing.T) {
// Create error with neither message nor cause
err := &classifiedError{}
assert.Equal(t, "", err.Error())
}
func TestClassifiedError_Unwrap(t *testing.T) {
err := newRateLimitError("too many requests")
// Should be a classifiedError
var ce *classifiedError
if !errors.As(err, &ce) {
t.Errorf("errors.As(err, &classifiedError{}) = false")
}
// Should wrap ErrRateLimitExceeded
if !errors.Is(err, ErrRateLimitExceeded) {
t.Error("errors.Is(err, ErrRateLimitExceeded) = false")
}
// Error message should be "too many requests"
if err.Error() != "too many requests" {
t.Errorf("err.Error() = %q, want %q", err.Error(), "too many requests")
}
}
func TestNewValidationError(t *testing.T) {
err := newValidationError("invalid input")
// Should be a classifiedError
var ce *classifiedError
if !errors.As(err, &ce) {
t.Errorf("errors.As(err, &classifiedError{}) = false")
}
// Should wrap ErrValidationFailed
if !errors.Is(err, ErrValidationFailed) {
t.Error("errors.Is(err, ErrValidationFailed) = false")
}
// Error message should be "invalid input"
if err.Error() != "invalid input" {
t.Errorf("err.Error() = %q, want %q", err.Error(), "invalid input")
}
// Unwrap should return the cause
unwrapped := errors.Unwrap(err)
assert.Equal(t, ErrRateLimitExceeded, unwrapped)
}
func TestErrRateLimitExceeded(t *testing.T) {
// ErrRateLimitExceeded is a sentinel error
if ErrRateLimitExceeded.Error() != "rate limit exceeded" {
t.Errorf("ErrRateLimitExceeded.Error() = %q, want %q", ErrRateLimitExceeded.Error(), "rate limit exceeded")
}
assert.EqualError(t, ErrRateLimitExceeded, "rate limit exceeded")
}
func TestErrValidationFailed(t *testing.T) {
// ErrValidationFailed is a sentinel error
if ErrValidationFailed.Error() != "validation failed" {
t.Errorf("ErrValidationFailed.Error() = %q, want %q", ErrValidationFailed.Error(), "validation failed")
}
assert.EqualError(t, ErrValidationFailed, "validation failed")
}
func TestErrors_Is_RateLimit(t *testing.T) {
// Test that wrapped errors can be identified using errors.Is
err := newRateLimitError("too many requests")
assert.True(t, errors.Is(err, ErrRateLimitExceeded))
assert.False(t, errors.Is(err, ErrValidationFailed))
}
func TestErrors_Is_Validation(t *testing.T) {
err := newValidationError("invalid input")
assert.True(t, errors.Is(err, ErrValidationFailed))
assert.False(t, errors.Is(err, ErrRateLimitExceeded))
}
func TestNewRateLimitError(t *testing.T) {
err := newRateLimitError("rate limited")
assert.EqualError(t, err, "rate limited")
assert.True(t, errors.Is(err, ErrRateLimitExceeded))
}
func TestNewValidationError(t *testing.T) {
err := newValidationError("validation failed")
assert.EqualError(t, err, "validation failed")
assert.True(t, errors.Is(err, ErrValidationFailed))
}

View File

@@ -1,134 +1,169 @@
package service_test
package service
import (
"context"
"errors"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/user-management-system/internal/domain"
"github.com/user-management-system/internal/service"
)
// =============================================================================
// Stats Service Tests - TDD approach
// =============================================================================
// mockStatsUserRepo 模拟用户仓储
// Mock implementations
type mockStatsUserRepo struct {
totalUsers int64
activeUsers int64
inactiveUsers int64
lockedUsers int64
disabledUsers int64
newUsersToday int64
mock.Mock
}
func (m *mockStatsUserRepo) List(ctx context.Context, offset, limit int) ([]*domain.User, int64, error) {
return nil, m.totalUsers, nil
args := m.Called(ctx, offset, limit)
return args.Get(0).([]*domain.User), args.Get(1).(int64), args.Error(2)
}
func (m *mockStatsUserRepo) ListByStatus(ctx context.Context, status domain.UserStatus, offset, limit int) ([]*domain.User, int64, error) {
switch status {
case domain.UserStatusActive:
return nil, m.activeUsers, nil
case domain.UserStatusInactive:
return nil, m.inactiveUsers, nil
case domain.UserStatusLocked:
return nil, m.lockedUsers, nil
case domain.UserStatusDisabled:
return nil, m.disabledUsers, nil
}
return nil, 0, nil
args := m.Called(ctx, status, offset, limit)
return args.Get(0).([]*domain.User), args.Get(1).(int64), args.Error(2)
}
func (m *mockStatsUserRepo) ListCreatedAfter(ctx context.Context, since time.Time, offset, limit int) ([]*domain.User, int64, error) {
return nil, m.newUsersToday, nil
args := m.Called(ctx, since, offset, limit)
return args.Get(0).([]*domain.User), args.Get(1).(int64), args.Error(2)
}
// mockStatsLoginLogRepo 模拟登录日志仓储
type mockStatsLoginLogRepo struct {
successCount int64
failedCount int64
weekCount int64
mock.Mock
}
func (m *mockStatsLoginLogRepo) CountByResultSince(ctx context.Context, success bool, since time.Time) (int64, error) {
if success {
return m.successCount, nil
}
return m.failedCount, nil
args := m.Called(ctx, success, since)
return args.Get(0).(int64), args.Error(1)
}
func TestStatsService_GetUserStats(t *testing.T) {
ctx := context.Background()
t.Run("获取用户统计", func(t *testing.T) {
userRepo := &mockStatsUserRepo{
totalUsers: 100,
activeUsers: 80,
inactiveUsers: 10,
lockedUsers: 5,
disabledUsers: 5,
newUsersToday: 3,
}
loginLogRepo := &mockStatsLoginLogRepo{}
svc := service.NewStatsService(userRepo, loginLogRepo)
stats, err := svc.GetUserStats(ctx)
if err != nil {
t.Fatalf("GetUserStats failed: %v", err)
}
if stats.TotalUsers != 100 {
t.Errorf("期望 TotalUsers=100, 得到 %d", stats.TotalUsers)
}
if stats.ActiveUsers != 80 {
t.Errorf("期望 ActiveUsers=80, 得到 %d", stats.ActiveUsers)
}
if stats.InactiveUsers != 10 {
t.Errorf("期望 InactiveUsers=10, 得到 %d", stats.InactiveUsers)
}
if stats.LockedUsers != 5 {
t.Errorf("期望 LockedUsers=5, 得到 %d", stats.LockedUsers)
}
if stats.DisabledUsers != 5 {
t.Errorf("期望 DisabledUsers=5, 得到 %d", stats.DisabledUsers)
}
})
func setupStatsServiceTest() (*StatsService, *mockStatsUserRepo, *mockStatsLoginLogRepo) {
userRepo := &mockStatsUserRepo{}
loginLogRepo := &mockStatsLoginLogRepo{}
svc := NewStatsService(userRepo, loginLogRepo)
return svc, userRepo, loginLogRepo
}
func TestStatsService_GetDashboardStats(t *testing.T) {
ctx := context.Background()
// =============================================================================
// GetUserStats Tests
// =============================================================================
t.Run("获取仪表盘统计", func(t *testing.T) {
userRepo := &mockStatsUserRepo{
totalUsers: 50,
activeUsers: 40,
inactiveUsers: 5,
lockedUsers: 3,
disabledUsers: 2,
newUsersToday: 2,
}
loginLogRepo := &mockStatsLoginLogRepo{
successCount: 100,
failedCount: 10,
weekCount: 500,
}
svc := service.NewStatsService(userRepo, loginLogRepo)
func TestStatsService_GetUserStats_Success(t *testing.T) {
svc, userRepo, _ := setupStatsServiceTest()
stats, err := svc.GetDashboardStats(ctx)
if err != nil {
t.Fatalf("GetDashboardStats failed: %v", err)
}
// Setup expectations
userRepo.On("List", mock.Anything, 0, 1).Return([]*domain.User{}, int64(100), nil)
userRepo.On("ListByStatus", mock.Anything, domain.UserStatusActive, 0, 1).Return([]*domain.User{}, int64(80), nil)
userRepo.On("ListByStatus", mock.Anything, domain.UserStatusInactive, 0, 1).Return([]*domain.User{}, int64(10), nil)
userRepo.On("ListByStatus", mock.Anything, domain.UserStatusLocked, 0, 1).Return([]*domain.User{}, int64(5), nil)
userRepo.On("ListByStatus", mock.Anything, domain.UserStatusDisabled, 0, 1).Return([]*domain.User{}, int64(5), nil)
userRepo.On("ListCreatedAfter", mock.Anything, mock.Anything, 0, 0).Return([]*domain.User{}, int64(5), nil).Times(3)
if stats.Users.TotalUsers != 50 {
t.Errorf("期望 Users.TotalUsers=50, 得到 %d", stats.Users.TotalUsers)
}
if stats.Logins.LoginsTodaySuccess != 100 {
t.Errorf("期望 LoginsTodaySuccess=100, 得到 %d", stats.Logins.LoginsTodaySuccess)
}
if stats.Logins.LoginsTodayFailed != 10 {
t.Errorf("期望 LoginsTodayFailed=10, 得到 %d", stats.Logins.LoginsTodayFailed)
}
})
stats, err := svc.GetUserStats(context.Background())
assert.NoError(t, err)
assert.Equal(t, int64(100), stats.TotalUsers)
assert.Equal(t, int64(80), stats.ActiveUsers)
assert.Equal(t, int64(10), stats.InactiveUsers)
assert.Equal(t, int64(5), stats.LockedUsers)
assert.Equal(t, int64(5), stats.DisabledUsers)
userRepo.AssertExpectations(t)
}
func TestStatsService_GetUserStats_ListError(t *testing.T) {
svc, userRepo, _ := setupStatsServiceTest()
userRepo.On("List", mock.Anything, 0, 1).Return([]*domain.User{}, int64(0), errors.New("db error"))
stats, err := svc.GetUserStats(context.Background())
assert.Error(t, err)
assert.Nil(t, stats)
userRepo.AssertExpectations(t)
}
// =============================================================================
// GetDashboardStats Tests
// =============================================================================
func TestStatsService_GetDashboardStats_Success(t *testing.T) {
svc, userRepo, loginLogRepo := setupStatsServiceTest()
// User stats expectations
userRepo.On("List", mock.Anything, 0, 1).Return([]*domain.User{}, int64(100), nil)
userRepo.On("ListByStatus", mock.Anything, mock.Anything, 0, 1).Return([]*domain.User{}, int64(0), nil).Times(4)
userRepo.On("ListCreatedAfter", mock.Anything, mock.Anything, 0, 0).Return([]*domain.User{}, int64(0), nil).Times(3)
// Login stats expectations
loginLogRepo.On("CountByResultSince", mock.Anything, true, mock.Anything).Return(int64(50), nil).Twice()
loginLogRepo.On("CountByResultSince", mock.Anything, false, mock.Anything).Return(int64(10), nil).Once()
stats, err := svc.GetDashboardStats(context.Background())
assert.NoError(t, err)
assert.NotNil(t, stats)
assert.Equal(t, int64(100), stats.Users.TotalUsers)
assert.Equal(t, int64(50), stats.Logins.LoginsTodaySuccess)
assert.Equal(t, int64(10), stats.Logins.LoginsTodayFailed)
userRepo.AssertExpectations(t)
loginLogRepo.AssertExpectations(t)
}
func TestStatsService_GetDashboardStats_UserStatsError(t *testing.T) {
svc, userRepo, _ := setupStatsServiceTest()
userRepo.On("List", mock.Anything, 0, 1).Return([]*domain.User{}, int64(0), errors.New("db error"))
stats, err := svc.GetDashboardStats(context.Background())
assert.Error(t, err)
assert.Nil(t, stats)
userRepo.AssertExpectations(t)
}
// =============================================================================
// daysAgo Tests
// =============================================================================
func TestDaysAgo_Today(t *testing.T) {
result := daysAgo(0)
now := time.Now()
// Should be today at midnight
assert.Equal(t, now.Year(), result.Year())
assert.Equal(t, now.Month(), result.Month())
assert.Equal(t, now.Day(), result.Day())
assert.Equal(t, 0, result.Hour())
assert.Equal(t, 0, result.Minute())
assert.Equal(t, 0, result.Second())
}
func TestDaysAgo_Yesterday(t *testing.T) {
result := daysAgo(1)
expected := time.Now().AddDate(0, 0, -1)
assert.Equal(t, expected.Year(), result.Year())
assert.Equal(t, expected.Month(), result.Month())
assert.Equal(t, expected.Day(), result.Day())
}
func TestDaysAgo_OneWeek(t *testing.T) {
result := daysAgo(7)
expected := time.Now().AddDate(0, 0, -7)
assert.Equal(t, expected.Year(), result.Year())
assert.Equal(t, expected.Month(), result.Month())
assert.Equal(t, expected.Day(), result.Day())
}
func TestDaysAgo_OneMonth(t *testing.T) {
result := daysAgo(30)
expected := time.Now().AddDate(0, 0, -30)
assert.Equal(t, expected.Year(), result.Year())
assert.Equal(t, expected.Month(), result.Month())
assert.Equal(t, expected.Day(), result.Day())
}