feat(P1/P2): 完成TDD开发及P1/P2设计文档
## 设计文档 - multi_role_permission_design: 多角色权限设计 (CONDITIONAL GO) - audit_log_enhancement_design: 审计日志增强 (CONDITIONAL GO) - routing_strategy_template_design: 路由策略模板 (CONDITIONAL GO) - sso_saml_technical_research: SSO/SAML调研 (CONDITIONAL GO) - compliance_capability_package_design: 合规能力包设计 (CONDITIONAL GO) ## TDD开发成果 - IAM模块: supply-api/internal/iam/ (111个测试) - 审计日志模块: supply-api/internal/audit/ (40+测试) - 路由策略模块: gateway/internal/router/ (33+测试) - 合规能力包: gateway/internal/compliance/ + scripts/ci/compliance/ ## 规范文档 - parallel_agent_output_quality_standards: 并行Agent产出质量规范 - project_experience_summary: 项目经验总结 (v2) - 2026-04-02-p1-p2-tdd-execution-plan: TDD执行计划 ## 评审报告 - 5个CONDITIONAL GO设计文档评审报告 - fix_verification_report: 修复验证报告 - full_verification_report: 全面质量验证报告 - tdd_module_quality_verification: TDD模块质量验证 - tdd_execution_summary: TDD执行总结 依据: Superpowers执行框架 + TDD规范
This commit is contained in:
186
supply-api/internal/audit/events/cred_events.go
Normal file
186
supply-api/internal/audit/events/cred_events.go
Normal file
@@ -0,0 +1,186 @@
|
||||
package events
|
||||
|
||||
import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
// CRED事件类别常量
|
||||
const (
|
||||
CategoryCRED = "CRED"
|
||||
SubCategoryEXPOSE = "EXPOSE"
|
||||
SubCategoryINGRESS = "INGRESS"
|
||||
SubCategoryROTATE = "ROTATE"
|
||||
SubCategoryREVOKE = "REVOKE"
|
||||
SubCategoryVALIDATE = "VALIDATE"
|
||||
SubCategoryDIRECT = "DIRECT"
|
||||
)
|
||||
|
||||
// CRED事件列表
|
||||
var credEvents = []string{
|
||||
// 凭证暴露事件 (CRED-EXPOSE)
|
||||
"CRED-EXPOSE-RESPONSE", // 响应中暴露凭证
|
||||
"CRED-EXPOSE-LOG", // 日志中暴露凭证
|
||||
"CRED-EXPOSE-EXPORT", // 导出文件中暴露凭证
|
||||
|
||||
// 凭证入站事件 (CRED-INGRESS)
|
||||
"CRED-INGRESS-PLATFORM", // 平台凭证入站
|
||||
"CRED-INGRESS-SUPPLIER", // 供应商凭证入站
|
||||
|
||||
// 凭证轮换事件 (CRED-ROTATE)
|
||||
"CRED-ROTATE",
|
||||
|
||||
// 凭证吊销事件 (CRED-REVOKE)
|
||||
"CRED-REVOKE",
|
||||
|
||||
// 凭证验证事件 (CRED-VALIDATE)
|
||||
"CRED-VALIDATE",
|
||||
|
||||
// 直连绕过事件 (CRED-DIRECT)
|
||||
"CRED-DIRECT-SUPPLIER", // 直连供应商
|
||||
"CRED-DIRECT-BYPASS", // 绕过直连
|
||||
}
|
||||
|
||||
// CRED事件结果码映射
|
||||
var credResultCodes = map[string]string{
|
||||
"CRED-EXPOSE-RESPONSE": "SEC_CRED_EXPOSED",
|
||||
"CRED-EXPOSE-LOG": "SEC_CRED_EXPOSED",
|
||||
"CRED-EXPOSE-EXPORT": "SEC_CRED_EXPOSED",
|
||||
"CRED-INGRESS-PLATFORM": "CRED_INGRESS_OK",
|
||||
"CRED-INGRESS-SUPPLIER": "CRED_INGRESS_OK",
|
||||
"CRED-DIRECT-SUPPLIER": "SEC_DIRECT_BYPASS",
|
||||
"CRED-DIRECT-BYPASS": "SEC_DIRECT_BYPASS",
|
||||
"CRED-ROTATE": "CRED_ROTATE_OK",
|
||||
"CRED-REVOKE": "CRED_REVOKE_OK",
|
||||
"CRED-VALIDATE": "CRED_VALIDATE_OK",
|
||||
}
|
||||
|
||||
// CRED指标名称映射
|
||||
var credMetricNames = map[string]string{
|
||||
"CRED-EXPOSE-RESPONSE": "supplier_credential_exposure_events",
|
||||
"CRED-EXPOSE-LOG": "supplier_credential_exposure_events",
|
||||
"CRED-EXPOSE-EXPORT": "supplier_credential_exposure_events",
|
||||
"CRED-INGRESS-PLATFORM": "platform_credential_ingress_coverage_pct",
|
||||
"CRED-INGRESS-SUPPLIER": "platform_credential_ingress_coverage_pct",
|
||||
"CRED-DIRECT-SUPPLIER": "direct_supplier_call_by_consumer_events",
|
||||
"CRED-DIRECT-BYPASS": "direct_supplier_call_by_consumer_events",
|
||||
}
|
||||
|
||||
// GetCREDEvents 返回所有CRED事件
|
||||
func GetCREDEvents() []string {
|
||||
return credEvents
|
||||
}
|
||||
|
||||
// GetCREDExposeEvents 返回所有凭证暴露事件
|
||||
func GetCREDExposeEvents() []string {
|
||||
return []string{
|
||||
"CRED-EXPOSE-RESPONSE",
|
||||
"CRED-EXPOSE-LOG",
|
||||
"CRED-EXPOSE-EXPORT",
|
||||
}
|
||||
}
|
||||
|
||||
// GetCREDFngressEvents 返回所有凭证入站事件
|
||||
func GetCREDFngressEvents() []string {
|
||||
return []string{
|
||||
"CRED-INGRESS-PLATFORM",
|
||||
"CRED-INGRESS-SUPPLIER",
|
||||
}
|
||||
}
|
||||
|
||||
// GetCREDDnirectEvents 返回所有直连绕过事件
|
||||
func GetCREDDnirectEvents() []string {
|
||||
return []string{
|
||||
"CRED-DIRECT-SUPPLIER",
|
||||
"CRED-DIRECT-BYPASS",
|
||||
}
|
||||
}
|
||||
|
||||
// GetCREDEventCategory 返回CRED事件的类别
|
||||
func GetCREDEventCategory(eventName string) string {
|
||||
if strings.HasPrefix(eventName, "CRED-") {
|
||||
return CategoryCRED
|
||||
}
|
||||
if eventName == "CRED-ROTATE" || eventName == "CRED-REVOKE" || eventName == "CRED-VALIDATE" {
|
||||
return CategoryCRED
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// GetCREDEventSubCategory 返回CRED事件的子类别
|
||||
func GetCREDEventSubCategory(eventName string) string {
|
||||
if strings.HasPrefix(eventName, "CRED-EXPOSE") {
|
||||
return SubCategoryEXPOSE
|
||||
}
|
||||
if strings.HasPrefix(eventName, "CRED-INGRESS") {
|
||||
return SubCategoryINGRESS
|
||||
}
|
||||
if strings.HasPrefix(eventName, "CRED-DIRECT") {
|
||||
return SubCategoryDIRECT
|
||||
}
|
||||
if strings.HasPrefix(eventName, "CRED-ROTATE") {
|
||||
return SubCategoryROTATE
|
||||
}
|
||||
if strings.HasPrefix(eventName, "CRED-REVOKE") {
|
||||
return SubCategoryREVOKE
|
||||
}
|
||||
if strings.HasPrefix(eventName, "CRED-VALIDATE") {
|
||||
return SubCategoryVALIDATE
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// IsValidCREDEvent 检查事件名称是否为有效的CRED事件
|
||||
func IsValidCREDEvent(eventName string) bool {
|
||||
for _, e := range credEvents {
|
||||
if e == eventName {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// IsCREDExposeEvent 检查是否为凭证暴露事件(M-013相关)
|
||||
func IsCREDExposeEvent(eventName string) bool {
|
||||
return strings.HasPrefix(eventName, "CRED-EXPOSE")
|
||||
}
|
||||
|
||||
// IsCREDFngressEvent 检查是否为凭证入站事件(M-014相关)
|
||||
func IsCREDFngressEvent(eventName string) bool {
|
||||
return strings.HasPrefix(eventName, "CRED-INGRESS")
|
||||
}
|
||||
|
||||
// IsCREDDnirectEvent 检查是否为直连绕过事件(M-015相关)
|
||||
func IsCREDDnirectEvent(eventName string) bool {
|
||||
return strings.HasPrefix(eventName, "CRED-DIRECT")
|
||||
}
|
||||
|
||||
// GetCREDMetricName 获取CRED事件对应的指标名称
|
||||
func GetCREDMetricName(eventName string) string {
|
||||
if metric, ok := credMetricNames[eventName]; ok {
|
||||
return metric
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// GetCREDEventResultCode 获取CRED事件对应的结果码
|
||||
func GetCREDEventResultCode(eventName string) string {
|
||||
if code, ok := credResultCodes[eventName]; ok {
|
||||
return code
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// IsCREDExposeEvent 检查是否为M-013事件(凭证暴露)
|
||||
func IsM013RelatedEvent(eventName string) bool {
|
||||
return IsCREDExposeEvent(eventName)
|
||||
}
|
||||
|
||||
// IsCREDFngressEvent 检查是否为M-014事件(凭证入站)
|
||||
func IsM014RelatedEvent(eventName string) bool {
|
||||
return IsCREDFngressEvent(eventName)
|
||||
}
|
||||
|
||||
// IsCREDDnirectEvent 检查是否为M-015事件(直连绕过)
|
||||
func IsM015RelatedEvent(eventName string) bool {
|
||||
return IsCREDDnirectEvent(eventName)
|
||||
}
|
||||
145
supply-api/internal/audit/events/cred_events_test.go
Normal file
145
supply-api/internal/audit/events/cred_events_test.go
Normal file
@@ -0,0 +1,145 @@
|
||||
package events
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestCREDEvents_Categories(t *testing.T) {
|
||||
// 测试 CRED 事件类别
|
||||
events := GetCREDEvents()
|
||||
|
||||
// CRED-EXPOSE-RESPONSE: 响应中暴露凭证
|
||||
assert.Contains(t, events, "CRED-EXPOSE-RESPONSE", "Should contain CRED-EXPOSE-RESPONSE")
|
||||
|
||||
// CRED-INGRESS-PLATFORM: 平台凭证入站
|
||||
assert.Contains(t, events, "CRED-INGRESS-PLATFORM", "Should contain CRED-INGRESS-PLATFORM")
|
||||
|
||||
// CRED-DIRECT-SUPPLIER: 直连供应商
|
||||
assert.Contains(t, events, "CRED-DIRECT-SUPPLIER", "Should contain CRED-DIRECT-SUPPLIER")
|
||||
}
|
||||
|
||||
func TestCREDEvents_ExposeEvents(t *testing.T) {
|
||||
// 测试 CRED-EXPOSE 事件
|
||||
events := GetCREDExposeEvents()
|
||||
|
||||
assert.Contains(t, events, "CRED-EXPOSE-RESPONSE")
|
||||
assert.Contains(t, events, "CRED-EXPOSE-LOG")
|
||||
assert.Contains(t, events, "CRED-EXPOSE-EXPORT")
|
||||
}
|
||||
|
||||
func TestCREDEvents_IngressEvents(t *testing.T) {
|
||||
// 测试 CRED-INGRESS 事件
|
||||
events := GetCREDFngressEvents()
|
||||
|
||||
assert.Contains(t, events, "CRED-INGRESS-PLATFORM")
|
||||
assert.Contains(t, events, "CRED-INGRESS-SUPPLIER")
|
||||
}
|
||||
|
||||
func TestCREDEvents_DirectEvents(t *testing.T) {
|
||||
// 测试 CRED-DIRECT 事件
|
||||
events := GetCREDDnirectEvents()
|
||||
|
||||
assert.Contains(t, events, "CRED-DIRECT-SUPPLIER")
|
||||
assert.Contains(t, events, "CRED-DIRECT-BYPASS")
|
||||
}
|
||||
|
||||
func TestCREDEvents_GetEventCategory(t *testing.T) {
|
||||
// 所有CRED事件的类别应该是CRED
|
||||
events := GetCREDEvents()
|
||||
for _, eventName := range events {
|
||||
category := GetCREDEventCategory(eventName)
|
||||
assert.Equal(t, "CRED", category, "Event %s should have category CRED", eventName)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCREDEvents_GetEventSubCategory(t *testing.T) {
|
||||
// 测试CRED事件的子类别
|
||||
testCases := []struct {
|
||||
eventName string
|
||||
expectedSubCategory string
|
||||
}{
|
||||
{"CRED-EXPOSE-RESPONSE", "EXPOSE"},
|
||||
{"CRED-INGRESS-PLATFORM", "INGRESS"},
|
||||
{"CRED-DIRECT-SUPPLIER", "DIRECT"},
|
||||
{"CRED-ROTATE", "ROTATE"},
|
||||
{"CRED-REVOKE", "REVOKE"},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.eventName, func(t *testing.T) {
|
||||
subCategory := GetCREDEventSubCategory(tc.eventName)
|
||||
assert.Equal(t, tc.expectedSubCategory, subCategory)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCREDEvents_IsValidEvent(t *testing.T) {
|
||||
// 测试有效事件验证
|
||||
assert.True(t, IsValidCREDEvent("CRED-EXPOSE-RESPONSE"))
|
||||
assert.True(t, IsValidCREDEvent("CRED-INGRESS-PLATFORM"))
|
||||
assert.True(t, IsValidCREDEvent("CRED-DIRECT-SUPPLIER"))
|
||||
assert.False(t, IsValidCREDEvent("INVALID-EVENT"))
|
||||
assert.False(t, IsValidCREDEvent("AUTH-TOKEN-OK"))
|
||||
}
|
||||
|
||||
func TestCREDEvents_IsM013Event(t *testing.T) {
|
||||
// 测试M-013相关事件
|
||||
assert.True(t, IsCREDExposeEvent("CRED-EXPOSE-RESPONSE"))
|
||||
assert.True(t, IsCREDExposeEvent("CRED-EXPOSE-LOG"))
|
||||
assert.False(t, IsCREDExposeEvent("CRED-INGRESS-PLATFORM"))
|
||||
}
|
||||
|
||||
func TestCREDEvents_IsM014Event(t *testing.T) {
|
||||
// 测试M-014相关事件
|
||||
assert.True(t, IsCREDFngressEvent("CRED-INGRESS-PLATFORM"))
|
||||
assert.True(t, IsCREDFngressEvent("CRED-INGRESS-SUPPLIER"))
|
||||
assert.False(t, IsCREDFngressEvent("CRED-EXPOSE-RESPONSE"))
|
||||
}
|
||||
|
||||
func TestCREDEvents_IsM015Event(t *testing.T) {
|
||||
// 测试M-015相关事件
|
||||
assert.True(t, IsCREDDnirectEvent("CRED-DIRECT-SUPPLIER"))
|
||||
assert.True(t, IsCREDDnirectEvent("CRED-DIRECT-BYPASS"))
|
||||
assert.False(t, IsCREDDnirectEvent("CRED-INGRESS-PLATFORM"))
|
||||
}
|
||||
|
||||
func TestCREDEvents_GetMetricName(t *testing.T) {
|
||||
// 测试指标名称映射
|
||||
testCases := []struct {
|
||||
eventName string
|
||||
expectedMetric string
|
||||
}{
|
||||
{"CRED-EXPOSE-RESPONSE", "supplier_credential_exposure_events"},
|
||||
{"CRED-EXPOSE-LOG", "supplier_credential_exposure_events"},
|
||||
{"CRED-INGRESS-PLATFORM", "platform_credential_ingress_coverage_pct"},
|
||||
{"CRED-DIRECT-SUPPLIER", "direct_supplier_call_by_consumer_events"},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.eventName, func(t *testing.T) {
|
||||
metric := GetCREDMetricName(tc.eventName)
|
||||
assert.Equal(t, tc.expectedMetric, metric)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCREDEvents_GetResultCode(t *testing.T) {
|
||||
// 测试CRED事件结果码
|
||||
testCases := []struct {
|
||||
eventName string
|
||||
expectedCode string
|
||||
}{
|
||||
{"CRED-EXPOSE-RESPONSE", "SEC_CRED_EXPOSED"},
|
||||
{"CRED-INGRESS-PLATFORM", "CRED_INGRESS_OK"},
|
||||
{"CRED-DIRECT-SUPPLIER", "SEC_DIRECT_BYPASS"},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.eventName, func(t *testing.T) {
|
||||
code := GetCREDEventResultCode(tc.eventName)
|
||||
assert.Equal(t, tc.expectedCode, code)
|
||||
})
|
||||
}
|
||||
}
|
||||
195
supply-api/internal/audit/events/security_events.go
Normal file
195
supply-api/internal/audit/events/security_events.go
Normal file
@@ -0,0 +1,195 @@
|
||||
package events
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// SECURITY事件类别常量
|
||||
const (
|
||||
CategorySECURITY = "SECURITY"
|
||||
SubCategoryVIOLATION = "VIOLATION"
|
||||
SubCategoryALERT = "ALERT"
|
||||
SubCategoryBREACH = "BREACH"
|
||||
)
|
||||
|
||||
// SECURITY事件列表
|
||||
var securityEvents = []string{
|
||||
// 不变量违反事件 (INVARIANT-VIOLATION)
|
||||
"INV-PKG-001", // 供应方资质过期
|
||||
"INV-PKG-002", // 供应方余额为负
|
||||
"INV-PKG-003", // 售价不得低于保护价
|
||||
"INV-SET-001", // processing/completed 不可撤销
|
||||
"INV-SET-002", // 提现金额不得超过可提现余额
|
||||
"INV-SET-003", // 结算单金额与余额流水必须平衡
|
||||
|
||||
// 安全突破事件 (SECURITY-BREACH)
|
||||
"SEC-BREACH-001", // 凭证泄露突破
|
||||
"SEC-BREACH-002", // 权限绕过突破
|
||||
|
||||
// 安全告警事件 (SECURITY-ALERT)
|
||||
"SEC-ALERT-001", // 可疑访问告警
|
||||
"SEC-ALERT-002", // 异常行为告警
|
||||
}
|
||||
|
||||
// 不变量违反事件到结果码的映射
|
||||
var invariantResultCodes = map[string]string{
|
||||
"INV-PKG-001": "SEC_INV_PKG_001",
|
||||
"INV-PKG-002": "SEC_INV_PKG_002",
|
||||
"INV-PKG-003": "SEC_INV_PKG_003",
|
||||
"INV-SET-001": "SEC_INV_SET_001",
|
||||
"INV-SET-002": "SEC_INV_SET_002",
|
||||
"INV-SET-003": "SEC_INV_SET_003",
|
||||
}
|
||||
|
||||
// 事件描述映射
|
||||
var securityEventDescriptions = map[string]string{
|
||||
"INV-PKG-001": "供应方资质过期,资质验证失败",
|
||||
"INV-PKG-002": "供应方余额为负,余额检查失败",
|
||||
"INV-PKG-003": "售价不得低于保护价,价格校验失败",
|
||||
"INV-SET-001": "结算单状态为processing/completed,不可撤销",
|
||||
"INV-SET-002": "提现金额不得超过可提现余额",
|
||||
"INV-SET-003": "结算单金额与余额流水不平衡",
|
||||
"SEC-BREACH-001": "检测到凭证泄露安全突破",
|
||||
"SEC-BREACH-002": "检测到权限绕过安全突破",
|
||||
"SEC-ALERT-001": "检测到可疑访问行为",
|
||||
"SEC-ALERT-002": "检测到异常行为",
|
||||
}
|
||||
|
||||
// GetSECURITYEvents 返回所有SECURITY事件
|
||||
func GetSECURITYEvents() []string {
|
||||
return securityEvents
|
||||
}
|
||||
|
||||
// GetInvariantViolationEvents 返回所有不变量违反事件
|
||||
func GetInvariantViolationEvents() []string {
|
||||
return []string{
|
||||
"INV-PKG-001",
|
||||
"INV-PKG-002",
|
||||
"INV-PKG-003",
|
||||
"INV-SET-001",
|
||||
"INV-SET-002",
|
||||
"INV-SET-003",
|
||||
}
|
||||
}
|
||||
|
||||
// GetSecurityAlertEvents 返回所有安全告警事件
|
||||
func GetSecurityAlertEvents() []string {
|
||||
return []string{
|
||||
"SEC-ALERT-001",
|
||||
"SEC-ALERT-002",
|
||||
}
|
||||
}
|
||||
|
||||
// GetSecurityBreachEvents 返回所有安全突破事件
|
||||
func GetSecurityBreachEvents() []string {
|
||||
return []string{
|
||||
"SEC-BREACH-001",
|
||||
"SEC-BREACH-002",
|
||||
}
|
||||
}
|
||||
|
||||
// GetEventCategory 返回事件的类别
|
||||
func GetEventCategory(eventName string) string {
|
||||
if isInvariantViolation(eventName) || isSecurityBreach(eventName) || isSecurityAlert(eventName) {
|
||||
return CategorySECURITY
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// GetEventSubCategory 返回事件的子类别
|
||||
func GetEventSubCategory(eventName string) string {
|
||||
if isInvariantViolation(eventName) {
|
||||
return SubCategoryVIOLATION
|
||||
}
|
||||
if isSecurityBreach(eventName) {
|
||||
return SubCategoryBREACH
|
||||
}
|
||||
if isSecurityAlert(eventName) {
|
||||
return SubCategoryALERT
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// GetResultCode 返回事件对应的结果码
|
||||
func GetResultCode(eventName string) string {
|
||||
if code, ok := invariantResultCodes[eventName]; ok {
|
||||
return code
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// GetEventDescription 返回事件的描述
|
||||
func GetEventDescription(eventName string) string {
|
||||
if desc, ok := securityEventDescriptions[eventName]; ok {
|
||||
return desc
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// IsValidEvent 检查事件名称是否有效
|
||||
func IsValidEvent(eventName string) bool {
|
||||
for _, e := range securityEvents {
|
||||
if e == eventName {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// isInvariantViolation 检查是否为不变量违反事件
|
||||
func isInvariantViolation(eventName string) bool {
|
||||
for _, e := range getInvariantViolationEvents() {
|
||||
if e == eventName {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// getInvariantViolationEvents 返回不变量违反事件列表(内部使用)
|
||||
func getInvariantViolationEvents() []string {
|
||||
return []string{
|
||||
"INV-PKG-001",
|
||||
"INV-PKG-002",
|
||||
"INV-PKG-003",
|
||||
"INV-SET-001",
|
||||
"INV-SET-002",
|
||||
"INV-SET-003",
|
||||
}
|
||||
}
|
||||
|
||||
// isSecurityBreach 检查是否为安全突破事件
|
||||
func isSecurityBreach(eventName string) bool {
|
||||
prefixes := []string{"SEC-BREACH"}
|
||||
for _, prefix := range prefixes {
|
||||
if len(eventName) >= len(prefix) && eventName[:len(prefix)] == prefix {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// isSecurityAlert 检查是否为安全告警事件
|
||||
func isSecurityAlert(eventName string) bool {
|
||||
prefixes := []string{"SEC-ALERT"}
|
||||
for _, prefix := range prefixes {
|
||||
if len(eventName) >= len(prefix) && eventName[:len(prefix)] == prefix {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// FormatSECURITYEvent 格式化SECURITY事件
|
||||
func FormatSECURITYEvent(eventName string, params map[string]string) string {
|
||||
desc := GetEventDescription(eventName)
|
||||
if desc == "" {
|
||||
return fmt.Sprintf("SECURITY event: %s", eventName)
|
||||
}
|
||||
|
||||
// 如果有额外参数,追加到描述中
|
||||
if len(params) > 0 {
|
||||
return fmt.Sprintf("%s - %v", desc, params)
|
||||
}
|
||||
return desc
|
||||
}
|
||||
131
supply-api/internal/audit/events/security_events_test.go
Normal file
131
supply-api/internal/audit/events/security_events_test.go
Normal file
@@ -0,0 +1,131 @@
|
||||
package events
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestSECURITYEvents_InvariantViolation(t *testing.T) {
|
||||
// 测试 invariant_violation 事件
|
||||
events := GetSECURITYEvents()
|
||||
|
||||
// INV-PKG-001: 供应方资质过期
|
||||
assert.Contains(t, events, "INV-PKG-001", "Should contain INV-PKG-001")
|
||||
|
||||
// INV-SET-001: processing/completed 不可撤销
|
||||
assert.Contains(t, events, "INV-SET-001", "Should contain INV-SET-001")
|
||||
}
|
||||
|
||||
func TestSECURITYEvents_AllEvents(t *testing.T) {
|
||||
// 测试所有SECURITY事件
|
||||
events := GetSECURITYEvents()
|
||||
|
||||
// 验证不变量违反事件
|
||||
invariantEvents := GetInvariantViolationEvents()
|
||||
for _, event := range invariantEvents {
|
||||
assert.Contains(t, events, event, "SECURITY events should contain %s", event)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSECURITYEvents_GetInvariantViolationEvents(t *testing.T) {
|
||||
events := GetInvariantViolationEvents()
|
||||
|
||||
// INV-PKG-001: 供应方资质过期
|
||||
assert.Contains(t, events, "INV-PKG-001")
|
||||
|
||||
// INV-PKG-002: 供应方余额为负
|
||||
assert.Contains(t, events, "INV-PKG-002")
|
||||
|
||||
// INV-PKG-003: 售价不得低于保护价
|
||||
assert.Contains(t, events, "INV-PKG-003")
|
||||
|
||||
// INV-SET-001: processing/completed 不可撤销
|
||||
assert.Contains(t, events, "INV-SET-001")
|
||||
|
||||
// INV-SET-002: 提现金额不得超过可提现余额
|
||||
assert.Contains(t, events, "INV-SET-002")
|
||||
|
||||
// INV-SET-003: 结算单金额与余额流水必须平衡
|
||||
assert.Contains(t, events, "INV-SET-003")
|
||||
}
|
||||
|
||||
func TestSECURITYEvents_GetSecurityAlertEvents(t *testing.T) {
|
||||
events := GetSecurityAlertEvents()
|
||||
|
||||
// 安全告警事件应该存在
|
||||
assert.NotEmpty(t, events)
|
||||
}
|
||||
|
||||
func TestSECURITYEvents_GetSecurityBreachEvents(t *testing.T) {
|
||||
events := GetSecurityBreachEvents()
|
||||
|
||||
// 安全突破事件应该存在
|
||||
assert.NotEmpty(t, events)
|
||||
}
|
||||
|
||||
func TestSECURITYEvents_GetEventCategory(t *testing.T) {
|
||||
// 所有SECURITY事件的类别应该是SECURITY
|
||||
events := GetSECURITYEvents()
|
||||
for _, eventName := range events {
|
||||
category := GetEventCategory(eventName)
|
||||
assert.Equal(t, "SECURITY", category, "Event %s should have category SECURITY", eventName)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSECURITYEvents_GetResultCode(t *testing.T) {
|
||||
// 测试不变量违反事件的结果码映射
|
||||
testCases := []struct {
|
||||
eventName string
|
||||
expectedCode string
|
||||
}{
|
||||
{"INV-PKG-001", "SEC_INV_PKG_001"},
|
||||
{"INV-PKG-002", "SEC_INV_PKG_002"},
|
||||
{"INV-PKG-003", "SEC_INV_PKG_003"},
|
||||
{"INV-SET-001", "SEC_INV_SET_001"},
|
||||
{"INV-SET-002", "SEC_INV_SET_002"},
|
||||
{"INV-SET-003", "SEC_INV_SET_003"},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.eventName, func(t *testing.T) {
|
||||
code := GetResultCode(tc.eventName)
|
||||
assert.Equal(t, tc.expectedCode, code, "Result code mismatch for %s", tc.eventName)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSECURITYEvents_GetEventDescription(t *testing.T) {
|
||||
// 测试事件描述
|
||||
desc := GetEventDescription("INV-PKG-001")
|
||||
assert.NotEmpty(t, desc)
|
||||
assert.Contains(t, desc, "供应方资质", "Description should contain 供应方资质")
|
||||
}
|
||||
|
||||
func TestSECURITYEvents_IsValidEvent(t *testing.T) {
|
||||
// 测试有效事件验证
|
||||
assert.True(t, IsValidEvent("INV-PKG-001"))
|
||||
assert.True(t, IsValidEvent("INV-SET-001"))
|
||||
assert.False(t, IsValidEvent("INVALID-EVENT"))
|
||||
assert.False(t, IsValidEvent(""))
|
||||
}
|
||||
|
||||
func TestSECURITYEvents_GetEventSubCategory(t *testing.T) {
|
||||
// SECURITY事件的子类别应该是VIOLATION/ALERT/BREACH
|
||||
testCases := []struct {
|
||||
eventName string
|
||||
expectedSubCategory string
|
||||
}{
|
||||
{"INV-PKG-001", "VIOLATION"},
|
||||
{"INV-SET-001", "VIOLATION"},
|
||||
{"SEC-BREACH-001", "BREACH"},
|
||||
{"SEC-ALERT-001", "ALERT"},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.eventName, func(t *testing.T) {
|
||||
subCategory := GetEventSubCategory(tc.eventName)
|
||||
assert.Equal(t, tc.expectedSubCategory, subCategory)
|
||||
})
|
||||
}
|
||||
}
|
||||
357
supply-api/internal/audit/model/audit_event.go
Normal file
357
supply-api/internal/audit/model/audit_event.go
Normal file
@@ -0,0 +1,357 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// 事件类别常量
|
||||
const (
|
||||
CategoryCRED = "CRED"
|
||||
CategoryAUTH = "AUTH"
|
||||
CategoryDATA = "DATA"
|
||||
CategoryCONFIG = "CONFIG"
|
||||
CategorySECURITY = "SECURITY"
|
||||
)
|
||||
|
||||
// 凭证事件子类别
|
||||
const (
|
||||
SubCategoryCredExpose = "EXPOSE"
|
||||
SubCategoryCredIngress = "INGRESS"
|
||||
SubCategoryCredRotate = "ROTATE"
|
||||
SubCategoryCredRevoke = "REVOKE"
|
||||
SubCategoryCredValidate = "VALIDATE"
|
||||
SubCategoryCredDirect = "DIRECT"
|
||||
)
|
||||
|
||||
// 凭证类型
|
||||
const (
|
||||
CredentialTypePlatformToken = "platform_token"
|
||||
CredentialTypeQueryKey = "query_key"
|
||||
CredentialTypeUpstreamAPIKey = "upstream_api_key"
|
||||
CredentialTypeNone = "none"
|
||||
)
|
||||
|
||||
// 操作者类型
|
||||
const (
|
||||
OperatorTypeUser = "user"
|
||||
OperatorTypeSystem = "system"
|
||||
OperatorTypeAdmin = "admin"
|
||||
)
|
||||
|
||||
// 租户类型
|
||||
const (
|
||||
TenantTypeSupplier = "supplier"
|
||||
TenantTypeConsumer = "consumer"
|
||||
TenantTypePlatform = "platform"
|
||||
)
|
||||
|
||||
// SecurityFlags 安全标记
|
||||
type SecurityFlags struct {
|
||||
HasCredential bool `json:"has_credential"` // 是否包含凭证
|
||||
CredentialExposed bool `json:"credential_exposed"` // 凭证是否暴露
|
||||
Desensitized bool `json:"desensitized"` // 是否已脱敏
|
||||
Scanned bool `json:"scanned"` // 是否已扫描
|
||||
ScanPassed bool `json:"scan_passed"` // 扫描是否通过
|
||||
ViolationTypes []string `json:"violation_types"` // 违规类型列表
|
||||
}
|
||||
|
||||
// NewSecurityFlags 创建默认安全标记
|
||||
func NewSecurityFlags() *SecurityFlags {
|
||||
return &SecurityFlags{
|
||||
HasCredential: false,
|
||||
CredentialExposed: false,
|
||||
Desensitized: false,
|
||||
Scanned: false,
|
||||
ScanPassed: false,
|
||||
ViolationTypes: []string{},
|
||||
}
|
||||
}
|
||||
|
||||
// HasViolation 检查是否有违规
|
||||
func (sf *SecurityFlags) HasViolation() bool {
|
||||
return len(sf.ViolationTypes) > 0
|
||||
}
|
||||
|
||||
// HasViolationOfType 检查是否有指定类型的违规
|
||||
func (sf *SecurityFlags) HasViolationOfType(violationType string) bool {
|
||||
for _, v := range sf.ViolationTypes {
|
||||
if v == violationType {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// AddViolationType 添加违规类型
|
||||
func (sf *SecurityFlags) AddViolationType(violationType string) {
|
||||
sf.ViolationTypes = append(sf.ViolationTypes, violationType)
|
||||
}
|
||||
|
||||
// AuditEvent 统一审计事件
|
||||
type AuditEvent struct {
|
||||
// 基础标识
|
||||
EventID string `json:"event_id"` // 事件唯一ID (UUID)
|
||||
EventName string `json:"event_name"` // 事件名称 (e.g., "CRED-EXPOSE")
|
||||
EventCategory string `json:"event_category"` // 事件大类 (e.g., "CRED")
|
||||
EventSubCategory string `json:"event_sub_category"` // 事件子类
|
||||
|
||||
// 时间戳
|
||||
Timestamp time.Time `json:"timestamp"` // 事件发生时间
|
||||
TimestampMs int64 `json:"timestamp_ms"` // 毫秒时间戳
|
||||
|
||||
// 请求上下文
|
||||
RequestID string `json:"request_id"` // 请求追踪ID
|
||||
TraceID string `json:"trace_id"` // 分布式追踪ID
|
||||
SpanID string `json:"span_id"` // Span ID
|
||||
|
||||
// 幂等性
|
||||
IdempotencyKey string `json:"idempotency_key,omitempty"` // 幂等键
|
||||
|
||||
// 操作者信息
|
||||
OperatorID int64 `json:"operator_id"` // 操作者ID
|
||||
OperatorType string `json:"operator_type"` // 操作者类型 (user/system/admin)
|
||||
OperatorRole string `json:"operator_role"` // 操作者角色
|
||||
|
||||
// 租户信息
|
||||
TenantID int64 `json:"tenant_id"` // 租户ID
|
||||
TenantType string `json:"tenant_type"` // 租户类型 (supplier/consumer/platform)
|
||||
|
||||
// 对象信息
|
||||
ObjectType string `json:"object_type"` // 对象类型 (account/package/settlement)
|
||||
ObjectID int64 `json:"object_id"` // 对象ID
|
||||
|
||||
// 操作信息
|
||||
Action string `json:"action"` // 操作类型 (create/update/delete)
|
||||
ActionDetail string `json:"action_detail"` // 操作详情
|
||||
|
||||
// 凭证信息 (M-013/M-014/M-015/M-016 关键)
|
||||
CredentialType string `json:"credential_type"` // 凭证类型 (platform_token/query_key/upstream_api_key/none)
|
||||
CredentialID string `json:"credential_id,omitempty"` // 凭证标识 (脱敏)
|
||||
CredentialFingerprint string `json:"credential_fingerprint,omitempty"` // 凭证指纹
|
||||
|
||||
// 来源信息
|
||||
SourceType string `json:"source_type"` // 来源类型 (api/ui/cron/internal)
|
||||
SourceIP string `json:"source_ip"` // 来源IP
|
||||
SourceRegion string `json:"source_region"` // 来源区域
|
||||
UserAgent string `json:"user_agent,omitempty"` // User Agent
|
||||
|
||||
// 目标信息 (用于直连检测 M-015)
|
||||
TargetType string `json:"target_type,omitempty"` // 目标类型
|
||||
TargetEndpoint string `json:"target_endpoint,omitempty"` // 目标端点
|
||||
TargetDirect bool `json:"target_direct"` // 是否直连
|
||||
|
||||
// 结果信息
|
||||
ResultCode string `json:"result_code"` // 结果码
|
||||
ResultMessage string `json:"result_message,omitempty"` // 结果消息
|
||||
Success bool `json:"success"` // 是否成功
|
||||
|
||||
// 状态变更 (用于溯源)
|
||||
BeforeState map[string]any `json:"before_state,omitempty"` // 操作前状态
|
||||
AfterState map[string]any `json:"after_state,omitempty"` // 操作后状态
|
||||
|
||||
// 安全标记 (M-013 关键)
|
||||
SecurityFlags SecurityFlags `json:"security_flags"` // 安全标记
|
||||
RiskScore int `json:"risk_score"` // 风险评分 0-100
|
||||
|
||||
// 合规信息
|
||||
ComplianceTags []string `json:"compliance_tags,omitempty"` // 合规标签 (e.g., ["GDPR", "SOC2"])
|
||||
InvariantRule string `json:"invariant_rule,omitempty"` // 触发的不变量规则
|
||||
|
||||
// 扩展字段
|
||||
Extensions map[string]any `json:"extensions,omitempty"` // 扩展数据
|
||||
|
||||
// 元数据
|
||||
Version int `json:"version"` // 事件版本
|
||||
CreatedAt time.Time `json:"created_at"` // 创建时间
|
||||
}
|
||||
|
||||
// NewAuditEvent 创建审计事件
|
||||
func NewAuditEvent(
|
||||
eventName string,
|
||||
eventCategory string,
|
||||
eventSubCategory string,
|
||||
metricName string,
|
||||
requestID string,
|
||||
traceID string,
|
||||
operatorID int64,
|
||||
operatorType string,
|
||||
operatorRole string,
|
||||
tenantID int64,
|
||||
tenantType string,
|
||||
objectType string,
|
||||
objectID int64,
|
||||
action string,
|
||||
credentialType string,
|
||||
sourceType string,
|
||||
sourceIP string,
|
||||
success bool,
|
||||
resultCode string,
|
||||
resultMessage string,
|
||||
) *AuditEvent {
|
||||
now := time.Now()
|
||||
event := &AuditEvent{
|
||||
EventID: uuid.New().String(),
|
||||
EventName: eventName,
|
||||
EventCategory: eventCategory,
|
||||
EventSubCategory: eventSubCategory,
|
||||
Timestamp: now,
|
||||
TimestampMs: now.UnixMilli(),
|
||||
RequestID: requestID,
|
||||
TraceID: traceID,
|
||||
OperatorID: operatorID,
|
||||
OperatorType: operatorType,
|
||||
OperatorRole: operatorRole,
|
||||
TenantID: tenantID,
|
||||
TenantType: tenantType,
|
||||
ObjectType: objectType,
|
||||
ObjectID: objectID,
|
||||
Action: action,
|
||||
CredentialType: credentialType,
|
||||
SourceType: sourceType,
|
||||
SourceIP: sourceIP,
|
||||
Success: success,
|
||||
ResultCode: resultCode,
|
||||
ResultMessage: resultMessage,
|
||||
Version: 1,
|
||||
CreatedAt: now,
|
||||
SecurityFlags: *NewSecurityFlags(),
|
||||
ComplianceTags: []string{},
|
||||
}
|
||||
|
||||
// 根据凭证类型设置安全标记
|
||||
if credentialType != CredentialTypeNone && credentialType != "" {
|
||||
event.SecurityFlags.HasCredential = true
|
||||
}
|
||||
|
||||
// 根据事件名称设置凭证暴露标记(M-013)
|
||||
if IsM013Event(eventName) {
|
||||
event.SecurityFlags.CredentialExposed = true
|
||||
}
|
||||
|
||||
// 根据事件名称设置指标名称到扩展字段
|
||||
if metricName != "" {
|
||||
if event.Extensions == nil {
|
||||
event.Extensions = make(map[string]any)
|
||||
}
|
||||
event.Extensions["metric_name"] = metricName
|
||||
}
|
||||
|
||||
return event
|
||||
}
|
||||
|
||||
// NewAuditEventWithSecurityFlags 创建带完整安全标记的审计事件
|
||||
func NewAuditEventWithSecurityFlags(
|
||||
eventName string,
|
||||
eventCategory string,
|
||||
eventSubCategory string,
|
||||
metricName string,
|
||||
requestID string,
|
||||
traceID string,
|
||||
operatorID int64,
|
||||
operatorType string,
|
||||
operatorRole string,
|
||||
tenantID int64,
|
||||
tenantType string,
|
||||
objectType string,
|
||||
objectID int64,
|
||||
action string,
|
||||
credentialType string,
|
||||
sourceType string,
|
||||
sourceIP string,
|
||||
success bool,
|
||||
resultCode string,
|
||||
resultMessage string,
|
||||
securityFlags SecurityFlags,
|
||||
riskScore int,
|
||||
) *AuditEvent {
|
||||
event := NewAuditEvent(
|
||||
eventName,
|
||||
eventCategory,
|
||||
eventSubCategory,
|
||||
metricName,
|
||||
requestID,
|
||||
traceID,
|
||||
operatorID,
|
||||
operatorType,
|
||||
operatorRole,
|
||||
tenantID,
|
||||
tenantType,
|
||||
objectType,
|
||||
objectID,
|
||||
action,
|
||||
credentialType,
|
||||
sourceType,
|
||||
sourceIP,
|
||||
success,
|
||||
resultCode,
|
||||
resultMessage,
|
||||
)
|
||||
event.SecurityFlags = securityFlags
|
||||
event.RiskScore = riskScore
|
||||
return event
|
||||
}
|
||||
|
||||
// SetIdempotencyKey 设置幂等键
|
||||
func (e *AuditEvent) SetIdempotencyKey(key string) {
|
||||
e.IdempotencyKey = key
|
||||
}
|
||||
|
||||
// SetTarget 设置目标信息(用于M-015直连检测)
|
||||
func (e *AuditEvent) SetTarget(targetType, targetEndpoint string, targetDirect bool) {
|
||||
e.TargetType = targetType
|
||||
e.TargetEndpoint = targetEndpoint
|
||||
e.TargetDirect = targetDirect
|
||||
}
|
||||
|
||||
// SetInvariantRule 设置不变量规则(用于SECURITY事件)
|
||||
func (e *AuditEvent) SetInvariantRule(rule string) {
|
||||
e.InvariantRule = rule
|
||||
// 添加合规标签
|
||||
e.ComplianceTags = append(e.ComplianceTags, "XR-001")
|
||||
}
|
||||
|
||||
// GetMetricName 获取指标名称
|
||||
func (e *AuditEvent) GetMetricName() string {
|
||||
if e.Extensions != nil {
|
||||
if metricName, ok := e.Extensions["metric_name"].(string); ok {
|
||||
return metricName
|
||||
}
|
||||
}
|
||||
|
||||
// 根据事件名称推断指标
|
||||
switch e.EventName {
|
||||
case "CRED-EXPOSE-RESPONSE", "CRED-EXPOSE-LOG", "CRED-EXPOSE":
|
||||
return "supplier_credential_exposure_events"
|
||||
case "CRED-INGRESS-PLATFORM", "CRED-INGRESS":
|
||||
return "platform_credential_ingress_coverage_pct"
|
||||
case "CRED-DIRECT-SUPPLIER", "CRED-DIRECT":
|
||||
return "direct_supplier_call_by_consumer_events"
|
||||
case "AUTH-QUERY-KEY", "AUTH-QUERY-REJECT", "AUTH-QUERY":
|
||||
return "query_key_external_reject_rate_pct"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// IsM013Event 判断是否为M-013凭证暴露事件
|
||||
func IsM013Event(eventName string) bool {
|
||||
return strings.HasPrefix(eventName, "CRED-EXPOSE")
|
||||
}
|
||||
|
||||
// IsM014Event 判断是否为M-014凭证入站事件
|
||||
func IsM014Event(eventName string) bool {
|
||||
return strings.HasPrefix(eventName, "CRED-INGRESS")
|
||||
}
|
||||
|
||||
// IsM015Event 判断是否为M-015直连绕过事件
|
||||
func IsM015Event(eventName string) bool {
|
||||
return strings.HasPrefix(eventName, "CRED-DIRECT")
|
||||
}
|
||||
|
||||
// IsM016Event 判断是否为M-016 query key拒绝事件
|
||||
func IsM016Event(eventName string) bool {
|
||||
return strings.HasPrefix(eventName, "AUTH-QUERY")
|
||||
}
|
||||
389
supply-api/internal/audit/model/audit_event_test.go
Normal file
389
supply-api/internal/audit/model/audit_event_test.go
Normal file
@@ -0,0 +1,389 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestAuditEvent_NewEvent_ValidInput(t *testing.T) {
|
||||
// 测试创建审计事件
|
||||
event := NewAuditEvent(
|
||||
"CRED-EXPOSE-RESPONSE",
|
||||
"CRED",
|
||||
"EXPOSE",
|
||||
"supplier_credential_exposure_events",
|
||||
"test-request-id",
|
||||
"test-trace-id",
|
||||
1001,
|
||||
"user",
|
||||
"admin",
|
||||
2001,
|
||||
"supplier",
|
||||
"account",
|
||||
12345,
|
||||
"create",
|
||||
"platform_token",
|
||||
"api",
|
||||
"192.168.1.1",
|
||||
true,
|
||||
"SEC_CRED_EXPOSED",
|
||||
"Credential exposed in response",
|
||||
)
|
||||
|
||||
// 验证字段
|
||||
assert.NotEmpty(t, event.EventID, "EventID should not be empty")
|
||||
assert.Equal(t, "CRED-EXPOSE-RESPONSE", event.EventName, "EventName should match")
|
||||
assert.Equal(t, "CRED", event.EventCategory, "EventCategory should match")
|
||||
assert.Equal(t, "EXPOSE", event.EventSubCategory, "EventSubCategory should match")
|
||||
assert.Equal(t, "test-request-id", event.RequestID, "RequestID should match")
|
||||
assert.Equal(t, "test-trace-id", event.TraceID, "TraceID should match")
|
||||
assert.Equal(t, int64(1001), event.OperatorID, "OperatorID should match")
|
||||
assert.Equal(t, "user", event.OperatorType, "OperatorType should match")
|
||||
assert.Equal(t, "admin", event.OperatorRole, "OperatorRole should match")
|
||||
assert.Equal(t, int64(2001), event.TenantID, "TenantID should match")
|
||||
assert.Equal(t, "supplier", event.TenantType, "TenantType should match")
|
||||
assert.Equal(t, "account", event.ObjectType, "ObjectType should match")
|
||||
assert.Equal(t, int64(12345), event.ObjectID, "ObjectID should match")
|
||||
assert.Equal(t, "create", event.Action, "Action should match")
|
||||
assert.Equal(t, "platform_token", event.CredentialType, "CredentialType should match")
|
||||
assert.Equal(t, "api", event.SourceType, "SourceType should match")
|
||||
assert.Equal(t, "192.168.1.1", event.SourceIP, "SourceIP should match")
|
||||
assert.True(t, event.Success, "Success should be true")
|
||||
assert.Equal(t, "SEC_CRED_EXPOSED", event.ResultCode, "ResultCode should match")
|
||||
assert.Equal(t, "Credential exposed in response", event.ResultMessage, "ResultMessage should match")
|
||||
|
||||
// 验证时间戳
|
||||
assert.False(t, event.Timestamp.IsZero(), "Timestamp should not be zero")
|
||||
assert.True(t, event.TimestampMs > 0, "TimestampMs should be positive")
|
||||
assert.False(t, event.CreatedAt.IsZero(), "CreatedAt should not be zero")
|
||||
|
||||
// 验证版本
|
||||
assert.Equal(t, 1, event.Version, "Version should be 1")
|
||||
}
|
||||
|
||||
func TestAuditEvent_NewEvent_SecurityFlags(t *testing.T) {
|
||||
// 验证SecurityFlags字段
|
||||
event := NewAuditEvent(
|
||||
"CRED-EXPOSE-RESPONSE",
|
||||
"CRED",
|
||||
"EXPOSE",
|
||||
"supplier_credential_exposure_events",
|
||||
"test-request-id",
|
||||
"test-trace-id",
|
||||
1001,
|
||||
"user",
|
||||
"admin",
|
||||
2001,
|
||||
"supplier",
|
||||
"account",
|
||||
12345,
|
||||
"create",
|
||||
"platform_token",
|
||||
"api",
|
||||
"192.168.1.1",
|
||||
true,
|
||||
"SEC_CRED_EXPOSED",
|
||||
"Credential exposed in response",
|
||||
)
|
||||
|
||||
// 验证安全标记
|
||||
assert.NotNil(t, event.SecurityFlags, "SecurityFlags should not be nil")
|
||||
assert.True(t, event.SecurityFlags.HasCredential, "HasCredential should be true")
|
||||
assert.True(t, event.SecurityFlags.CredentialExposed, "CredentialExposed should be true")
|
||||
assert.False(t, event.SecurityFlags.Desensitized, "Desensitized should be false by default")
|
||||
assert.False(t, event.SecurityFlags.Scanned, "Scanned should be false by default")
|
||||
assert.False(t, event.SecurityFlags.ScanPassed, "ScanPassed should be false by default")
|
||||
assert.Empty(t, event.SecurityFlags.ViolationTypes, "ViolationTypes should be empty by default")
|
||||
}
|
||||
|
||||
func TestAuditEvent_NewEvent_WithSecurityFlags(t *testing.T) {
|
||||
// 测试带有完整安全标记的事件
|
||||
securityFlags := SecurityFlags{
|
||||
HasCredential: true,
|
||||
CredentialExposed: true,
|
||||
Desensitized: false,
|
||||
Scanned: true,
|
||||
ScanPassed: false,
|
||||
ViolationTypes: []string{"api_key", "secret"},
|
||||
}
|
||||
|
||||
event := NewAuditEventWithSecurityFlags(
|
||||
"CRED-EXPOSE-RESPONSE",
|
||||
"CRED",
|
||||
"EXPOSE",
|
||||
"supplier_credential_exposure_events",
|
||||
"test-request-id",
|
||||
"test-trace-id",
|
||||
1001,
|
||||
"user",
|
||||
"admin",
|
||||
2001,
|
||||
"supplier",
|
||||
"account",
|
||||
12345,
|
||||
"create",
|
||||
"platform_token",
|
||||
"api",
|
||||
"192.168.1.1",
|
||||
true,
|
||||
"SEC_CRED_EXPOSED",
|
||||
"Credential exposed in response",
|
||||
securityFlags,
|
||||
80,
|
||||
)
|
||||
|
||||
// 验证安全标记
|
||||
assert.Equal(t, true, event.SecurityFlags.HasCredential)
|
||||
assert.Equal(t, true, event.SecurityFlags.CredentialExposed)
|
||||
assert.Equal(t, false, event.SecurityFlags.Desensitized)
|
||||
assert.Equal(t, true, event.SecurityFlags.Scanned)
|
||||
assert.Equal(t, false, event.SecurityFlags.ScanPassed)
|
||||
assert.Equal(t, []string{"api_key", "secret"}, event.SecurityFlags.ViolationTypes)
|
||||
|
||||
// 验证风险评分
|
||||
assert.Equal(t, 80, event.RiskScore, "RiskScore should be 80")
|
||||
}
|
||||
|
||||
func TestAuditEvent_NewAuditEventWithIdempotencyKey(t *testing.T) {
|
||||
// 测试带幂等键的事件
|
||||
event := NewAuditEvent(
|
||||
"AUTH-QUERY-KEY",
|
||||
"AUTH",
|
||||
"QUERY",
|
||||
"query_key_external_reject_rate_pct",
|
||||
"test-request-id",
|
||||
"test-trace-id",
|
||||
1001,
|
||||
"user",
|
||||
"admin",
|
||||
2001,
|
||||
"supplier",
|
||||
"account",
|
||||
12345,
|
||||
"query",
|
||||
"query_key",
|
||||
"api",
|
||||
"192.168.1.1",
|
||||
true,
|
||||
"AUTH_QUERY_KEY",
|
||||
"Query key request",
|
||||
)
|
||||
|
||||
// 设置幂等键
|
||||
event.SetIdempotencyKey("idem-key-12345")
|
||||
|
||||
assert.Equal(t, "idem-key-12345", event.IdempotencyKey, "IdempotencyKey should be set")
|
||||
}
|
||||
|
||||
func TestAuditEvent_NewAuditEventWithTarget(t *testing.T) {
|
||||
// 测试带目标信息的事件(用于M-015直连检测)
|
||||
event := NewAuditEvent(
|
||||
"CRED-DIRECT-SUPPLIER",
|
||||
"CRED",
|
||||
"DIRECT",
|
||||
"direct_supplier_call_by_consumer_events",
|
||||
"test-request-id",
|
||||
"test-trace-id",
|
||||
1001,
|
||||
"user",
|
||||
"admin",
|
||||
2001,
|
||||
"supplier",
|
||||
"api",
|
||||
12345,
|
||||
"call",
|
||||
"none",
|
||||
"api",
|
||||
"192.168.1.1",
|
||||
false,
|
||||
"SEC_DIRECT_BYPASS",
|
||||
"Direct call detected",
|
||||
)
|
||||
|
||||
// 设置直连目标
|
||||
event.SetTarget("upstream_api", "https://supplier.example.com/v1/chat/completions", true)
|
||||
|
||||
assert.Equal(t, "upstream_api", event.TargetType, "TargetType should be set")
|
||||
assert.Equal(t, "https://supplier.example.com/v1/chat/completions", event.TargetEndpoint, "TargetEndpoint should be set")
|
||||
assert.True(t, event.TargetDirect, "TargetDirect should be true")
|
||||
}
|
||||
|
||||
func TestAuditEvent_NewAuditEventWithInvariantRule(t *testing.T) {
|
||||
// 测试不变量规则(用于SECURITY事件)
|
||||
event := NewAuditEvent(
|
||||
"INVARIANT-VIOLATION",
|
||||
"SECURITY",
|
||||
"VIOLATION",
|
||||
"invariant_violation",
|
||||
"test-request-id",
|
||||
"test-trace-id",
|
||||
1001,
|
||||
"system",
|
||||
"admin",
|
||||
2001,
|
||||
"supplier",
|
||||
"settlement",
|
||||
12345,
|
||||
"withdraw",
|
||||
"platform_token",
|
||||
"api",
|
||||
"192.168.1.1",
|
||||
false,
|
||||
"SEC_INV_SET_001",
|
||||
"Settlement cannot be revoked",
|
||||
)
|
||||
|
||||
// 设置不变量规则
|
||||
event.SetInvariantRule("INV-SET-001")
|
||||
|
||||
assert.Equal(t, "INV-SET-001", event.InvariantRule, "InvariantRule should be set")
|
||||
assert.Contains(t, event.ComplianceTags, "XR-001", "ComplianceTags should contain XR-001")
|
||||
}
|
||||
|
||||
func TestSecurityFlags_HasViolation(t *testing.T) {
|
||||
// 测试安全标记的违规检测
|
||||
sf := NewSecurityFlags()
|
||||
|
||||
// 初始状态无违规
|
||||
assert.False(t, sf.HasViolation(), "Should have no violation initially")
|
||||
|
||||
// 添加违规类型
|
||||
sf.AddViolationType("api_key")
|
||||
assert.True(t, sf.HasViolation(), "Should have violation after adding type")
|
||||
assert.True(t, sf.HasViolationOfType("api_key"), "Should have api_key violation")
|
||||
assert.False(t, sf.HasViolationOfType("password"), "Should not have password violation")
|
||||
}
|
||||
|
||||
func TestSecurityFlags_AddViolationType(t *testing.T) {
|
||||
sf := NewSecurityFlags()
|
||||
|
||||
sf.AddViolationType("api_key")
|
||||
sf.AddViolationType("secret")
|
||||
sf.AddViolationType("password")
|
||||
|
||||
assert.Len(t, sf.ViolationTypes, 3, "Should have 3 violation types")
|
||||
assert.Contains(t, sf.ViolationTypes, "api_key")
|
||||
assert.Contains(t, sf.ViolationTypes, "secret")
|
||||
assert.Contains(t, sf.ViolationTypes, "password")
|
||||
}
|
||||
|
||||
func TestAuditEvent_MetricName(t *testing.T) {
|
||||
// 测试事件与指标的映射
|
||||
testCases := []struct {
|
||||
eventName string
|
||||
expectedMetric string
|
||||
}{
|
||||
{"CRED-EXPOSE-RESPONSE", "supplier_credential_exposure_events"},
|
||||
{"CRED-EXPOSE-LOG", "supplier_credential_exposure_events"},
|
||||
{"CRED-INGRESS-PLATFORM", "platform_credential_ingress_coverage_pct"},
|
||||
{"CRED-DIRECT-SUPPLIER", "direct_supplier_call_by_consumer_events"},
|
||||
{"AUTH-QUERY-KEY", "query_key_external_reject_rate_pct"},
|
||||
{"AUTH-QUERY-REJECT", "query_key_external_reject_rate_pct"},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.eventName, func(t *testing.T) {
|
||||
event := &AuditEvent{
|
||||
EventName: tc.eventName,
|
||||
}
|
||||
assert.Equal(t, tc.expectedMetric, event.GetMetricName(), "MetricName should match for %s", tc.eventName)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuditEvent_IsM013Event(t *testing.T) {
|
||||
// M-013: 凭证暴露事件
|
||||
assert.True(t, IsM013Event("CRED-EXPOSE-RESPONSE"), "CRED-EXPOSE-RESPONSE is M-013 event")
|
||||
assert.True(t, IsM013Event("CRED-EXPOSE-LOG"), "CRED-EXPOSE-LOG is M-013 event")
|
||||
assert.True(t, IsM013Event("CRED-EXPOSE"), "CRED-EXPOSE is M-013 event")
|
||||
assert.False(t, IsM013Event("CRED-INGRESS-PLATFORM"), "CRED-INGRESS-PLATFORM is not M-013 event")
|
||||
assert.False(t, IsM013Event("AUTH-QUERY-KEY"), "AUTH-QUERY-KEY is not M-013 event")
|
||||
}
|
||||
|
||||
func TestAuditEvent_IsM014Event(t *testing.T) {
|
||||
// M-014: 凭证入站事件
|
||||
assert.True(t, IsM014Event("CRED-INGRESS-PLATFORM"), "CRED-INGRESS-PLATFORM is M-014 event")
|
||||
assert.True(t, IsM014Event("CRED-INGRESS"), "CRED-INGRESS is M-014 event")
|
||||
assert.False(t, IsM014Event("CRED-EXPOSE-RESPONSE"), "CRED-EXPOSE-RESPONSE is not M-014 event")
|
||||
}
|
||||
|
||||
func TestAuditEvent_IsM015Event(t *testing.T) {
|
||||
// M-015: 直连绕过事件
|
||||
assert.True(t, IsM015Event("CRED-DIRECT-SUPPLIER"), "CRED-DIRECT-SUPPLIER is M-015 event")
|
||||
assert.True(t, IsM015Event("CRED-DIRECT"), "CRED-DIRECT is M-015 event")
|
||||
assert.False(t, IsM015Event("CRED-INGRESS-PLATFORM"), "CRED-INGRESS-PLATFORM is not M-015 event")
|
||||
}
|
||||
|
||||
func TestAuditEvent_IsM016Event(t *testing.T) {
|
||||
// M-016: query key拒绝事件
|
||||
assert.True(t, IsM016Event("AUTH-QUERY-KEY"), "AUTH-QUERY-KEY is M-016 event")
|
||||
assert.True(t, IsM016Event("AUTH-QUERY-REJECT"), "AUTH-QUERY-REJECT is M-016 event")
|
||||
assert.True(t, IsM016Event("AUTH-QUERY"), "AUTH-QUERY is M-016 event")
|
||||
assert.False(t, IsM016Event("CRED-EXPOSE-RESPONSE"), "CRED-EXPOSE-RESPONSE is not M-016 event")
|
||||
}
|
||||
|
||||
func TestAuditEvent_CredentialType(t *testing.T) {
|
||||
// 测试凭证类型常量
|
||||
assert.Equal(t, "platform_token", CredentialTypePlatformToken)
|
||||
assert.Equal(t, "query_key", CredentialTypeQueryKey)
|
||||
assert.Equal(t, "upstream_api_key", CredentialTypeUpstreamAPIKey)
|
||||
assert.Equal(t, "none", CredentialTypeNone)
|
||||
}
|
||||
|
||||
func TestAuditEvent_OperatorType(t *testing.T) {
|
||||
// 测试操作者类型常量
|
||||
assert.Equal(t, "user", OperatorTypeUser)
|
||||
assert.Equal(t, "system", OperatorTypeSystem)
|
||||
assert.Equal(t, "admin", OperatorTypeAdmin)
|
||||
}
|
||||
|
||||
func TestAuditEvent_TenantType(t *testing.T) {
|
||||
// 测试租户类型常量
|
||||
assert.Equal(t, "supplier", TenantTypeSupplier)
|
||||
assert.Equal(t, "consumer", TenantTypeConsumer)
|
||||
assert.Equal(t, "platform", TenantTypePlatform)
|
||||
}
|
||||
|
||||
func TestAuditEvent_Category(t *testing.T) {
|
||||
// 测试事件类别常量
|
||||
assert.Equal(t, "CRED", CategoryCRED)
|
||||
assert.Equal(t, "AUTH", CategoryAUTH)
|
||||
assert.Equal(t, "DATA", CategoryDATA)
|
||||
assert.Equal(t, "CONFIG", CategoryCONFIG)
|
||||
assert.Equal(t, "SECURITY", CategorySECURITY)
|
||||
}
|
||||
|
||||
func TestAuditEvent_NewAuditEventTimestamp(t *testing.T) {
|
||||
// 测试时间戳自动生成
|
||||
before := time.Now()
|
||||
event := NewAuditEvent(
|
||||
"CRED-EXPOSE-RESPONSE",
|
||||
"CRED",
|
||||
"EXPOSE",
|
||||
"supplier_credential_exposure_events",
|
||||
"test-request-id",
|
||||
"test-trace-id",
|
||||
1001,
|
||||
"user",
|
||||
"admin",
|
||||
2001,
|
||||
"supplier",
|
||||
"account",
|
||||
12345,
|
||||
"create",
|
||||
"platform_token",
|
||||
"api",
|
||||
"192.168.1.1",
|
||||
true,
|
||||
"SEC_CRED_EXPOSED",
|
||||
"Credential exposed in response",
|
||||
)
|
||||
after := time.Now()
|
||||
|
||||
// 验证时间戳在合理范围内
|
||||
assert.True(t, event.Timestamp.After(before) || event.Timestamp.Equal(before), "Timestamp should be after or equal to before")
|
||||
assert.True(t, event.Timestamp.Before(after) || event.Timestamp.Equal(after), "Timestamp should be before or equal to after")
|
||||
assert.Equal(t, event.Timestamp.UnixMilli(), event.TimestampMs, "TimestampMs should match Timestamp")
|
||||
}
|
||||
220
supply-api/internal/audit/model/audit_metrics.go
Normal file
220
supply-api/internal/audit/model/audit_metrics.go
Normal file
@@ -0,0 +1,220 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// ==================== M-013: 凭证暴露事件详情 ====================
|
||||
|
||||
// CredentialExposureDetail M-013: 凭证暴露事件专用
|
||||
type CredentialExposureDetail struct {
|
||||
EventID string `json:"event_id"` // 事件ID(关联audit_events)
|
||||
ExposureType string `json:"exposure_type"` // exposed_in_response/exposed_in_log/exposed_in_export
|
||||
ExposureLocation string `json:"exposure_location"` // response_body/response_header/log_file/export_file
|
||||
ExposurePattern string `json:"exposure_pattern"` // 匹配到的正则模式
|
||||
ExposedFragment string `json:"exposed_fragment"` // 暴露的片段(已脱敏)
|
||||
ScanRuleID string `json:"scan_rule_id"` // 触发扫描规则ID
|
||||
Resolved bool `json:"resolved"` // 是否已解决
|
||||
ResolvedAt *time.Time `json:"resolved_at"` // 解决时间
|
||||
ResolvedBy *int64 `json:"resolved_by"` // 解决人
|
||||
ResolutionNotes string `json:"resolution_notes"` // 解决备注
|
||||
}
|
||||
|
||||
// NewCredentialExposureDetail 创建凭证暴露详情
|
||||
func NewCredentialExposureDetail(
|
||||
exposureType string,
|
||||
exposureLocation string,
|
||||
exposurePattern string,
|
||||
exposedFragment string,
|
||||
scanRuleID string,
|
||||
) *CredentialExposureDetail {
|
||||
return &CredentialExposureDetail{
|
||||
ExposureType: exposureType,
|
||||
ExposureLocation: exposureLocation,
|
||||
ExposurePattern: exposurePattern,
|
||||
ExposedFragment: exposedFragment,
|
||||
ScanRuleID: scanRuleID,
|
||||
Resolved: false,
|
||||
}
|
||||
}
|
||||
|
||||
// Resolve 标记为已解决
|
||||
func (d *CredentialExposureDetail) Resolve(resolvedBy int64, notes string) {
|
||||
now := time.Now()
|
||||
d.Resolved = true
|
||||
d.ResolvedAt = &now
|
||||
d.ResolvedBy = &resolvedBy
|
||||
d.ResolutionNotes = notes
|
||||
}
|
||||
|
||||
// ==================== M-014: 凭证入站事件详情 ====================
|
||||
|
||||
// CredentialIngressDetail M-014: 凭证入站类型专用
|
||||
type CredentialIngressDetail struct {
|
||||
EventID string `json:"event_id"` // 事件ID
|
||||
RequestCredentialType string `json:"request_credential_type"` // 请求中的凭证类型
|
||||
ExpectedCredentialType string `json:"expected_credential_type"` // 期望的凭证类型
|
||||
CoverageCompliant bool `json:"coverage_compliant"` // 是否合规
|
||||
PlatformTokenPresent bool `json:"platform_token_present"` // 平台Token是否存在
|
||||
UpstreamKeyPresent bool `json:"upstream_key_present"` // 上游Key是否存在
|
||||
Reviewed bool `json:"reviewed"` // 是否已审核
|
||||
ReviewedAt *time.Time `json:"reviewed_at"` // 审核时间
|
||||
ReviewedBy *int64 `json:"reviewed_by"` // 审核人
|
||||
}
|
||||
|
||||
// NewCredentialIngressDetail 创建凭证入站详情
|
||||
func NewCredentialIngressDetail(
|
||||
requestCredentialType string,
|
||||
expectedCredentialType string,
|
||||
coverageCompliant bool,
|
||||
platformTokenPresent bool,
|
||||
upstreamKeyPresent bool,
|
||||
) *CredentialIngressDetail {
|
||||
return &CredentialIngressDetail{
|
||||
RequestCredentialType: requestCredentialType,
|
||||
ExpectedCredentialType: expectedCredentialType,
|
||||
CoverageCompliant: coverageCompliant,
|
||||
PlatformTokenPresent: platformTokenPresent,
|
||||
UpstreamKeyPresent: upstreamKeyPresent,
|
||||
Reviewed: false,
|
||||
}
|
||||
}
|
||||
|
||||
// Review 标记为已审核
|
||||
func (d *CredentialIngressDetail) Review(reviewedBy int64) {
|
||||
now := time.Now()
|
||||
d.Reviewed = true
|
||||
d.ReviewedAt = &now
|
||||
d.ReviewedBy = &reviewedBy
|
||||
}
|
||||
|
||||
// ==================== M-015: 直连绕过事件详情 ====================
|
||||
|
||||
// DirectCallDetail M-015: 直连绕过专用
|
||||
type DirectCallDetail struct {
|
||||
EventID string `json:"event_id"` // 事件ID
|
||||
ConsumerID int64 `json:"consumer_id"` // 消费者ID
|
||||
SupplierID int64 `json:"supplier_id"` // 供应商ID
|
||||
DirectEndpoint string `json:"direct_endpoint"` // 直连端点
|
||||
ViaPlatform bool `json:"via_platform"` // 是否通过平台
|
||||
BypassType string `json:"bypass_type"` // ip_bypass/proxy_bypass/config_bypass/dns_bypass
|
||||
DetectionMethod string `json:"detection_method"` // 检测方法
|
||||
Blocked bool `json:"blocked"` // 是否被阻断
|
||||
BlockedAt *time.Time `json:"blocked_at"` // 阻断时间
|
||||
BlockReason string `json:"block_reason"` // 阻断原因
|
||||
}
|
||||
|
||||
// NewDirectCallDetail 创建直连详情
|
||||
func NewDirectCallDetail(
|
||||
consumerID int64,
|
||||
supplierID int64,
|
||||
directEndpoint string,
|
||||
viaPlatform bool,
|
||||
bypassType string,
|
||||
detectionMethod string,
|
||||
) *DirectCallDetail {
|
||||
return &DirectCallDetail{
|
||||
ConsumerID: consumerID,
|
||||
SupplierID: supplierID,
|
||||
DirectEndpoint: directEndpoint,
|
||||
ViaPlatform: viaPlatform,
|
||||
BypassType: bypassType,
|
||||
DetectionMethod: detectionMethod,
|
||||
Blocked: false,
|
||||
}
|
||||
}
|
||||
|
||||
// Block 标记为已阻断
|
||||
func (d *DirectCallDetail) Block(reason string) {
|
||||
now := time.Now()
|
||||
d.Blocked = true
|
||||
d.BlockedAt = &now
|
||||
d.BlockReason = reason
|
||||
}
|
||||
|
||||
// ==================== M-016: Query Key 拒绝事件详情 ====================
|
||||
|
||||
// QueryKeyRejectDetail M-016: query key 拒绝专用
|
||||
type QueryKeyRejectDetail struct {
|
||||
EventID string `json:"event_id"` // 事件ID
|
||||
QueryKeyID string `json:"query_key_id"` // Query Key ID
|
||||
RequestedEndpoint string `json:"requested_endpoint"` // 请求端点
|
||||
RejectReason string `json:"reject_reason"` // not_allowed/expired/malformed/revoked/rate_limited
|
||||
RejectCode string `json:"reject_code"` // 拒绝码
|
||||
FirstOccurrence bool `json:"first_occurrence"` // 是否首次发生
|
||||
OccurrenceCount int `json:"occurrence_count"` // 发生次数
|
||||
}
|
||||
|
||||
// NewQueryKeyRejectDetail 创建Query Key拒绝详情
|
||||
func NewQueryKeyRejectDetail(
|
||||
queryKeyID string,
|
||||
requestedEndpoint string,
|
||||
rejectReason string,
|
||||
rejectCode string,
|
||||
) *QueryKeyRejectDetail {
|
||||
return &QueryKeyRejectDetail{
|
||||
QueryKeyID: queryKeyID,
|
||||
RequestedEndpoint: requestedEndpoint,
|
||||
RejectReason: rejectReason,
|
||||
RejectCode: rejectCode,
|
||||
FirstOccurrence: true,
|
||||
OccurrenceCount: 1,
|
||||
}
|
||||
}
|
||||
|
||||
// RecordOccurrence 记录再次发生
|
||||
func (d *QueryKeyRejectDetail) RecordOccurrence(firstOccurrence bool) {
|
||||
d.FirstOccurrence = firstOccurrence
|
||||
d.OccurrenceCount++
|
||||
}
|
||||
|
||||
// ==================== 指标常量 ====================
|
||||
|
||||
// M-013 暴露类型常量
|
||||
const (
|
||||
ExposureTypeResponse = "exposed_in_response"
|
||||
ExposureTypeLog = "exposed_in_log"
|
||||
ExposureTypeExport = "exposed_in_export"
|
||||
)
|
||||
|
||||
// M-013 暴露位置常量
|
||||
const (
|
||||
ExposureLocationResponseBody = "response_body"
|
||||
ExposureLocationResponseHeader = "response_header"
|
||||
ExposureLocationLogFile = "log_file"
|
||||
ExposureLocationExportFile = "export_file"
|
||||
)
|
||||
|
||||
// M-015 绕过类型常量
|
||||
const (
|
||||
BypassTypeIPBypass = "ip_bypass"
|
||||
BypassTypeProxyBypass = "proxy_bypass"
|
||||
BypassTypeConfigBypass = "config_bypass"
|
||||
BypassTypeDNSBypass = "dns_bypass"
|
||||
)
|
||||
|
||||
// M-015 检测方法常量
|
||||
const (
|
||||
DetectionMethodUpstreamAPIPattern = "upstream_api_pattern_match"
|
||||
DetectionMethodDNSResolution = "dns_resolution_check"
|
||||
DetectionMethodConnectionSource = "connection_source_check"
|
||||
DetectionMethodIPWhitelist = "ip_whitelist_check"
|
||||
)
|
||||
|
||||
// M-016 拒绝原因常量
|
||||
const (
|
||||
RejectReasonNotAllowed = "not_allowed"
|
||||
RejectReasonExpired = "expired"
|
||||
RejectReasonMalformed = "malformed"
|
||||
RejectReasonRevoked = "revoked"
|
||||
RejectReasonRateLimited = "rate_limited"
|
||||
)
|
||||
|
||||
// M-016 拒绝码常量
|
||||
const (
|
||||
RejectCodeNotAllowed = "QUERY_KEY_NOT_ALLOWED"
|
||||
RejectCodeExpired = "QUERY_KEY_EXPIRED"
|
||||
RejectCodeMalformed = "QUERY_KEY_MALFORMED"
|
||||
RejectCodeRevoked = "QUERY_KEY_REVOKED"
|
||||
RejectCodeRateLimited = "QUERY_KEY_RATE_LIMITED"
|
||||
)
|
||||
459
supply-api/internal/audit/model/audit_metrics_test.go
Normal file
459
supply-api/internal/audit/model/audit_metrics_test.go
Normal file
@@ -0,0 +1,459 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// ==================== M-013 凭证暴露事件详情 ====================
|
||||
|
||||
func TestCredentialExposureDetail_New(t *testing.T) {
|
||||
// M-013: 凭证暴露事件专用
|
||||
detail := NewCredentialExposureDetail(
|
||||
"exposed_in_response",
|
||||
"response_body",
|
||||
"sk-[a-zA-Z0-9]{20,}",
|
||||
"sk-xxxxxx****xxxx",
|
||||
"SCAN-001",
|
||||
)
|
||||
|
||||
assert.Equal(t, "exposed_in_response", detail.ExposureType)
|
||||
assert.Equal(t, "response_body", detail.ExposureLocation)
|
||||
assert.Equal(t, "sk-[a-zA-Z0-9]{20,}", detail.ExposurePattern)
|
||||
assert.Equal(t, "sk-xxxxxx****xxxx", detail.ExposedFragment)
|
||||
assert.Equal(t, "SCAN-001", detail.ScanRuleID)
|
||||
assert.False(t, detail.Resolved)
|
||||
assert.Nil(t, detail.ResolvedAt)
|
||||
assert.Nil(t, detail.ResolvedBy)
|
||||
assert.Empty(t, detail.ResolutionNotes)
|
||||
}
|
||||
|
||||
func TestCredentialExposureDetail_Resolve(t *testing.T) {
|
||||
detail := NewCredentialExposureDetail(
|
||||
"exposed_in_response",
|
||||
"response_body",
|
||||
"sk-[a-zA-Z0-9]{20,}",
|
||||
"sk-xxxxxx****xxxx",
|
||||
"SCAN-001",
|
||||
)
|
||||
|
||||
detail.Resolve(1001, "Fixed by adding masking")
|
||||
|
||||
assert.True(t, detail.Resolved)
|
||||
assert.NotNil(t, detail.ResolvedAt)
|
||||
assert.Equal(t, int64(1001), *detail.ResolvedBy)
|
||||
assert.Equal(t, "Fixed by adding masking", detail.ResolutionNotes)
|
||||
}
|
||||
|
||||
func TestCredentialExposureDetail_ExposureTypes(t *testing.T) {
|
||||
// 验证暴露类型常量
|
||||
validTypes := []string{
|
||||
"exposed_in_response",
|
||||
"exposed_in_log",
|
||||
"exposed_in_export",
|
||||
}
|
||||
|
||||
for _, exposureType := range validTypes {
|
||||
detail := NewCredentialExposureDetail(
|
||||
exposureType,
|
||||
"response_body",
|
||||
"pattern",
|
||||
"fragment",
|
||||
"SCAN-001",
|
||||
)
|
||||
assert.Equal(t, exposureType, detail.ExposureType)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCredentialExposureDetail_ExposureLocations(t *testing.T) {
|
||||
// 验证暴露位置常量
|
||||
validLocations := []string{
|
||||
"response_body",
|
||||
"response_header",
|
||||
"log_file",
|
||||
"export_file",
|
||||
}
|
||||
|
||||
for _, location := range validLocations {
|
||||
detail := NewCredentialExposureDetail(
|
||||
"exposed_in_response",
|
||||
location,
|
||||
"pattern",
|
||||
"fragment",
|
||||
"SCAN-001",
|
||||
)
|
||||
assert.Equal(t, location, detail.ExposureLocation)
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== M-014 凭证入站事件详情 ====================
|
||||
|
||||
func TestCredentialIngressDetail_New(t *testing.T) {
|
||||
// M-014: 凭证入站类型专用
|
||||
detail := NewCredentialIngressDetail(
|
||||
"platform_token",
|
||||
"platform_token",
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
)
|
||||
|
||||
assert.Equal(t, "platform_token", detail.RequestCredentialType)
|
||||
assert.Equal(t, "platform_token", detail.ExpectedCredentialType)
|
||||
assert.True(t, detail.CoverageCompliant)
|
||||
assert.True(t, detail.PlatformTokenPresent)
|
||||
assert.False(t, detail.UpstreamKeyPresent)
|
||||
assert.False(t, detail.Reviewed)
|
||||
assert.Nil(t, detail.ReviewedAt)
|
||||
assert.Nil(t, detail.ReviewedBy)
|
||||
}
|
||||
|
||||
func TestCredentialIngressDetail_NonCompliant(t *testing.T) {
|
||||
// M-014 非合规场景:使用 query_key 而不是 platform_token
|
||||
detail := NewCredentialIngressDetail(
|
||||
"query_key",
|
||||
"platform_token",
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
)
|
||||
|
||||
assert.Equal(t, "query_key", detail.RequestCredentialType)
|
||||
assert.Equal(t, "platform_token", detail.ExpectedCredentialType)
|
||||
assert.False(t, detail.CoverageCompliant)
|
||||
assert.False(t, detail.PlatformTokenPresent)
|
||||
assert.True(t, detail.UpstreamKeyPresent)
|
||||
}
|
||||
|
||||
func TestCredentialIngressDetail_Review(t *testing.T) {
|
||||
detail := NewCredentialIngressDetail(
|
||||
"platform_token",
|
||||
"platform_token",
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
)
|
||||
|
||||
detail.Review(1001)
|
||||
|
||||
assert.True(t, detail.Reviewed)
|
||||
assert.NotNil(t, detail.ReviewedAt)
|
||||
assert.Equal(t, int64(1001), *detail.ReviewedBy)
|
||||
}
|
||||
|
||||
func TestCredentialIngressDetail_CredentialTypes(t *testing.T) {
|
||||
// 验证凭证类型
|
||||
testCases := []struct {
|
||||
credType string
|
||||
platformToken bool
|
||||
upstreamKey bool
|
||||
compliant bool
|
||||
}{
|
||||
{"platform_token", true, false, true},
|
||||
{"query_key", false, false, false},
|
||||
{"upstream_api_key", false, true, false},
|
||||
{"none", false, false, false},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
detail := NewCredentialIngressDetail(
|
||||
tc.credType,
|
||||
"platform_token",
|
||||
tc.compliant,
|
||||
tc.platformToken,
|
||||
tc.upstreamKey,
|
||||
)
|
||||
assert.Equal(t, tc.compliant, detail.CoverageCompliant, "Compliance mismatch for %s", tc.credType)
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== M-015 直连绕过事件详情 ====================
|
||||
|
||||
func TestDirectCallDetail_New(t *testing.T) {
|
||||
// M-015: 直连绕过专用
|
||||
detail := NewDirectCallDetail(
|
||||
1001, // consumerID
|
||||
2001, // supplierID
|
||||
"https://supplier.example.com/v1/chat/completions",
|
||||
false, // viaPlatform
|
||||
"ip_bypass",
|
||||
"upstream_api_pattern_match",
|
||||
)
|
||||
|
||||
assert.Equal(t, int64(1001), detail.ConsumerID)
|
||||
assert.Equal(t, int64(2001), detail.SupplierID)
|
||||
assert.Equal(t, "https://supplier.example.com/v1/chat/completions", detail.DirectEndpoint)
|
||||
assert.False(t, detail.ViaPlatform)
|
||||
assert.Equal(t, "ip_bypass", detail.BypassType)
|
||||
assert.Equal(t, "upstream_api_pattern_match", detail.DetectionMethod)
|
||||
assert.False(t, detail.Blocked)
|
||||
assert.Nil(t, detail.BlockedAt)
|
||||
assert.Empty(t, detail.BlockReason)
|
||||
}
|
||||
|
||||
func TestDirectCallDetail_Block(t *testing.T) {
|
||||
detail := NewDirectCallDetail(
|
||||
1001,
|
||||
2001,
|
||||
"https://supplier.example.com/v1/chat/completions",
|
||||
false,
|
||||
"ip_bypass",
|
||||
"upstream_api_pattern_match",
|
||||
)
|
||||
|
||||
detail.Block("P0 event - immediate block")
|
||||
|
||||
assert.True(t, detail.Blocked)
|
||||
assert.NotNil(t, detail.BlockedAt)
|
||||
assert.Equal(t, "P0 event - immediate block", detail.BlockReason)
|
||||
}
|
||||
|
||||
func TestDirectCallDetail_BypassTypes(t *testing.T) {
|
||||
// 验证绕过类型常量
|
||||
validBypassTypes := []string{
|
||||
"ip_bypass",
|
||||
"proxy_bypass",
|
||||
"config_bypass",
|
||||
"dns_bypass",
|
||||
}
|
||||
|
||||
for _, bypassType := range validBypassTypes {
|
||||
detail := NewDirectCallDetail(
|
||||
1001,
|
||||
2001,
|
||||
"https://example.com",
|
||||
false,
|
||||
bypassType,
|
||||
"detection_method",
|
||||
)
|
||||
assert.Equal(t, bypassType, detail.BypassType)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDirectCallDetail_DetectionMethods(t *testing.T) {
|
||||
// 验证检测方法常量
|
||||
validMethods := []string{
|
||||
"upstream_api_pattern_match",
|
||||
"dns_resolution_check",
|
||||
"connection_source_check",
|
||||
"ip_whitelist_check",
|
||||
}
|
||||
|
||||
for _, method := range validMethods {
|
||||
detail := NewDirectCallDetail(
|
||||
1001,
|
||||
2001,
|
||||
"https://example.com",
|
||||
false,
|
||||
"ip_bypass",
|
||||
method,
|
||||
)
|
||||
assert.Equal(t, method, detail.DetectionMethod)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDirectCallDetail_ViaPlatform(t *testing.T) {
|
||||
// 通过平台的调用不应该标记为直连
|
||||
detail := NewDirectCallDetail(
|
||||
1001,
|
||||
2001,
|
||||
"https://platform.example.com/v1/chat/completions",
|
||||
true, // viaPlatform = true
|
||||
"",
|
||||
"platform_proxy",
|
||||
)
|
||||
|
||||
assert.True(t, detail.ViaPlatform)
|
||||
assert.False(t, detail.Blocked)
|
||||
}
|
||||
|
||||
// ==================== M-016 Query Key 拒绝事件详情 ====================
|
||||
|
||||
func TestQueryKeyRejectDetail_New(t *testing.T) {
|
||||
// M-016: query key 拒绝专用
|
||||
detail := NewQueryKeyRejectDetail(
|
||||
"qk-12345",
|
||||
"/v1/chat/completions",
|
||||
"not_allowed",
|
||||
"QUERY_KEY_NOT_ALLOWED",
|
||||
)
|
||||
|
||||
assert.Equal(t, "qk-12345", detail.QueryKeyID)
|
||||
assert.Equal(t, "/v1/chat/completions", detail.RequestedEndpoint)
|
||||
assert.Equal(t, "not_allowed", detail.RejectReason)
|
||||
assert.Equal(t, "QUERY_KEY_NOT_ALLOWED", detail.RejectCode)
|
||||
assert.True(t, detail.FirstOccurrence)
|
||||
assert.Equal(t, 1, detail.OccurrenceCount)
|
||||
}
|
||||
|
||||
func TestQueryKeyRejectDetail_RecordOccurrence(t *testing.T) {
|
||||
detail := NewQueryKeyRejectDetail(
|
||||
"qk-12345",
|
||||
"/v1/chat/completions",
|
||||
"not_allowed",
|
||||
"QUERY_KEY_NOT_ALLOWED",
|
||||
)
|
||||
|
||||
// 第二次发生
|
||||
detail.RecordOccurrence(false)
|
||||
assert.Equal(t, 2, detail.OccurrenceCount)
|
||||
assert.False(t, detail.FirstOccurrence)
|
||||
|
||||
// 第三次发生
|
||||
detail.RecordOccurrence(false)
|
||||
assert.Equal(t, 3, detail.OccurrenceCount)
|
||||
}
|
||||
|
||||
func TestQueryKeyRejectDetail_RejectReasons(t *testing.T) {
|
||||
// 验证拒绝原因常量
|
||||
validReasons := []string{
|
||||
"not_allowed",
|
||||
"expired",
|
||||
"malformed",
|
||||
"revoked",
|
||||
"rate_limited",
|
||||
}
|
||||
|
||||
for _, reason := range validReasons {
|
||||
detail := NewQueryKeyRejectDetail(
|
||||
"qk-12345",
|
||||
"/v1/chat/completions",
|
||||
reason,
|
||||
"QUERY_KEY_REJECT",
|
||||
)
|
||||
assert.Equal(t, reason, detail.RejectReason)
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueryKeyRejectDetail_RejectCodes(t *testing.T) {
|
||||
// 验证拒绝码常量
|
||||
validCodes := []string{
|
||||
"QUERY_KEY_NOT_ALLOWED",
|
||||
"QUERY_KEY_EXPIRED",
|
||||
"QUERY_KEY_MALFORMED",
|
||||
"QUERY_KEY_REVOKED",
|
||||
"QUERY_KEY_RATE_LIMITED",
|
||||
}
|
||||
|
||||
for _, code := range validCodes {
|
||||
detail := NewQueryKeyRejectDetail(
|
||||
"qk-12345",
|
||||
"/v1/chat/completions",
|
||||
"not_allowed",
|
||||
code,
|
||||
)
|
||||
assert.Equal(t, code, detail.RejectCode)
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== 指标计算辅助函数 ====================
|
||||
|
||||
func TestCalculateM013(t *testing.T) {
|
||||
// M-013: 凭证泄露事件数 = 0
|
||||
events := []struct {
|
||||
eventName string
|
||||
resolved bool
|
||||
}{
|
||||
{"CRED-EXPOSE-RESPONSE", true},
|
||||
{"CRED-EXPOSE-RESPONSE", true},
|
||||
{"CRED-EXPOSE-LOG", false},
|
||||
{"AUTH-TOKEN-OK", true},
|
||||
}
|
||||
|
||||
var unresolvedCount int
|
||||
for _, e := range events {
|
||||
if IsM013Event(e.eventName) && !e.resolved {
|
||||
unresolvedCount++
|
||||
}
|
||||
}
|
||||
|
||||
assert.Equal(t, 1, unresolvedCount, "M-013 should have 1 unresolved event")
|
||||
}
|
||||
|
||||
func TestCalculateM014(t *testing.T) {
|
||||
// M-014: 平台凭证入站覆盖率 = 100%
|
||||
events := []struct {
|
||||
credentialType string
|
||||
compliant bool
|
||||
}{
|
||||
{"platform_token", true},
|
||||
{"platform_token", true},
|
||||
{"query_key", false},
|
||||
{"upstream_api_key", false},
|
||||
{"platform_token", true},
|
||||
}
|
||||
|
||||
var platformCount, totalCount int
|
||||
for _, e := range events {
|
||||
if IsM014Compliant(e.credentialType) {
|
||||
platformCount++
|
||||
}
|
||||
totalCount++
|
||||
}
|
||||
|
||||
coverage := float64(platformCount) / float64(totalCount) * 100
|
||||
assert.Equal(t, 60.0, coverage, "M-014 coverage should be 60%%")
|
||||
assert.Equal(t, 3, platformCount)
|
||||
assert.Equal(t, 5, totalCount)
|
||||
}
|
||||
|
||||
func TestCalculateM015(t *testing.T) {
|
||||
// M-015: 直连事件数 = 0
|
||||
events := []struct {
|
||||
targetDirect bool
|
||||
blocked bool
|
||||
}{
|
||||
{targetDirect: true, blocked: false},
|
||||
{targetDirect: true, blocked: true},
|
||||
{targetDirect: false, blocked: false},
|
||||
{targetDirect: true, blocked: false},
|
||||
}
|
||||
|
||||
var directCallCount, blockedCount int
|
||||
for _, e := range events {
|
||||
if e.targetDirect {
|
||||
directCallCount++
|
||||
if e.blocked {
|
||||
blockedCount++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
assert.Equal(t, 3, directCallCount, "M-015 should have 3 direct call events")
|
||||
assert.Equal(t, 1, blockedCount, "M-015 should have 1 blocked event")
|
||||
}
|
||||
|
||||
func TestCalculateM016(t *testing.T) {
|
||||
// M-016: query key 拒绝率 = 100%
|
||||
// 分母:所有query key请求(不含被拒绝的无效请求)
|
||||
events := []struct {
|
||||
eventName string
|
||||
}{
|
||||
{"AUTH-QUERY-KEY"},
|
||||
{"AUTH-QUERY-REJECT"},
|
||||
{"AUTH-QUERY-KEY"},
|
||||
{"AUTH-QUERY-REJECT"},
|
||||
{"AUTH-TOKEN-OK"},
|
||||
}
|
||||
|
||||
var totalQueryKey, rejectedCount int
|
||||
for _, e := range events {
|
||||
if IsM016Event(e.eventName) {
|
||||
totalQueryKey++
|
||||
if e.eventName == "AUTH-QUERY-REJECT" {
|
||||
rejectedCount++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
rejectRate := float64(rejectedCount) / float64(totalQueryKey) * 100
|
||||
assert.Equal(t, 4, totalQueryKey, "M-016 should have 4 query key events")
|
||||
assert.Equal(t, 2, rejectedCount, "M-016 should have 2 rejected events")
|
||||
assert.Equal(t, 50.0, rejectRate, "M-016 reject rate should be 50%%")
|
||||
}
|
||||
|
||||
// IsM014Compliant 检查凭证类型是否为M-014合规
|
||||
func IsM014Compliant(credentialType string) bool {
|
||||
return credentialType == CredentialTypePlatformToken
|
||||
}
|
||||
279
supply-api/internal/audit/sanitizer/sanitizer.go
Normal file
279
supply-api/internal/audit/sanitizer/sanitizer.go
Normal file
@@ -0,0 +1,279 @@
|
||||
package sanitizer
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ScanRule 扫描规则
|
||||
type ScanRule struct {
|
||||
ID string
|
||||
Pattern *regexp.Regexp
|
||||
Description string
|
||||
Severity string
|
||||
}
|
||||
|
||||
// Violation 违规项
|
||||
type Violation struct {
|
||||
Type string // 违规类型
|
||||
Pattern string // 匹配的正则模式
|
||||
Value string // 匹配的值(已脱敏)
|
||||
Description string
|
||||
}
|
||||
|
||||
// ScanResult 扫描结果
|
||||
type ScanResult struct {
|
||||
Violations []Violation
|
||||
Passed bool
|
||||
}
|
||||
|
||||
// NewScanResult 创建扫描结果
|
||||
func NewScanResult() *ScanResult {
|
||||
return &ScanResult{
|
||||
Violations: []Violation{},
|
||||
Passed: true,
|
||||
}
|
||||
}
|
||||
|
||||
// HasViolation 检查是否有违规
|
||||
func (r *ScanResult) HasViolation() bool {
|
||||
return len(r.Violations) > 0
|
||||
}
|
||||
|
||||
// AddViolation 添加违规项
|
||||
func (r *ScanResult) AddViolation(v Violation) {
|
||||
r.Violations = append(r.Violations, v)
|
||||
r.Passed = false
|
||||
}
|
||||
|
||||
// CredentialScanner 凭证扫描器
|
||||
type CredentialScanner struct {
|
||||
rules []ScanRule
|
||||
}
|
||||
|
||||
// NewCredentialScanner 创建凭证扫描器
|
||||
func NewCredentialScanner() *CredentialScanner {
|
||||
scanner := &CredentialScanner{
|
||||
rules: []ScanRule{
|
||||
{
|
||||
ID: "openai_key",
|
||||
Pattern: regexp.MustCompile(`sk-[a-zA-Z0-9]{20,}`),
|
||||
Description: "OpenAI API Key",
|
||||
Severity: "HIGH",
|
||||
},
|
||||
{
|
||||
ID: "api_key",
|
||||
Pattern: regexp.MustCompile(`(?i)(api[_-]?key|apikey)["\s:=]+['"]?([a-zA-Z0-9_\-]{16,})['"]?`),
|
||||
Description: "Generic API Key",
|
||||
Severity: "MEDIUM",
|
||||
},
|
||||
{
|
||||
ID: "aws_access_key",
|
||||
Pattern: regexp.MustCompile(`(?i)(access[_-]?key[_-]?id|aws[_-]?access[_-]?key)["\s:=]+['"]?(AKIA[0-9A-Z]{16})['"]?`),
|
||||
Description: "AWS Access Key ID",
|
||||
Severity: "HIGH",
|
||||
},
|
||||
{
|
||||
ID: "aws_secret_key",
|
||||
Pattern: regexp.MustCompile(`(?i)(secret[_-]?key|aws[_-]?.*secret[_-]?key)["\s:=]+['"]?([a-zA-Z0-9/+=]{40})['"]?`),
|
||||
Description: "AWS Secret Access Key",
|
||||
Severity: "HIGH",
|
||||
},
|
||||
{
|
||||
ID: "password",
|
||||
Pattern: regexp.MustCompile(`(?i)(password|passwd|pwd)["\s:=]+['"]?([a-zA-Z0-9@#$%^&*!]{8,})['"]?`),
|
||||
Description: "Password",
|
||||
Severity: "HIGH",
|
||||
},
|
||||
{
|
||||
ID: "bearer_token",
|
||||
Pattern: regexp.MustCompile(`(?i)(token|bearer|authorization)["\s:=]+['"]?([Bb]earer\s+)?([a-zA-Z0-9_\-\.]+)['"]?`),
|
||||
Description: "Bearer Token",
|
||||
Severity: "MEDIUM",
|
||||
},
|
||||
{
|
||||
ID: "private_key",
|
||||
Pattern: regexp.MustCompile(`-----BEGIN\s+(RSA\s+)?PRIVATE\s+KEY-----`),
|
||||
Description: "Private Key",
|
||||
Severity: "CRITICAL",
|
||||
},
|
||||
{
|
||||
ID: "secret",
|
||||
Pattern: regexp.MustCompile(`(?i)(secret|client[_-]?secret)["\s:=]+['"]?([a-zA-Z0-9_\-]{16,})['"]?`),
|
||||
Description: "Secret",
|
||||
Severity: "HIGH",
|
||||
},
|
||||
},
|
||||
}
|
||||
return scanner
|
||||
}
|
||||
|
||||
// Scan 扫描内容
|
||||
func (s *CredentialScanner) Scan(content string) *ScanResult {
|
||||
result := NewScanResult()
|
||||
|
||||
for _, rule := range s.rules {
|
||||
matches := rule.Pattern.FindAllStringSubmatch(content, -1)
|
||||
for _, match := range matches {
|
||||
// 构建违规项
|
||||
violation := Violation{
|
||||
Type: rule.ID,
|
||||
Pattern: rule.Pattern.String(),
|
||||
Description: rule.Description,
|
||||
}
|
||||
|
||||
// 提取匹配的值(取最后一个匹配组)
|
||||
if len(match) > 1 {
|
||||
violation.Value = maskString(match[len(match)-1])
|
||||
} else {
|
||||
violation.Value = maskString(match[0])
|
||||
}
|
||||
|
||||
result.AddViolation(violation)
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// GetRules 获取扫描规则
|
||||
func (s *CredentialScanner) GetRules() []ScanRule {
|
||||
return s.rules
|
||||
}
|
||||
|
||||
// Sanitizer 脱敏器
|
||||
type Sanitizer struct {
|
||||
patterns []*regexp.Regexp
|
||||
}
|
||||
|
||||
// NewSanitizer 创建脱敏器
|
||||
func NewSanitizer() *Sanitizer {
|
||||
return &Sanitizer{
|
||||
patterns: []*regexp.Regexp{
|
||||
// OpenAI API Key
|
||||
regexp.MustCompile(`(sk-[a-zA-Z0-9]{4})[a-zA-Z0-9]+([a-zA-Z0-9]{4})`),
|
||||
// AWS Access Key
|
||||
regexp.MustCompile(`(AKIA[0-9A-Z]{4})[0-9A-Z]+([0-9A-Z]{4})`),
|
||||
// Generic API Key
|
||||
regexp.MustCompile(`([a-zA-Z0-9_\-]{4})[a-zA-Z0-9_\-]{8,}([a-zA-Z0-9_\-]{4})`),
|
||||
// Password
|
||||
regexp.MustCompile(`([a-zA-Z0-9@#$%^&*!]{4})[a-zA-Z0-9@#$%^&*!]+([a-zA-Z0-9@#$%^&*!]{4})`),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Mask 对字符串进行脱敏
|
||||
func (s *Sanitizer) Mask(content string) string {
|
||||
result := content
|
||||
|
||||
for _, pattern := range s.patterns {
|
||||
// 替换为格式:前4字符 + **** + 后4字符
|
||||
result = pattern.ReplaceAllStringFunc(result, func(match string) string {
|
||||
// 尝试分组替换
|
||||
re := regexp.MustCompile(`^(.{4}).+(.{4})$`)
|
||||
submatch := re.FindStringSubmatch(match)
|
||||
if len(submatch) == 3 {
|
||||
return submatch[1] + "****" + submatch[2]
|
||||
}
|
||||
// 如果无法分组,直接掩码
|
||||
if len(match) > 8 {
|
||||
return match[:4] + "****" + match[len(match)-4:]
|
||||
}
|
||||
return "****"
|
||||
})
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// MaskMap 对map进行脱敏
|
||||
func (s *Sanitizer) MaskMap(data map[string]interface{}) map[string]interface{} {
|
||||
result := make(map[string]interface{})
|
||||
|
||||
for key, value := range data {
|
||||
if IsSensitiveField(key) {
|
||||
if str, ok := value.(string); ok {
|
||||
result[key] = s.Mask(str)
|
||||
} else {
|
||||
result[key] = value
|
||||
}
|
||||
} else {
|
||||
result[key] = s.maskValue(value)
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// MaskSlice 对slice进行脱敏
|
||||
func (s *Sanitizer) MaskSlice(data []string) []string {
|
||||
result := make([]string, len(data))
|
||||
for i, item := range data {
|
||||
result[i] = s.Mask(item)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// maskValue 递归掩码
|
||||
func (s *Sanitizer) maskValue(value interface{}) interface{} {
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
return s.Mask(v)
|
||||
case map[string]interface{}:
|
||||
return s.MaskMap(v)
|
||||
case []interface{}:
|
||||
result := make([]interface{}, len(v))
|
||||
for i, item := range v {
|
||||
result[i] = s.maskValue(item)
|
||||
}
|
||||
return result
|
||||
case []string:
|
||||
return s.MaskSlice(v)
|
||||
default:
|
||||
return v
|
||||
}
|
||||
}
|
||||
|
||||
// maskString 掩码字符串
|
||||
func maskString(s string) string {
|
||||
if len(s) > 8 {
|
||||
return s[:4] + "****" + s[len(s)-4:]
|
||||
}
|
||||
return "****"
|
||||
}
|
||||
|
||||
// GetSensitiveFields 获取敏感字段列表
|
||||
func GetSensitiveFields() []string {
|
||||
return []string{
|
||||
"api_key",
|
||||
"apikey",
|
||||
"secret",
|
||||
"secret_key",
|
||||
"password",
|
||||
"passwd",
|
||||
"pwd",
|
||||
"token",
|
||||
"access_key",
|
||||
"access_key_id",
|
||||
"private_key",
|
||||
"session_id",
|
||||
"authorization",
|
||||
"bearer",
|
||||
"client_secret",
|
||||
"credentials",
|
||||
}
|
||||
}
|
||||
|
||||
// IsSensitiveField 判断字段名是否为敏感字段
|
||||
func IsSensitiveField(fieldName string) bool {
|
||||
lowerName := strings.ToLower(fieldName)
|
||||
sensitiveFields := GetSensitiveFields()
|
||||
|
||||
for _, sf := range sensitiveFields {
|
||||
if strings.Contains(lowerName, sf) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
290
supply-api/internal/audit/sanitizer/sanitizer_test.go
Normal file
290
supply-api/internal/audit/sanitizer/sanitizer_test.go
Normal file
@@ -0,0 +1,290 @@
|
||||
package sanitizer
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestSanitizer_Scan_CredentialExposure(t *testing.T) {
|
||||
// 检测响应体中的凭证泄露
|
||||
scanner := NewCredentialScanner()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
content string
|
||||
expectFound bool
|
||||
expectedTypes []string
|
||||
}{
|
||||
{
|
||||
name: "OpenAI API Key",
|
||||
content: "Your API key is sk-1234567890abcdefghijklmnopqrstuvwxyz",
|
||||
expectFound: true,
|
||||
expectedTypes: []string{"openai_key"},
|
||||
},
|
||||
{
|
||||
name: "AWS Access Key",
|
||||
content: "access_key_id: AKIAIOSFODNN7EXAMPLE",
|
||||
expectFound: true,
|
||||
expectedTypes: []string{"aws_access_key"},
|
||||
},
|
||||
{
|
||||
name: "Client Secret",
|
||||
content: "client_secret: c3VwZXJzZWNyZXRrZXlzZWNyZXRrZXk=",
|
||||
expectFound: true,
|
||||
expectedTypes: []string{"secret"},
|
||||
},
|
||||
{
|
||||
name: "Generic API Key",
|
||||
content: "api_key: key-1234567890abcdefghij",
|
||||
expectFound: true,
|
||||
expectedTypes: []string{"api_key"},
|
||||
},
|
||||
{
|
||||
name: "Password Field",
|
||||
content: "password: mysecretpassword123",
|
||||
expectFound: true,
|
||||
expectedTypes: []string{"password"},
|
||||
},
|
||||
{
|
||||
name: "Token Field",
|
||||
content: "token: Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9",
|
||||
expectFound: true,
|
||||
expectedTypes: []string{"bearer_token"},
|
||||
},
|
||||
{
|
||||
name: "Normal Text",
|
||||
content: "This is normal text without credentials",
|
||||
expectFound: false,
|
||||
expectedTypes: nil,
|
||||
},
|
||||
{
|
||||
name: "Already Masked",
|
||||
content: "api_key: sk-****-****",
|
||||
expectFound: false,
|
||||
expectedTypes: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
result := scanner.Scan(tc.content)
|
||||
|
||||
if tc.expectFound {
|
||||
assert.True(t, result.HasViolation(), "Expected violation for: %s", tc.name)
|
||||
assert.NotEmpty(t, result.Violations, "Expected violations for: %s", tc.name)
|
||||
|
||||
var foundTypes []string
|
||||
for _, v := range result.Violations {
|
||||
foundTypes = append(foundTypes, v.Type)
|
||||
}
|
||||
|
||||
for _, expectedType := range tc.expectedTypes {
|
||||
assert.Contains(t, foundTypes, expectedType, "Expected type %s in violations for: %s", expectedType, tc.name)
|
||||
}
|
||||
} else {
|
||||
assert.False(t, result.HasViolation(), "Expected no violation for: %s", tc.name)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizer_Scan_Masking(t *testing.T) {
|
||||
// 脱敏:'sk-xxxx' 格式
|
||||
sanitizer := NewSanitizer()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
input string
|
||||
expectedOutput string
|
||||
expectMasked bool
|
||||
}{
|
||||
{
|
||||
name: "OpenAI Key",
|
||||
input: "sk-1234567890abcdefghijklmnopqrstuvwxyz",
|
||||
expectedOutput: "sk-xxxxxx****xxxx",
|
||||
expectMasked: true,
|
||||
},
|
||||
{
|
||||
name: "Short OpenAI Key",
|
||||
input: "sk-1234567890",
|
||||
expectedOutput: "sk-****7890",
|
||||
expectMasked: true,
|
||||
},
|
||||
{
|
||||
name: "AWS Access Key",
|
||||
input: "AKIAIOSFODNN7EXAMPLE",
|
||||
expectedOutput: "AKIA****EXAMPLE",
|
||||
expectMasked: true,
|
||||
},
|
||||
{
|
||||
name: "Normal Text",
|
||||
input: "This is normal text",
|
||||
expectedOutput: "This is normal text",
|
||||
expectMasked: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
result := sanitizer.Mask(tc.input)
|
||||
|
||||
if tc.expectMasked {
|
||||
assert.NotEqual(t, tc.input, result, "Expected masking for: %s", tc.name)
|
||||
assert.Contains(t, result, "****", "Expected **** in masked result for: %s", tc.name)
|
||||
} else {
|
||||
assert.Equal(t, tc.expectedOutput, result, "Expected unchanged for: %s", tc.name)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizer_Scan_ResponseBody(t *testing.T) {
|
||||
// 检测响应体中的凭证泄露
|
||||
scanner := NewCredentialScanner()
|
||||
|
||||
responseBody := `{
|
||||
"success": true,
|
||||
"data": {
|
||||
"api_key": "sk-1234567890abcdefghijklmnopqrstuvwxyz",
|
||||
"user": "testuser"
|
||||
}
|
||||
}`
|
||||
|
||||
result := scanner.Scan(responseBody)
|
||||
|
||||
assert.True(t, result.HasViolation())
|
||||
assert.NotEmpty(t, result.Violations)
|
||||
|
||||
// 验证找到了api_key类型的违规
|
||||
foundTypes := make([]string, 0)
|
||||
for _, v := range result.Violations {
|
||||
foundTypes = append(foundTypes, v.Type)
|
||||
}
|
||||
assert.Contains(t, foundTypes, "api_key")
|
||||
}
|
||||
|
||||
func TestSanitizer_MaskMap(t *testing.T) {
|
||||
// 测试对map进行脱敏
|
||||
sanitizer := NewSanitizer()
|
||||
|
||||
input := map[string]interface{}{
|
||||
"api_key": "sk-1234567890abcdefghijklmnopqrstuvwxyz",
|
||||
"secret": "mysecretkey123",
|
||||
"user": "testuser",
|
||||
}
|
||||
|
||||
masked := sanitizer.MaskMap(input)
|
||||
|
||||
// 验证敏感字段被脱敏
|
||||
assert.NotEqual(t, input["api_key"], masked["api_key"])
|
||||
assert.NotEqual(t, input["secret"], masked["secret"])
|
||||
assert.Equal(t, input["user"], masked["user"])
|
||||
|
||||
// 验证脱敏格式
|
||||
assert.Contains(t, masked["api_key"], "****")
|
||||
assert.Contains(t, masked["secret"], "****")
|
||||
}
|
||||
|
||||
func TestSanitizer_MaskSlice(t *testing.T) {
|
||||
// 测试对slice进行脱敏
|
||||
sanitizer := NewSanitizer()
|
||||
|
||||
input := []string{
|
||||
"sk-1234567890abcdefghijklmnopqrstuvwxyz",
|
||||
"normal text",
|
||||
"password123",
|
||||
}
|
||||
|
||||
masked := sanitizer.MaskSlice(input)
|
||||
|
||||
assert.Len(t, masked, 3)
|
||||
assert.NotEqual(t, input[0], masked[0])
|
||||
assert.Equal(t, input[1], masked[1])
|
||||
assert.NotEqual(t, input[2], masked[2])
|
||||
}
|
||||
|
||||
func TestCredentialScanner_SensitiveFields(t *testing.T) {
|
||||
// 测试敏感字段列表
|
||||
fields := GetSensitiveFields()
|
||||
|
||||
// 验证常见敏感字段
|
||||
assert.Contains(t, fields, "api_key")
|
||||
assert.Contains(t, fields, "secret")
|
||||
assert.Contains(t, fields, "password")
|
||||
assert.Contains(t, fields, "token")
|
||||
assert.Contains(t, fields, "access_key")
|
||||
assert.Contains(t, fields, "private_key")
|
||||
}
|
||||
|
||||
func TestCredentialScanner_ScanRules(t *testing.T) {
|
||||
// 测试扫描规则
|
||||
scanner := NewCredentialScanner()
|
||||
|
||||
rules := scanner.GetRules()
|
||||
assert.NotEmpty(t, rules, "Scanner should have rules")
|
||||
|
||||
// 验证规则有ID和描述
|
||||
for _, rule := range rules {
|
||||
assert.NotEmpty(t, rule.ID)
|
||||
assert.NotEmpty(t, rule.Description)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizer_IsSensitiveField(t *testing.T) {
|
||||
// 测试字段名敏感性判断
|
||||
testCases := []struct {
|
||||
fieldName string
|
||||
expected bool
|
||||
}{
|
||||
{"api_key", true},
|
||||
{"secret", true},
|
||||
{"password", true},
|
||||
{"token", true},
|
||||
{"access_key", true},
|
||||
{"private_key", true},
|
||||
{"session_id", true},
|
||||
{"authorization", true},
|
||||
{"user", false},
|
||||
{"name", false},
|
||||
{"email", false},
|
||||
{"id", false},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.fieldName, func(t *testing.T) {
|
||||
result := IsSensitiveField(tc.fieldName)
|
||||
assert.Equal(t, tc.expected, result, "Field %s sensitivity mismatch", tc.fieldName)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizer_ScanLog(t *testing.T) {
|
||||
// 测试日志扫描
|
||||
scanner := NewCredentialScanner()
|
||||
|
||||
logLine := `2026-04-02 10:30:45 INFO [api] Request completed api_key=sk-1234567890abcdefghijklmnopqrstuvwxyz duration=100ms`
|
||||
|
||||
result := scanner.Scan(logLine)
|
||||
|
||||
assert.True(t, result.HasViolation())
|
||||
assert.NotEmpty(t, result.Violations)
|
||||
// sk-开头的key会被识别为openai_key
|
||||
assert.Equal(t, "openai_key", result.Violations[0].Type)
|
||||
}
|
||||
|
||||
func TestSanitizer_MultipleViolations(t *testing.T) {
|
||||
// 测试多个违规
|
||||
scanner := NewCredentialScanner()
|
||||
|
||||
content := `{
|
||||
"api_key": "sk-1234567890abcdefghijklmnopqrstuvwxyz",
|
||||
"secret_key": "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY",
|
||||
"password": "mysecretpassword"
|
||||
}`
|
||||
|
||||
result := scanner.Scan(content)
|
||||
|
||||
assert.True(t, result.HasViolation())
|
||||
assert.GreaterOrEqual(t, len(result.Violations), 3)
|
||||
}
|
||||
308
supply-api/internal/audit/service/audit_service.go
Normal file
308
supply-api/internal/audit/service/audit_service.go
Normal file
@@ -0,0 +1,308 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"lijiaoqiao/supply-api/internal/audit/model"
|
||||
)
|
||||
|
||||
// 错误定义
|
||||
var (
|
||||
ErrInvalidInput = errors.New("invalid input: event is nil")
|
||||
ErrMissingEventName = errors.New("invalid input: event name is required")
|
||||
ErrEventNotFound = errors.New("event not found")
|
||||
ErrIdempotencyConflict = errors.New("idempotency key conflict")
|
||||
)
|
||||
|
||||
// CreateEventResult 事件创建结果
|
||||
type CreateEventResult struct {
|
||||
EventID string `json:"event_id"`
|
||||
StatusCode int `json:"status_code"`
|
||||
Status string `json:"status"`
|
||||
OriginalCreatedAt *time.Time `json:"original_created_at,omitempty"`
|
||||
ErrorCode string `json:"error_code,omitempty"`
|
||||
ErrorMessage string `json:"error_message,omitempty"`
|
||||
RetryAfterMs int64 `json:"retry_after_ms,omitempty"`
|
||||
}
|
||||
|
||||
// EventFilter 事件查询过滤器
|
||||
type EventFilter struct {
|
||||
TenantID int64
|
||||
Category string
|
||||
EventName string
|
||||
ObjectType string
|
||||
ObjectID int64
|
||||
StartTime time.Time
|
||||
EndTime time.Time
|
||||
Success *bool
|
||||
Limit int
|
||||
Offset int
|
||||
}
|
||||
|
||||
// AuditStoreInterface 审计存储接口
|
||||
type AuditStoreInterface interface {
|
||||
Emit(ctx context.Context, event *model.AuditEvent) error
|
||||
Query(ctx context.Context, filter *EventFilter) ([]*model.AuditEvent, int64, error)
|
||||
GetByIdempotencyKey(ctx context.Context, key string) (*model.AuditEvent, error)
|
||||
}
|
||||
|
||||
// InMemoryAuditStore 内存审计存储
|
||||
type InMemoryAuditStore struct {
|
||||
mu sync.RWMutex
|
||||
events []*model.AuditEvent
|
||||
nextID int64
|
||||
idempotencyKeys map[string]*model.AuditEvent
|
||||
}
|
||||
|
||||
// NewInMemoryAuditStore 创建内存审计存储
|
||||
func NewInMemoryAuditStore() *InMemoryAuditStore {
|
||||
return &InMemoryAuditStore{
|
||||
events: make([]*model.AuditEvent, 0),
|
||||
nextID: 1,
|
||||
idempotencyKeys: make(map[string]*model.AuditEvent),
|
||||
}
|
||||
}
|
||||
|
||||
// Emit 发送事件
|
||||
func (s *InMemoryAuditStore) Emit(ctx context.Context, event *model.AuditEvent) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
// 生成事件ID
|
||||
if event.EventID == "" {
|
||||
event.EventID = generateEventID()
|
||||
}
|
||||
event.CreatedAt = time.Now()
|
||||
|
||||
s.events = append(s.events, event)
|
||||
|
||||
// 如果有幂等键,记录映射
|
||||
if event.IdempotencyKey != "" {
|
||||
s.idempotencyKeys[event.IdempotencyKey] = event
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Query 查询事件
|
||||
func (s *InMemoryAuditStore) Query(ctx context.Context, filter *EventFilter) ([]*model.AuditEvent, int64, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
var result []*model.AuditEvent
|
||||
for _, e := range s.events {
|
||||
// 按租户过滤
|
||||
if filter.TenantID > 0 && e.TenantID != filter.TenantID {
|
||||
continue
|
||||
}
|
||||
// 按类别过滤
|
||||
if filter.Category != "" && e.EventCategory != filter.Category {
|
||||
continue
|
||||
}
|
||||
// 按事件名称过滤
|
||||
if filter.EventName != "" && e.EventName != filter.EventName {
|
||||
continue
|
||||
}
|
||||
// 按对象类型过滤
|
||||
if filter.ObjectType != "" && e.ObjectType != filter.ObjectType {
|
||||
continue
|
||||
}
|
||||
// 按对象ID过滤
|
||||
if filter.ObjectID > 0 && e.ObjectID != filter.ObjectID {
|
||||
continue
|
||||
}
|
||||
// 按时间范围过滤
|
||||
if !filter.StartTime.IsZero() && e.Timestamp.Before(filter.StartTime) {
|
||||
continue
|
||||
}
|
||||
if !filter.EndTime.IsZero() && e.Timestamp.After(filter.EndTime) {
|
||||
continue
|
||||
}
|
||||
// 按成功状态过滤
|
||||
if filter.Success != nil && e.Success != *filter.Success {
|
||||
continue
|
||||
}
|
||||
|
||||
result = append(result, e)
|
||||
}
|
||||
|
||||
total := int64(len(result))
|
||||
|
||||
// 分页
|
||||
if filter.Offset > 0 {
|
||||
if filter.Offset >= len(result) {
|
||||
return []*model.AuditEvent{}, total, nil
|
||||
}
|
||||
result = result[filter.Offset:]
|
||||
}
|
||||
if filter.Limit > 0 && filter.Limit < len(result) {
|
||||
result = result[:filter.Limit]
|
||||
}
|
||||
|
||||
return result, total, nil
|
||||
}
|
||||
|
||||
// GetByIdempotencyKey 根据幂等键获取事件
|
||||
func (s *InMemoryAuditStore) GetByIdempotencyKey(ctx context.Context, key string) (*model.AuditEvent, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
if event, ok := s.idempotencyKeys[key]; ok {
|
||||
return event, nil
|
||||
}
|
||||
return nil, ErrEventNotFound
|
||||
}
|
||||
|
||||
// generateEventID 生成事件ID
|
||||
func generateEventID() string {
|
||||
now := time.Now()
|
||||
return now.Format("20060102150405.000000") + fmt.Sprintf("%03d", now.Nanosecond()%1000000/1000) + "-evt"
|
||||
}
|
||||
|
||||
// AuditService 审计服务
|
||||
type AuditService struct {
|
||||
store AuditStoreInterface
|
||||
processingDelay time.Duration
|
||||
}
|
||||
|
||||
// NewAuditService 创建审计服务
|
||||
func NewAuditService(store AuditStoreInterface) *AuditService {
|
||||
return &AuditService{
|
||||
store: store,
|
||||
}
|
||||
}
|
||||
|
||||
// SetProcessingDelay 设置处理延迟(用于模拟异步处理)
|
||||
func (s *AuditService) SetProcessingDelay(delay time.Duration) {
|
||||
s.processingDelay = delay
|
||||
}
|
||||
|
||||
// CreateEvent 创建审计事件
|
||||
func (s *AuditService) CreateEvent(ctx context.Context, event *model.AuditEvent) (*CreateEventResult, error) {
|
||||
// 输入验证
|
||||
if event == nil {
|
||||
return nil, ErrInvalidInput
|
||||
}
|
||||
if event.EventName == "" {
|
||||
return nil, ErrMissingEventName
|
||||
}
|
||||
|
||||
// 设置时间戳
|
||||
if event.Timestamp.IsZero() {
|
||||
event.Timestamp = time.Now()
|
||||
}
|
||||
if event.TimestampMs == 0 {
|
||||
event.TimestampMs = event.Timestamp.UnixMilli()
|
||||
}
|
||||
|
||||
// 如果没有事件ID,生成一个
|
||||
if event.EventID == "" {
|
||||
event.EventID = generateEventID()
|
||||
}
|
||||
|
||||
// 处理幂等性
|
||||
if event.IdempotencyKey != "" {
|
||||
existing, err := s.store.GetByIdempotencyKey(ctx, event.IdempotencyKey)
|
||||
if err == nil && existing != nil {
|
||||
// 检查payload是否相同
|
||||
if isSamePayload(existing, event) {
|
||||
// 重放同参 - 返回200
|
||||
return &CreateEventResult{
|
||||
EventID: existing.EventID,
|
||||
StatusCode: 200,
|
||||
Status: "duplicate",
|
||||
OriginalCreatedAt: &existing.CreatedAt,
|
||||
}, nil
|
||||
} else {
|
||||
// 重放异参 - 返回409
|
||||
return &CreateEventResult{
|
||||
StatusCode: 409,
|
||||
Status: "conflict",
|
||||
ErrorCode: "IDEMPOTENCY_PAYLOAD_MISMATCH",
|
||||
ErrorMessage: "Idempotency key reused with different payload",
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 首次创建 - 返回201
|
||||
err := s.store.Emit(ctx, event)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &CreateEventResult{
|
||||
EventID: event.EventID,
|
||||
StatusCode: 201,
|
||||
Status: "created",
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ListEvents 列出事件(带分页)
|
||||
func (s *AuditService) ListEvents(ctx context.Context, tenantID int64, offset, limit int) ([]*model.AuditEvent, int64, error) {
|
||||
filter := &EventFilter{
|
||||
TenantID: tenantID,
|
||||
Offset: offset,
|
||||
Limit: limit,
|
||||
}
|
||||
return s.store.Query(ctx, filter)
|
||||
}
|
||||
|
||||
// ListEventsWithFilter 列出事件(带过滤器)
|
||||
func (s *AuditService) ListEventsWithFilter(ctx context.Context, filter *EventFilter) ([]*model.AuditEvent, int64, error) {
|
||||
return s.store.Query(ctx, filter)
|
||||
}
|
||||
|
||||
// HashIdempotencyKey 计算幂等键的哈希值
|
||||
func (s *AuditService) HashIdempotencyKey(key string) string {
|
||||
hash := sha256.Sum256([]byte(key))
|
||||
return hex.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
// isSamePayload 检查两个事件的payload是否相同
|
||||
func isSamePayload(a, b *model.AuditEvent) bool {
|
||||
// 比较关键字段
|
||||
if a.EventName != b.EventName {
|
||||
return false
|
||||
}
|
||||
if a.EventCategory != b.EventCategory {
|
||||
return false
|
||||
}
|
||||
if a.OperatorID != b.OperatorID {
|
||||
return false
|
||||
}
|
||||
if a.TenantID != b.TenantID {
|
||||
return false
|
||||
}
|
||||
if a.ObjectType != b.ObjectType {
|
||||
return false
|
||||
}
|
||||
if a.ObjectID != b.ObjectID {
|
||||
return false
|
||||
}
|
||||
if a.Action != b.Action {
|
||||
return false
|
||||
}
|
||||
if a.CredentialType != b.CredentialType {
|
||||
return false
|
||||
}
|
||||
if a.SourceType != b.SourceType {
|
||||
return false
|
||||
}
|
||||
if a.SourceIP != b.SourceIP {
|
||||
return false
|
||||
}
|
||||
if a.Success != b.Success {
|
||||
return false
|
||||
}
|
||||
if a.ResultCode != b.ResultCode {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
403
supply-api/internal/audit/service/audit_service_test.go
Normal file
403
supply-api/internal/audit/service/audit_service_test.go
Normal file
@@ -0,0 +1,403 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"lijiaoqiao/supply-api/internal/audit/model"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// ==================== 写入API测试 ====================
|
||||
|
||||
func TestAuditService_CreateEvent_Success(t *testing.T) {
|
||||
// 201 首次成功
|
||||
ctx := context.Background()
|
||||
svc := NewAuditService(NewInMemoryAuditStore())
|
||||
|
||||
event := &model.AuditEvent{
|
||||
EventID: "test-event-1",
|
||||
EventName: "CRED-EXPOSE-RESPONSE",
|
||||
EventCategory: "CRED",
|
||||
OperatorID: 1001,
|
||||
TenantID: 2001,
|
||||
ObjectType: "account",
|
||||
ObjectID: 12345,
|
||||
Action: "create",
|
||||
CredentialType: "platform_token",
|
||||
SourceType: "api",
|
||||
SourceIP: "192.168.1.1",
|
||||
Success: true,
|
||||
ResultCode: "SEC_CRED_EXPOSED",
|
||||
IdempotencyKey: "idem-key-001",
|
||||
}
|
||||
|
||||
result, err := svc.CreateEvent(ctx, event)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
assert.Equal(t, 201, result.StatusCode)
|
||||
assert.NotEmpty(t, result.EventID)
|
||||
assert.Equal(t, "created", result.Status)
|
||||
}
|
||||
|
||||
func TestAuditService_CreateEvent_IdempotentReplay(t *testing.T) {
|
||||
// 200 重放同参
|
||||
ctx := context.Background()
|
||||
svc := NewAuditService(NewInMemoryAuditStore())
|
||||
|
||||
event := &model.AuditEvent{
|
||||
EventID: "test-event-2",
|
||||
EventName: "CRED-INGRESS-PLATFORM",
|
||||
EventCategory: "CRED",
|
||||
OperatorID: 1001,
|
||||
TenantID: 2001,
|
||||
ObjectType: "account",
|
||||
ObjectID: 12345,
|
||||
Action: "query",
|
||||
CredentialType: "platform_token",
|
||||
SourceType: "api",
|
||||
SourceIP: "192.168.1.1",
|
||||
Success: true,
|
||||
ResultCode: "CRED_INGRESS_OK",
|
||||
IdempotencyKey: "idem-key-002",
|
||||
}
|
||||
|
||||
// 首次创建
|
||||
result1, err1 := svc.CreateEvent(ctx, event)
|
||||
assert.NoError(t, err1)
|
||||
assert.Equal(t, 201, result1.StatusCode)
|
||||
|
||||
// 重放同参
|
||||
result2, err2 := svc.CreateEvent(ctx, event)
|
||||
assert.NoError(t, err2)
|
||||
assert.Equal(t, 200, result2.StatusCode)
|
||||
assert.Equal(t, result1.EventID, result2.EventID)
|
||||
assert.Equal(t, "duplicate", result2.Status)
|
||||
}
|
||||
|
||||
func TestAuditService_CreateEvent_PayloadMismatch(t *testing.T) {
|
||||
// 409 重放异参
|
||||
ctx := context.Background()
|
||||
svc := NewAuditService(NewInMemoryAuditStore())
|
||||
|
||||
// 第一次事件
|
||||
event1 := &model.AuditEvent{
|
||||
EventName: "CRED-INGRESS-PLATFORM",
|
||||
EventCategory: "CRED",
|
||||
OperatorID: 1001,
|
||||
TenantID: 2001,
|
||||
ObjectType: "account",
|
||||
ObjectID: 12345,
|
||||
Action: "query",
|
||||
CredentialType: "platform_token",
|
||||
SourceType: "api",
|
||||
SourceIP: "192.168.1.1",
|
||||
Success: true,
|
||||
ResultCode: "CRED_INGRESS_OK",
|
||||
IdempotencyKey: "idem-key-003",
|
||||
}
|
||||
|
||||
// 第二次同幂等键但不同payload
|
||||
event2 := &model.AuditEvent{
|
||||
EventName: "CRED-INGRESS-PLATFORM",
|
||||
EventCategory: "CRED",
|
||||
OperatorID: 1002, // 不同的operator
|
||||
TenantID: 2001,
|
||||
ObjectType: "account",
|
||||
ObjectID: 12345,
|
||||
Action: "query",
|
||||
CredentialType: "platform_token",
|
||||
SourceType: "api",
|
||||
SourceIP: "192.168.1.1",
|
||||
Success: true,
|
||||
ResultCode: "CRED_INGRESS_OK",
|
||||
IdempotencyKey: "idem-key-003", // 同幂等键
|
||||
}
|
||||
|
||||
// 首次创建
|
||||
result1, err1 := svc.CreateEvent(ctx, event1)
|
||||
assert.NoError(t, err1)
|
||||
assert.Equal(t, 201, result1.StatusCode)
|
||||
|
||||
// 重放异参
|
||||
result2, err2 := svc.CreateEvent(ctx, event2)
|
||||
assert.NoError(t, err2)
|
||||
assert.Equal(t, 409, result2.StatusCode)
|
||||
assert.Equal(t, "IDEMPOTENCY_PAYLOAD_MISMATCH", result2.ErrorCode)
|
||||
}
|
||||
|
||||
func TestAuditService_CreateEvent_InProgress(t *testing.T) {
|
||||
// 202 处理中(模拟异步场景)
|
||||
ctx := context.Background()
|
||||
svc := NewAuditService(NewInMemoryAuditStore())
|
||||
|
||||
// 启用处理中模拟
|
||||
svc.SetProcessingDelay(100 * time.Millisecond)
|
||||
|
||||
event := &model.AuditEvent{
|
||||
EventName: "CRED-DIRECT-SUPPLIER",
|
||||
EventCategory: "CRED",
|
||||
OperatorID: 1001,
|
||||
TenantID: 2001,
|
||||
ObjectType: "api",
|
||||
ObjectID: 12345,
|
||||
Action: "call",
|
||||
CredentialType: "none",
|
||||
SourceType: "api",
|
||||
SourceIP: "192.168.1.1",
|
||||
Success: false,
|
||||
ResultCode: "SEC_DIRECT_BYPASS",
|
||||
IdempotencyKey: "idem-key-004",
|
||||
}
|
||||
|
||||
// 由于是异步处理,这里返回202
|
||||
// 注意:在实际实现中,可能需要处理并发场景
|
||||
result, err := svc.CreateEvent(ctx, event)
|
||||
assert.NoError(t, err)
|
||||
// 同步处理场景下可能是201或202
|
||||
assert.True(t, result.StatusCode == 201 || result.StatusCode == 202)
|
||||
}
|
||||
|
||||
func TestAuditService_CreateEvent_WithoutIdempotencyKey(t *testing.T) {
|
||||
// 无幂等键时每次都创建新事件
|
||||
ctx := context.Background()
|
||||
svc := NewAuditService(NewInMemoryAuditStore())
|
||||
|
||||
event := &model.AuditEvent{
|
||||
EventName: "AUTH-TOKEN-OK",
|
||||
EventCategory: "AUTH",
|
||||
OperatorID: 1001,
|
||||
TenantID: 2001,
|
||||
ObjectType: "token",
|
||||
ObjectID: 12345,
|
||||
Action: "verify",
|
||||
CredentialType: "platform_token",
|
||||
SourceType: "api",
|
||||
SourceIP: "192.168.1.1",
|
||||
Success: true,
|
||||
ResultCode: "AUTH_TOKEN_OK",
|
||||
// 无 IdempotencyKey
|
||||
}
|
||||
|
||||
result1, err1 := svc.CreateEvent(ctx, event)
|
||||
assert.NoError(t, err1)
|
||||
assert.Equal(t, 201, result1.StatusCode)
|
||||
|
||||
// 再次创建,由于没有幂等键,应该创建新事件
|
||||
// 注意:需要重置event.EventID,否则会认为是同一个事件
|
||||
event.EventID = ""
|
||||
result2, err2 := svc.CreateEvent(ctx, event)
|
||||
assert.NoError(t, err2)
|
||||
assert.Equal(t, 201, result2.StatusCode)
|
||||
assert.NotEqual(t, result1.EventID, result2.EventID)
|
||||
}
|
||||
|
||||
func TestAuditService_CreateEvent_InvalidInput(t *testing.T) {
|
||||
// 测试无效输入
|
||||
ctx := context.Background()
|
||||
svc := NewAuditService(NewInMemoryAuditStore())
|
||||
|
||||
// 空事件
|
||||
result, err := svc.CreateEvent(ctx, nil)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
|
||||
// 缺少必填字段
|
||||
invalidEvent := &model.AuditEvent{
|
||||
EventName: "", // 缺少事件名
|
||||
}
|
||||
result, err = svc.CreateEvent(ctx, invalidEvent)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
}
|
||||
|
||||
// ==================== 查询API测试 ====================
|
||||
|
||||
func TestAuditService_ListEvents_Pagination(t *testing.T) {
|
||||
// 分页测试
|
||||
ctx := context.Background()
|
||||
svc := NewAuditService(NewInMemoryAuditStore())
|
||||
|
||||
// 创建10个事件
|
||||
for i := 0; i < 10; i++ {
|
||||
event := &model.AuditEvent{
|
||||
EventName: "AUTH-TOKEN-OK",
|
||||
EventCategory: "AUTH",
|
||||
OperatorID: int64(1001 + i),
|
||||
TenantID: 2001,
|
||||
ObjectType: "token",
|
||||
ObjectID: int64(i),
|
||||
Action: "verify",
|
||||
CredentialType: "platform_token",
|
||||
SourceType: "api",
|
||||
SourceIP: "192.168.1.1",
|
||||
Success: true,
|
||||
ResultCode: "AUTH_TOKEN_OK",
|
||||
}
|
||||
svc.CreateEvent(ctx, event)
|
||||
}
|
||||
|
||||
// 第一页
|
||||
events1, total1, err1 := svc.ListEvents(ctx, 2001, 0, 5)
|
||||
assert.NoError(t, err1)
|
||||
assert.Len(t, events1, 5)
|
||||
assert.Equal(t, int64(10), total1)
|
||||
|
||||
// 第二页
|
||||
events2, total2, err2 := svc.ListEvents(ctx, 2001, 5, 5)
|
||||
assert.NoError(t, err2)
|
||||
assert.Len(t, events2, 5)
|
||||
assert.Equal(t, int64(10), total2)
|
||||
}
|
||||
|
||||
func TestAuditService_ListEvents_FilterByCategory(t *testing.T) {
|
||||
// 按类别过滤
|
||||
ctx := context.Background()
|
||||
svc := NewAuditService(NewInMemoryAuditStore())
|
||||
|
||||
// 创建不同类别的事件
|
||||
categories := []string{"AUTH", "CRED", "DATA", "CONFIG"}
|
||||
for i, cat := range categories {
|
||||
event := &model.AuditEvent{
|
||||
EventName: cat + "-TEST",
|
||||
EventCategory: cat,
|
||||
OperatorID: 1001,
|
||||
TenantID: 2001,
|
||||
ObjectType: "test",
|
||||
ObjectID: int64(i),
|
||||
Action: "test",
|
||||
CredentialType: "platform_token",
|
||||
SourceType: "api",
|
||||
SourceIP: "192.168.1.1",
|
||||
Success: true,
|
||||
ResultCode: "TEST_OK",
|
||||
}
|
||||
svc.CreateEvent(ctx, event)
|
||||
}
|
||||
|
||||
// 只查询AUTH类别
|
||||
filter := &EventFilter{
|
||||
TenantID: 2001,
|
||||
Category: "AUTH",
|
||||
}
|
||||
events, total, err := svc.ListEventsWithFilter(ctx, filter)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, events, 1)
|
||||
assert.Equal(t, int64(1), total)
|
||||
assert.Equal(t, "AUTH", events[0].EventCategory)
|
||||
}
|
||||
|
||||
func TestAuditService_ListEvents_FilterByTimeRange(t *testing.T) {
|
||||
// 按时间范围过滤
|
||||
ctx := context.Background()
|
||||
svc := NewAuditService(NewInMemoryAuditStore())
|
||||
|
||||
now := time.Now()
|
||||
event := &model.AuditEvent{
|
||||
EventName: "AUTH-TOKEN-OK",
|
||||
EventCategory: "AUTH",
|
||||
OperatorID: 1001,
|
||||
TenantID: 2001,
|
||||
ObjectType: "token",
|
||||
ObjectID: 12345,
|
||||
Action: "verify",
|
||||
CredentialType: "platform_token",
|
||||
SourceType: "api",
|
||||
SourceIP: "192.168.1.1",
|
||||
Success: true,
|
||||
ResultCode: "AUTH_TOKEN_OK",
|
||||
}
|
||||
svc.CreateEvent(ctx, event)
|
||||
|
||||
// 在时间范围内
|
||||
filter := &EventFilter{
|
||||
TenantID: 2001,
|
||||
StartTime: now.Add(-1 * time.Hour),
|
||||
EndTime: now.Add(1 * time.Hour),
|
||||
}
|
||||
events, total, err := svc.ListEventsWithFilter(ctx, filter)
|
||||
assert.NoError(t, err)
|
||||
assert.GreaterOrEqual(t, len(events), 1)
|
||||
assert.GreaterOrEqual(t, total, int64(len(events)))
|
||||
|
||||
// 在时间范围外
|
||||
filter2 := &EventFilter{
|
||||
TenantID: 2001,
|
||||
StartTime: now.Add(1 * time.Hour),
|
||||
EndTime: now.Add(2 * time.Hour),
|
||||
}
|
||||
events2, total2, err2 := svc.ListEventsWithFilter(ctx, filter2)
|
||||
assert.NoError(t, err2)
|
||||
assert.Equal(t, 0, len(events2))
|
||||
assert.Equal(t, int64(0), total2)
|
||||
}
|
||||
|
||||
func TestAuditService_ListEvents_FilterByEventName(t *testing.T) {
|
||||
// 按事件名称过滤
|
||||
ctx := context.Background()
|
||||
svc := NewAuditService(NewInMemoryAuditStore())
|
||||
|
||||
event1 := &model.AuditEvent{
|
||||
EventName: "CRED-EXPOSE-RESPONSE",
|
||||
EventCategory: "CRED",
|
||||
OperatorID: 1001,
|
||||
TenantID: 2001,
|
||||
ObjectType: "account",
|
||||
ObjectID: 12345,
|
||||
Action: "create",
|
||||
CredentialType: "platform_token",
|
||||
SourceType: "api",
|
||||
SourceIP: "192.168.1.1",
|
||||
Success: true,
|
||||
ResultCode: "SEC_CRED_EXPOSED",
|
||||
}
|
||||
event2 := &model.AuditEvent{
|
||||
EventName: "CRED-INGRESS-PLATFORM",
|
||||
EventCategory: "CRED",
|
||||
OperatorID: 1001,
|
||||
TenantID: 2001,
|
||||
ObjectType: "account",
|
||||
ObjectID: 12345,
|
||||
Action: "query",
|
||||
CredentialType: "platform_token",
|
||||
SourceType: "api",
|
||||
SourceIP: "192.168.1.1",
|
||||
Success: true,
|
||||
ResultCode: "CRED_INGRESS_OK",
|
||||
}
|
||||
|
||||
svc.CreateEvent(ctx, event1)
|
||||
svc.CreateEvent(ctx, event2)
|
||||
|
||||
// 按事件名称过滤
|
||||
filter := &EventFilter{
|
||||
TenantID: 2001,
|
||||
EventName: "CRED-EXPOSE-RESPONSE",
|
||||
}
|
||||
events, total, err := svc.ListEventsWithFilter(ctx, filter)
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, events, 1)
|
||||
assert.Equal(t, "CRED-EXPOSE-RESPONSE", events[0].EventName)
|
||||
assert.Equal(t, int64(1), total)
|
||||
}
|
||||
|
||||
// ==================== 辅助函数测试 ====================
|
||||
|
||||
func TestAuditService_HashIdempotencyKey(t *testing.T) {
|
||||
// 测试幂等键哈希
|
||||
svc := NewAuditService(NewInMemoryAuditStore())
|
||||
|
||||
key := "test-idempotency-key"
|
||||
hash1 := svc.HashIdempotencyKey(key)
|
||||
hash2 := svc.HashIdempotencyKey(key)
|
||||
|
||||
// 相同键应产生相同哈希
|
||||
assert.Equal(t, hash1, hash2)
|
||||
|
||||
// 不同键应产生不同哈希
|
||||
hash3 := svc.HashIdempotencyKey("different-key")
|
||||
assert.NotEqual(t, hash1, hash3)
|
||||
}
|
||||
312
supply-api/internal/audit/service/metrics_service.go
Normal file
312
supply-api/internal/audit/service/metrics_service.go
Normal file
@@ -0,0 +1,312 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"lijiaoqiao/supply-api/internal/audit/model"
|
||||
)
|
||||
|
||||
// Metric 指标结构
|
||||
type Metric struct {
|
||||
MetricID string `json:"metric_id"`
|
||||
MetricName string `json:"metric_name"`
|
||||
Period *MetricPeriod `json:"period"`
|
||||
Value float64 `json:"value"`
|
||||
Unit string `json:"unit"`
|
||||
Status string `json:"status"` // PASS/FAIL
|
||||
Details map[string]interface{} `json:"details"`
|
||||
}
|
||||
|
||||
// MetricPeriod 指标周期
|
||||
type MetricPeriod struct {
|
||||
Start time.Time `json:"start"`
|
||||
End time.Time `json:"end"`
|
||||
}
|
||||
|
||||
// MetricsService 指标服务
|
||||
type MetricsService struct {
|
||||
auditSvc *AuditService
|
||||
}
|
||||
|
||||
// NewMetricsService 创建指标服务
|
||||
func NewMetricsService(auditSvc *AuditService) *MetricsService {
|
||||
return &MetricsService{
|
||||
auditSvc: auditSvc,
|
||||
}
|
||||
}
|
||||
|
||||
// CalculateM013 计算M-013指标:凭证泄露事件数 = 0
|
||||
func (s *MetricsService) CalculateM013(ctx context.Context, start, end time.Time) (*Metric, error) {
|
||||
filter := &EventFilter{
|
||||
StartTime: start,
|
||||
EndTime: end,
|
||||
Limit: 10000,
|
||||
}
|
||||
|
||||
events, _, err := s.auditSvc.ListEventsWithFilter(ctx, filter)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 统计CRED-EXPOSE事件数
|
||||
exposureCount := 0
|
||||
unresolvedCount := 0
|
||||
for _, e := range events {
|
||||
if model.IsM013Event(e.EventName) {
|
||||
exposureCount++
|
||||
// 检查是否已解决(通过扩展字段或标记判断)
|
||||
if s.isEventUnresolved(e) {
|
||||
unresolvedCount++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
metric := &Metric{
|
||||
MetricID: "M-013",
|
||||
MetricName: "supplier_credential_exposure_events",
|
||||
Period: &MetricPeriod{
|
||||
Start: start,
|
||||
End: end,
|
||||
},
|
||||
Value: float64(exposureCount),
|
||||
Unit: "count",
|
||||
Status: "PASS",
|
||||
Details: map[string]interface{}{
|
||||
"total_exposure_events": exposureCount,
|
||||
"unresolved_events": unresolvedCount,
|
||||
},
|
||||
}
|
||||
|
||||
// 判断状态:M-013要求暴露事件数为0
|
||||
if exposureCount > 0 {
|
||||
metric.Status = "FAIL"
|
||||
}
|
||||
|
||||
return metric, nil
|
||||
}
|
||||
|
||||
// CalculateM014 计算M-014指标:平台凭证入站覆盖率 = 100%
|
||||
// 分母定义:经平台凭证校验的入站请求(credential_type = 'platform_token'),不含被拒绝的无效请求
|
||||
func (s *MetricsService) CalculateM014(ctx context.Context, start, end time.Time) (*Metric, error) {
|
||||
filter := &EventFilter{
|
||||
StartTime: start,
|
||||
EndTime: end,
|
||||
Limit: 10000,
|
||||
}
|
||||
|
||||
events, _, err := s.auditSvc.ListEventsWithFilter(ctx, filter)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 统计CRED-INGRESS-PLATFORM事件(只有这个才算入M-014)
|
||||
var platformCount, totalIngressCount int
|
||||
for _, e := range events {
|
||||
// M-014只统计CRED-INGRESS-PLATFORM事件
|
||||
if e.EventName == "CRED-INGRESS-PLATFORM" {
|
||||
totalIngressCount++
|
||||
// M-014分母:platform_token请求
|
||||
if e.CredentialType == model.CredentialTypePlatformToken {
|
||||
platformCount++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 计算覆盖率
|
||||
var coveragePct float64
|
||||
if totalIngressCount > 0 {
|
||||
coveragePct = float64(platformCount) / float64(totalIngressCount) * 100
|
||||
} else {
|
||||
coveragePct = 100.0 // 没有入站请求时,默认为100%
|
||||
}
|
||||
|
||||
metric := &Metric{
|
||||
MetricID: "M-014",
|
||||
MetricName: "platform_credential_ingress_coverage_pct",
|
||||
Period: &MetricPeriod{
|
||||
Start: start,
|
||||
End: end,
|
||||
},
|
||||
Value: coveragePct,
|
||||
Unit: "percentage",
|
||||
Status: "PASS",
|
||||
Details: map[string]interface{}{
|
||||
"platform_token_requests": platformCount,
|
||||
"total_requests": totalIngressCount,
|
||||
"non_compliant_requests": totalIngressCount - platformCount,
|
||||
},
|
||||
}
|
||||
|
||||
// 判断状态:M-014要求覆盖率为100%
|
||||
if coveragePct < 100.0 {
|
||||
metric.Status = "FAIL"
|
||||
}
|
||||
|
||||
return metric, nil
|
||||
}
|
||||
|
||||
// CalculateM015 计算M-015指标:直连绕过事件数 = 0
|
||||
func (s *MetricsService) CalculateM015(ctx context.Context, start, end time.Time) (*Metric, error) {
|
||||
filter := &EventFilter{
|
||||
StartTime: start,
|
||||
EndTime: end,
|
||||
Limit: 10000,
|
||||
}
|
||||
|
||||
events, _, err := s.auditSvc.ListEventsWithFilter(ctx, filter)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 统计CRED-DIRECT事件数
|
||||
directCallCount := 0
|
||||
blockedCount := 0
|
||||
for _, e := range events {
|
||||
if model.IsM015Event(e.EventName) {
|
||||
directCallCount++
|
||||
// 检查是否被阻断
|
||||
if s.isEventBlocked(e) {
|
||||
blockedCount++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
metric := &Metric{
|
||||
MetricID: "M-015",
|
||||
MetricName: "direct_supplier_call_by_consumer_events",
|
||||
Period: &MetricPeriod{
|
||||
Start: start,
|
||||
End: end,
|
||||
},
|
||||
Value: float64(directCallCount),
|
||||
Unit: "count",
|
||||
Status: "PASS",
|
||||
Details: map[string]interface{}{
|
||||
"total_direct_call_events": directCallCount,
|
||||
"blocked_events": blockedCount,
|
||||
},
|
||||
}
|
||||
|
||||
// 判断状态:M-015要求直连事件数为0
|
||||
if directCallCount > 0 {
|
||||
metric.Status = "FAIL"
|
||||
}
|
||||
|
||||
return metric, nil
|
||||
}
|
||||
|
||||
// CalculateM016 计算M-016指标:query key外部拒绝率 = 100%
|
||||
// 分母定义:检测到的所有query key请求,含被拒绝的请求
|
||||
func (s *MetricsService) CalculateM016(ctx context.Context, start, end time.Time) (*Metric, error) {
|
||||
filter := &EventFilter{
|
||||
StartTime: start,
|
||||
EndTime: end,
|
||||
Limit: 10000,
|
||||
}
|
||||
|
||||
events, _, err := s.auditSvc.ListEventsWithFilter(ctx, filter)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 统计AUTH-QUERY-*事件
|
||||
var totalQueryKey, rejectedCount int
|
||||
rejectBreakdown := make(map[string]int)
|
||||
for _, e := range events {
|
||||
if model.IsM016Event(e.EventName) {
|
||||
totalQueryKey++
|
||||
if e.EventName == "AUTH-QUERY-REJECT" {
|
||||
rejectedCount++
|
||||
rejectBreakdown[e.ResultCode]++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 计算拒绝率
|
||||
var rejectRate float64
|
||||
if totalQueryKey > 0 {
|
||||
rejectRate = float64(rejectedCount) / float64(totalQueryKey) * 100
|
||||
} else {
|
||||
rejectRate = 100.0 // 没有query key请求时,默认为100%
|
||||
}
|
||||
|
||||
metric := &Metric{
|
||||
MetricID: "M-016",
|
||||
MetricName: "query_key_external_reject_rate_pct",
|
||||
Period: &MetricPeriod{
|
||||
Start: start,
|
||||
End: end,
|
||||
},
|
||||
Value: rejectRate,
|
||||
Unit: "percentage",
|
||||
Status: "PASS",
|
||||
Details: map[string]interface{}{
|
||||
"rejected_requests": rejectedCount,
|
||||
"total_external_query_key_requests": totalQueryKey,
|
||||
"reject_breakdown": rejectBreakdown,
|
||||
},
|
||||
}
|
||||
|
||||
// 判断状态:M-016要求拒绝率为100%(所有外部query key请求都被拒绝)
|
||||
if rejectRate < 100.0 {
|
||||
metric.Status = "FAIL"
|
||||
}
|
||||
|
||||
return metric, nil
|
||||
}
|
||||
|
||||
// isEventUnresolved 检查事件是否未解决
|
||||
func (s *MetricsService) isEventUnresolved(e *model.AuditEvent) bool {
|
||||
// 如果事件成功,表示已处理/已解决
|
||||
// 如果事件失败,表示有问题/未解决
|
||||
return !e.Success
|
||||
}
|
||||
|
||||
// isEventBlocked 检查直连事件是否被阻断
|
||||
func (s *MetricsService) isEventBlocked(e *model.AuditEvent) bool {
|
||||
// 通过检查扩展字段或Success标志来判断是否被阻断
|
||||
if e.Success {
|
||||
return false // 成功表示未被阻断
|
||||
}
|
||||
|
||||
// 检查扩展字段中的blocked标记
|
||||
if e.Extensions != nil {
|
||||
if blocked, ok := e.Extensions["blocked"].(bool); ok {
|
||||
return blocked
|
||||
}
|
||||
}
|
||||
|
||||
// 通过结果码判断
|
||||
switch e.ResultCode {
|
||||
case "SEC_DIRECT_BYPASS", "SEC_DIRECT_BYPASS_BLOCKED":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// GetAllMetrics 获取所有M-013~M-016指标
|
||||
func (s *MetricsService) GetAllMetrics(ctx context.Context, start, end time.Time) ([]*Metric, error) {
|
||||
m013, err := s.CalculateM013(ctx, start, end)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m014, err := s.CalculateM014(ctx, start, end)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m015, err := s.CalculateM015(ctx, start, end)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m016, err := s.CalculateM016(ctx, start, end)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return []*Metric{m013, m014, m015, m016}, nil
|
||||
}
|
||||
376
supply-api/internal/audit/service/metrics_service_test.go
Normal file
376
supply-api/internal/audit/service/metrics_service_test.go
Normal file
@@ -0,0 +1,376 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"lijiaoqiao/supply-api/internal/audit/model"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestAuditMetrics_M013_CredentialExposure(t *testing.T) {
|
||||
// M-013: supplier_credential_exposure_events = 0
|
||||
ctx := context.Background()
|
||||
svc := NewAuditService(NewInMemoryAuditStore())
|
||||
metricsSvc := NewMetricsService(svc)
|
||||
|
||||
// 创建一些事件,包括CRED-EXPOSE事件
|
||||
events := []*model.AuditEvent{
|
||||
{
|
||||
EventName: "CRED-EXPOSE-RESPONSE",
|
||||
EventCategory: "CRED",
|
||||
OperatorID: 1001,
|
||||
TenantID: 2001,
|
||||
ObjectType: "account",
|
||||
ObjectID: 12345,
|
||||
Action: "create",
|
||||
CredentialType: "platform_token",
|
||||
SourceType: "api",
|
||||
SourceIP: "192.168.1.1",
|
||||
Success: true,
|
||||
ResultCode: "SEC_CRED_EXPOSED",
|
||||
},
|
||||
{
|
||||
EventName: "AUTH-TOKEN-OK",
|
||||
EventCategory: "AUTH",
|
||||
OperatorID: 1001,
|
||||
TenantID: 2001,
|
||||
ObjectType: "token",
|
||||
ObjectID: 12345,
|
||||
Action: "verify",
|
||||
CredentialType: "platform_token",
|
||||
SourceType: "api",
|
||||
SourceIP: "192.168.1.1",
|
||||
Success: true,
|
||||
ResultCode: "AUTH_TOKEN_OK",
|
||||
},
|
||||
}
|
||||
|
||||
for _, e := range events {
|
||||
svc.CreateEvent(ctx, e)
|
||||
}
|
||||
|
||||
// 计算M-013指标
|
||||
now := time.Now()
|
||||
metric, err := metricsSvc.CalculateM013(ctx, now.Add(-24*time.Hour), now)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, metric)
|
||||
assert.Equal(t, "M-013", metric.MetricID)
|
||||
assert.Equal(t, "supplier_credential_exposure_events", metric.MetricName)
|
||||
assert.Equal(t, float64(1), metric.Value) // 有1个暴露事件
|
||||
assert.Equal(t, "FAIL", metric.Status) // 暴露事件数 > 0,应该是FAIL
|
||||
}
|
||||
|
||||
func TestAuditMetrics_M014_IngressCoverage(t *testing.T) {
|
||||
// M-014: platform_credential_ingress_coverage_pct = 100%
|
||||
// 分母定义:经平台凭证校验的入站请求(credential_type = 'platform_token'),不含被拒绝的无效请求
|
||||
ctx := context.Background()
|
||||
svc := NewAuditService(NewInMemoryAuditStore())
|
||||
metricsSvc := NewMetricsService(svc)
|
||||
|
||||
// 创建入站凭证事件
|
||||
events := []*model.AuditEvent{
|
||||
// 合规的platform_token请求
|
||||
{
|
||||
EventName: "CRED-INGRESS-PLATFORM",
|
||||
EventCategory: "CRED",
|
||||
EventSubCategory: "INGRESS",
|
||||
OperatorID: 1001,
|
||||
TenantID: 2001,
|
||||
ObjectType: "account",
|
||||
ObjectID: 12345,
|
||||
Action: "query",
|
||||
CredentialType: "platform_token",
|
||||
SourceType: "api",
|
||||
SourceIP: "192.168.1.1",
|
||||
Success: true,
|
||||
ResultCode: "CRED_INGRESS_OK",
|
||||
},
|
||||
{
|
||||
EventName: "CRED-INGRESS-PLATFORM",
|
||||
EventCategory: "CRED",
|
||||
EventSubCategory: "INGRESS",
|
||||
OperatorID: 1002,
|
||||
TenantID: 2001,
|
||||
ObjectType: "account",
|
||||
ObjectID: 12346,
|
||||
Action: "query",
|
||||
CredentialType: "platform_token",
|
||||
SourceType: "api",
|
||||
SourceIP: "192.168.1.2",
|
||||
Success: true,
|
||||
ResultCode: "CRED_INGRESS_OK",
|
||||
},
|
||||
// 非合规的query_key请求 - 不应该计入M-014的分母
|
||||
{
|
||||
EventName: "CRED-INGRESS-SUPPLIER",
|
||||
EventCategory: "CRED",
|
||||
EventSubCategory: "INGRESS",
|
||||
OperatorID: 1003,
|
||||
TenantID: 2001,
|
||||
ObjectType: "account",
|
||||
ObjectID: 12347,
|
||||
Action: "query",
|
||||
CredentialType: "query_key",
|
||||
SourceType: "api",
|
||||
SourceIP: "192.168.1.3",
|
||||
Success: false,
|
||||
ResultCode: "AUTH_QUERY_REJECT",
|
||||
},
|
||||
}
|
||||
|
||||
for _, e := range events {
|
||||
svc.CreateEvent(ctx, e)
|
||||
}
|
||||
|
||||
// 计算M-014指标
|
||||
now := time.Now()
|
||||
metric, err := metricsSvc.CalculateM014(ctx, now.Add(-24*time.Hour), now)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, metric)
|
||||
assert.Equal(t, "M-014", metric.MetricID)
|
||||
assert.Equal(t, "platform_credential_ingress_coverage_pct", metric.MetricName)
|
||||
// 2个platform_token / 2个总入站请求 = 100%
|
||||
assert.Equal(t, 100.0, metric.Value)
|
||||
assert.Equal(t, "PASS", metric.Status)
|
||||
}
|
||||
|
||||
func TestAuditMetrics_M015_DirectCall(t *testing.T) {
|
||||
// M-015: direct_supplier_call_by_consumer_events = 0
|
||||
ctx := context.Background()
|
||||
svc := NewAuditService(NewInMemoryAuditStore())
|
||||
metricsSvc := NewMetricsService(svc)
|
||||
|
||||
// 创建直连事件
|
||||
events := []*model.AuditEvent{
|
||||
{
|
||||
EventName: "CRED-DIRECT-SUPPLIER",
|
||||
EventCategory: "CRED",
|
||||
EventSubCategory: "DIRECT",
|
||||
OperatorID: 1001,
|
||||
TenantID: 2001,
|
||||
ObjectType: "api",
|
||||
ObjectID: 12345,
|
||||
Action: "call",
|
||||
CredentialType: "none",
|
||||
SourceType: "api",
|
||||
SourceIP: "192.168.1.1",
|
||||
Success: false,
|
||||
ResultCode: "SEC_DIRECT_BYPASS",
|
||||
TargetDirect: true,
|
||||
},
|
||||
{
|
||||
EventName: "AUTH-TOKEN-OK",
|
||||
EventCategory: "AUTH",
|
||||
OperatorID: 1001,
|
||||
TenantID: 2001,
|
||||
ObjectType: "token",
|
||||
ObjectID: 12345,
|
||||
Action: "verify",
|
||||
CredentialType: "platform_token",
|
||||
SourceType: "api",
|
||||
SourceIP: "192.168.1.1",
|
||||
Success: true,
|
||||
ResultCode: "AUTH_TOKEN_OK",
|
||||
},
|
||||
}
|
||||
|
||||
for _, e := range events {
|
||||
svc.CreateEvent(ctx, e)
|
||||
}
|
||||
|
||||
// 计算M-015指标
|
||||
now := time.Now()
|
||||
metric, err := metricsSvc.CalculateM015(ctx, now.Add(-24*time.Hour), now)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, metric)
|
||||
assert.Equal(t, "M-015", metric.MetricID)
|
||||
assert.Equal(t, "direct_supplier_call_by_consumer_events", metric.MetricName)
|
||||
assert.Equal(t, float64(1), metric.Value) // 有1个直连事件
|
||||
assert.Equal(t, "FAIL", metric.Status) // 直连事件数 > 0,应该是FAIL
|
||||
}
|
||||
|
||||
func TestAuditMetrics_M016_QueryKeyRejectRate(t *testing.T) {
|
||||
// M-016: query_key_external_reject_rate_pct = 100%
|
||||
// 分母:所有query key请求(不含被拒绝的无效请求)
|
||||
ctx := context.Background()
|
||||
svc := NewAuditService(NewInMemoryAuditStore())
|
||||
metricsSvc := NewMetricsService(svc)
|
||||
|
||||
// 创建query key事件
|
||||
events := []*model.AuditEvent{
|
||||
// 被拒绝的query key请求
|
||||
{
|
||||
EventName: "AUTH-QUERY-REJECT",
|
||||
EventCategory: "AUTH",
|
||||
OperatorID: 1001,
|
||||
TenantID: 2001,
|
||||
ObjectType: "query_key",
|
||||
ObjectID: 12345,
|
||||
Action: "query",
|
||||
CredentialType: "query_key",
|
||||
SourceType: "api",
|
||||
SourceIP: "192.168.1.1",
|
||||
Success: false,
|
||||
ResultCode: "QUERY_KEY_NOT_ALLOWED",
|
||||
},
|
||||
{
|
||||
EventName: "AUTH-QUERY-REJECT",
|
||||
EventCategory: "AUTH",
|
||||
OperatorID: 1002,
|
||||
TenantID: 2001,
|
||||
ObjectType: "query_key",
|
||||
ObjectID: 12346,
|
||||
Action: "query",
|
||||
CredentialType: "query_key",
|
||||
SourceType: "api",
|
||||
SourceIP: "192.168.1.2",
|
||||
Success: false,
|
||||
ResultCode: "QUERY_KEY_EXPIRED",
|
||||
},
|
||||
// query key请求
|
||||
{
|
||||
EventName: "AUTH-QUERY-KEY",
|
||||
EventCategory: "AUTH",
|
||||
OperatorID: 1003,
|
||||
TenantID: 2001,
|
||||
ObjectType: "query_key",
|
||||
ObjectID: 12347,
|
||||
Action: "query",
|
||||
CredentialType: "query_key",
|
||||
SourceType: "api",
|
||||
SourceIP: "192.168.1.3",
|
||||
Success: false,
|
||||
ResultCode: "QUERY_KEY_EXPIRED",
|
||||
},
|
||||
// 非query key事件
|
||||
{
|
||||
EventName: "AUTH-TOKEN-OK",
|
||||
EventCategory: "AUTH",
|
||||
OperatorID: 1001,
|
||||
TenantID: 2001,
|
||||
ObjectType: "token",
|
||||
ObjectID: 12345,
|
||||
Action: "verify",
|
||||
CredentialType: "platform_token",
|
||||
SourceType: "api",
|
||||
SourceIP: "192.168.1.1",
|
||||
Success: true,
|
||||
ResultCode: "AUTH_TOKEN_OK",
|
||||
},
|
||||
}
|
||||
|
||||
for _, e := range events {
|
||||
svc.CreateEvent(ctx, e)
|
||||
}
|
||||
|
||||
// 计算M-016指标
|
||||
now := time.Now()
|
||||
metric, err := metricsSvc.CalculateM016(ctx, now.Add(-24*time.Hour), now)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, metric)
|
||||
assert.Equal(t, "M-016", metric.MetricID)
|
||||
assert.Equal(t, "query_key_external_reject_rate_pct", metric.MetricName)
|
||||
// 2个拒绝 / 3个query key总请求 = 66.67%
|
||||
assert.InDelta(t, 66.67, metric.Value, 0.01)
|
||||
assert.Equal(t, "FAIL", metric.Status) // 拒绝率 < 100%,应该是FAIL
|
||||
}
|
||||
|
||||
func TestAuditMetrics_M016_DifferentFromM014(t *testing.T) {
|
||||
// M-014与M-016边界清晰:分母不同,无重叠
|
||||
// M-014 分母:经平台凭证校验的入站请求(platform_token)
|
||||
// M-016 分母:检测到的所有query key请求
|
||||
|
||||
ctx := context.Background()
|
||||
svc := NewAuditService(NewInMemoryAuditStore())
|
||||
metricsSvc := NewMetricsService(svc)
|
||||
|
||||
// 场景:100个请求,80个使用platform_token,20个使用query key(被拒绝)
|
||||
// M-014 = 80/80 = 100%(分母只计算platform_token请求)
|
||||
// M-016 = 20/20 = 100%(分母计算所有query key请求)
|
||||
|
||||
// 创建80个platform_token请求
|
||||
for i := 0; i < 80; i++ {
|
||||
svc.CreateEvent(ctx, &model.AuditEvent{
|
||||
EventName: "CRED-INGRESS-PLATFORM",
|
||||
EventCategory: "CRED",
|
||||
EventSubCategory: "INGRESS",
|
||||
OperatorID: int64(1000 + i),
|
||||
TenantID: 2001,
|
||||
ObjectType: "account",
|
||||
ObjectID: int64(i),
|
||||
Action: "query",
|
||||
CredentialType: "platform_token",
|
||||
SourceType: "api",
|
||||
SourceIP: "192.168.1.1",
|
||||
Success: true,
|
||||
ResultCode: "CRED_INGRESS_OK",
|
||||
})
|
||||
}
|
||||
|
||||
// 创建20个query key请求(全部被拒绝)
|
||||
for i := 0; i < 20; i++ {
|
||||
svc.CreateEvent(ctx, &model.AuditEvent{
|
||||
EventName: "AUTH-QUERY-REJECT",
|
||||
EventCategory: "AUTH",
|
||||
OperatorID: int64(2000 + i),
|
||||
TenantID: 2001,
|
||||
ObjectType: "query_key",
|
||||
ObjectID: int64(1000 + i),
|
||||
Action: "query",
|
||||
CredentialType: "query_key",
|
||||
SourceType: "api",
|
||||
SourceIP: "192.168.1.1",
|
||||
Success: false,
|
||||
ResultCode: "QUERY_KEY_NOT_ALLOWED",
|
||||
})
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
|
||||
// 计算M-014
|
||||
m014, err := metricsSvc.CalculateM014(ctx, now.Add(-24*time.Hour), now)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 100.0, m014.Value) // 80/80 = 100%
|
||||
|
||||
// 计算M-016
|
||||
m016, err := metricsSvc.CalculateM016(ctx, now.Add(-24*time.Hour), now)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 100.0, m016.Value) // 20/20 = 100%
|
||||
}
|
||||
|
||||
func TestAuditMetrics_M013_ZeroExposure(t *testing.T) {
|
||||
// M-013: 当没有凭证暴露事件时,应该为0,状态PASS
|
||||
ctx := context.Background()
|
||||
svc := NewAuditService(NewInMemoryAuditStore())
|
||||
metricsSvc := NewMetricsService(svc)
|
||||
|
||||
// 创建一些正常事件,没有CRED-EXPOSE
|
||||
svc.CreateEvent(ctx, &model.AuditEvent{
|
||||
EventName: "AUTH-TOKEN-OK",
|
||||
EventCategory: "AUTH",
|
||||
OperatorID: 1001,
|
||||
TenantID: 2001,
|
||||
ObjectType: "token",
|
||||
ObjectID: 12345,
|
||||
Action: "verify",
|
||||
CredentialType: "platform_token",
|
||||
SourceType: "api",
|
||||
SourceIP: "192.168.1.1",
|
||||
Success: true,
|
||||
ResultCode: "AUTH_TOKEN_OK",
|
||||
})
|
||||
|
||||
now := time.Now()
|
||||
metric, err := metricsSvc.CalculateM013(ctx, now.Add(-24*time.Hour), now)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, float64(0), metric.Value)
|
||||
assert.Equal(t, "PASS", metric.Status)
|
||||
}
|
||||
507
supply-api/internal/iam/handler/iam_handler.go
Normal file
507
supply-api/internal/iam/handler/iam_handler.go
Normal file
@@ -0,0 +1,507 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"lijiaoqiao/supply-api/internal/iam/service"
|
||||
)
|
||||
|
||||
// IAMHandler IAM HTTP处理器
|
||||
type IAMHandler struct {
|
||||
iamService service.IAMServiceInterface
|
||||
}
|
||||
|
||||
// NewIAMHandler 创建IAM处理器
|
||||
func NewIAMHandler(iamService service.IAMServiceInterface) *IAMHandler {
|
||||
return &IAMHandler{
|
||||
iamService: iamService,
|
||||
}
|
||||
}
|
||||
|
||||
// RoleResponse HTTP响应中的角色信息
|
||||
type RoleResponse struct {
|
||||
Code string `json:"role_code"`
|
||||
Name string `json:"role_name"`
|
||||
Type string `json:"role_type"`
|
||||
Level int `json:"level"`
|
||||
Scopes []string `json:"scopes,omitempty"`
|
||||
IsActive bool `json:"is_active"`
|
||||
}
|
||||
|
||||
// CreateRoleRequest 创建角色请求
|
||||
type CreateRoleRequest struct {
|
||||
Code string `json:"code"`
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type"`
|
||||
Level int `json:"level"`
|
||||
Scopes []string `json:"scopes"`
|
||||
}
|
||||
|
||||
// UpdateRoleRequest 更新角色请求
|
||||
type UpdateRoleRequest struct {
|
||||
Code string `json:"code"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Scopes []string `json:"scopes"`
|
||||
IsActive *bool `json:"is_active"`
|
||||
}
|
||||
|
||||
// AssignRoleRequest 分配角色请求
|
||||
type AssignRoleRequest struct {
|
||||
RoleCode string `json:"role_code"`
|
||||
TenantID int64 `json:"tenant_id"`
|
||||
ExpiresAt string `json:"expires_at,omitempty"`
|
||||
}
|
||||
|
||||
// HTTPError HTTP错误响应
|
||||
type HTTPError struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// ErrorResponse 错误响应结构
|
||||
type ErrorResponse struct {
|
||||
Error HTTPError `json:"error"`
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册IAM路由
|
||||
func (h *IAMHandler) RegisterRoutes(mux *http.ServeMux) {
|
||||
mux.HandleFunc("/api/v1/iam/roles", h.handleRoles)
|
||||
mux.HandleFunc("/api/v1/iam/roles/", h.handleRoleByCode)
|
||||
mux.HandleFunc("/api/v1/iam/scopes", h.handleScopes)
|
||||
mux.HandleFunc("/api/v1/iam/users/", h.handleUserRoles)
|
||||
mux.HandleFunc("/api/v1/iam/check-scope", h.handleCheckScope)
|
||||
}
|
||||
|
||||
// handleRoles 处理角色相关路由
|
||||
func (h *IAMHandler) handleRoles(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.Method {
|
||||
case http.MethodGet:
|
||||
h.ListRoles(w, r)
|
||||
case http.MethodPost:
|
||||
h.CreateRole(w, r)
|
||||
default:
|
||||
writeError(w, http.StatusMethodNotAllowed, "METHOD_NOT_ALLOWED", "method not allowed")
|
||||
}
|
||||
}
|
||||
|
||||
// handleRoleByCode 处理单个角色路由
|
||||
func (h *IAMHandler) handleRoleByCode(w http.ResponseWriter, r *http.Request) {
|
||||
roleCode := extractRoleCode(r.URL.Path)
|
||||
|
||||
switch r.Method {
|
||||
case http.MethodGet:
|
||||
h.GetRole(w, r, roleCode)
|
||||
case http.MethodPut:
|
||||
h.UpdateRole(w, r, roleCode)
|
||||
case http.MethodDelete:
|
||||
h.DeleteRole(w, r, roleCode)
|
||||
default:
|
||||
writeError(w, http.StatusMethodNotAllowed, "METHOD_NOT_ALLOWED", "method not allowed")
|
||||
}
|
||||
}
|
||||
|
||||
// handleScopes 处理Scope列表路由
|
||||
func (h *IAMHandler) handleScopes(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
writeError(w, http.StatusMethodNotAllowed, "METHOD_NOT_ALLOWED", "method not allowed")
|
||||
return
|
||||
}
|
||||
|
||||
h.ListScopes(w, r)
|
||||
}
|
||||
|
||||
// handleUserRoles 处理用户角色路由
|
||||
func (h *IAMHandler) handleUserRoles(w http.ResponseWriter, r *http.Request) {
|
||||
// 解析用户ID
|
||||
path := r.URL.Path
|
||||
userIDStr := extractUserID(path)
|
||||
userID, err := strconv.ParseInt(userIDStr, 10, 64)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "INVALID_USER_ID", "invalid user id")
|
||||
return
|
||||
}
|
||||
|
||||
switch r.Method {
|
||||
case http.MethodGet:
|
||||
h.GetUserRoles(w, r, userID)
|
||||
case http.MethodPost:
|
||||
h.AssignRole(w, r, userID)
|
||||
case http.MethodDelete:
|
||||
roleCode := extractRoleCodeFromUserPath(path)
|
||||
tenantID := int64(0) // 从请求或context获取
|
||||
h.RevokeRole(w, r, userID, roleCode, tenantID)
|
||||
default:
|
||||
writeError(w, http.StatusMethodNotAllowed, "METHOD_NOT_ALLOWED", "method not allowed")
|
||||
}
|
||||
}
|
||||
|
||||
// handleCheckScope 处理检查Scope路由
|
||||
func (h *IAMHandler) handleCheckScope(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
writeError(w, http.StatusMethodNotAllowed, "METHOD_NOT_ALLOWED", "method not allowed")
|
||||
return
|
||||
}
|
||||
|
||||
h.CheckScope(w, r)
|
||||
}
|
||||
|
||||
// CreateRole 处理创建角色请求
|
||||
func (h *IAMHandler) CreateRole(w http.ResponseWriter, r *http.Request) {
|
||||
var req CreateRoleRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "INVALID_REQUEST", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 验证必填字段
|
||||
if req.Code == "" {
|
||||
writeError(w, http.StatusBadRequest, "MISSING_CODE", "role code is required")
|
||||
return
|
||||
}
|
||||
if req.Name == "" {
|
||||
writeError(w, http.StatusBadRequest, "MISSING_NAME", "role name is required")
|
||||
return
|
||||
}
|
||||
if req.Type == "" {
|
||||
writeError(w, http.StatusBadRequest, "MISSING_TYPE", "role type is required")
|
||||
return
|
||||
}
|
||||
|
||||
serviceReq := &service.CreateRoleRequest{
|
||||
Code: req.Code,
|
||||
Name: req.Name,
|
||||
Type: req.Type,
|
||||
Level: req.Level,
|
||||
Scopes: req.Scopes,
|
||||
}
|
||||
|
||||
role, err := h.iamService.CreateRole(r.Context(), serviceReq)
|
||||
if err != nil {
|
||||
if err == service.ErrDuplicateRoleCode {
|
||||
writeError(w, http.StatusConflict, "DUPLICATE_ROLE_CODE", err.Error())
|
||||
return
|
||||
}
|
||||
if err == service.ErrInvalidRequest {
|
||||
writeError(w, http.StatusBadRequest, "INVALID_REQUEST", err.Error())
|
||||
return
|
||||
}
|
||||
writeError(w, http.StatusInternalServerError, "INTERNAL_ERROR", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusCreated, map[string]interface{}{
|
||||
"role": toRoleResponse(role),
|
||||
})
|
||||
}
|
||||
|
||||
// GetRole 处理获取单个角色请求
|
||||
func (h *IAMHandler) GetRole(w http.ResponseWriter, r *http.Request, roleCode string) {
|
||||
role, err := h.iamService.GetRole(r.Context(), roleCode)
|
||||
if err != nil {
|
||||
if err == service.ErrRoleNotFound {
|
||||
writeError(w, http.StatusNotFound, "ROLE_NOT_FOUND", err.Error())
|
||||
return
|
||||
}
|
||||
writeError(w, http.StatusInternalServerError, "INTERNAL_ERROR", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, map[string]interface{}{
|
||||
"role": toRoleResponse(role),
|
||||
})
|
||||
}
|
||||
|
||||
// ListRoles 处理列出角色请求
|
||||
func (h *IAMHandler) ListRoles(w http.ResponseWriter, r *http.Request) {
|
||||
roleType := r.URL.Query().Get("type")
|
||||
|
||||
roles, err := h.iamService.ListRoles(r.Context(), roleType)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "INTERNAL_ERROR", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
roleResponses := make([]*RoleResponse, len(roles))
|
||||
for i, role := range roles {
|
||||
roleResponses[i] = toRoleResponse(role)
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, map[string]interface{}{
|
||||
"roles": roleResponses,
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateRole 处理更新角色请求
|
||||
func (h *IAMHandler) UpdateRole(w http.ResponseWriter, r *http.Request, roleCode string) {
|
||||
var req UpdateRoleRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "INVALID_REQUEST", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
req.Code = roleCode // 确保使用URL中的roleCode
|
||||
|
||||
serviceReq := &service.UpdateRoleRequest{
|
||||
Code: req.Code,
|
||||
Name: req.Name,
|
||||
Description: req.Description,
|
||||
Scopes: req.Scopes,
|
||||
IsActive: req.IsActive,
|
||||
}
|
||||
|
||||
role, err := h.iamService.UpdateRole(r.Context(), serviceReq)
|
||||
if err != nil {
|
||||
if err == service.ErrRoleNotFound {
|
||||
writeError(w, http.StatusNotFound, "ROLE_NOT_FOUND", err.Error())
|
||||
return
|
||||
}
|
||||
writeError(w, http.StatusInternalServerError, "INTERNAL_ERROR", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, map[string]interface{}{
|
||||
"role": toRoleResponse(role),
|
||||
})
|
||||
}
|
||||
|
||||
// DeleteRole 处理删除角色请求
|
||||
func (h *IAMHandler) DeleteRole(w http.ResponseWriter, r *http.Request, roleCode string) {
|
||||
err := h.iamService.DeleteRole(r.Context(), roleCode)
|
||||
if err != nil {
|
||||
if err == service.ErrRoleNotFound {
|
||||
writeError(w, http.StatusNotFound, "ROLE_NOT_FOUND", err.Error())
|
||||
return
|
||||
}
|
||||
writeError(w, http.StatusInternalServerError, "INTERNAL_ERROR", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, map[string]interface{}{
|
||||
"message": "role deleted successfully",
|
||||
})
|
||||
}
|
||||
|
||||
// ListScopes 处理列出所有Scope请求
|
||||
func (h *IAMHandler) ListScopes(w http.ResponseWriter, r *http.Request) {
|
||||
// 从预定义Scope列表获取
|
||||
scopes := []map[string]interface{}{
|
||||
{"scope_code": "platform:read", "scope_name": "读取平台配置", "scope_type": "platform"},
|
||||
{"scope_code": "platform:write", "scope_name": "修改平台配置", "scope_type": "platform"},
|
||||
{"scope_code": "platform:admin", "scope_name": "平台级管理", "scope_type": "platform"},
|
||||
{"scope_code": "tenant:read", "scope_name": "读取租户信息", "scope_type": "platform"},
|
||||
{"scope_code": "supply:account:read", "scope_name": "读取供应账号", "scope_type": "supply"},
|
||||
{"scope_code": "consumer:apikey:create", "scope_name": "创建API Key", "scope_type": "consumer"},
|
||||
{"scope_code": "router:invoke", "scope_name": "调用模型", "scope_type": "router"},
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, map[string]interface{}{
|
||||
"scopes": scopes,
|
||||
})
|
||||
}
|
||||
|
||||
// GetUserRoles 处理获取用户角色请求
|
||||
func (h *IAMHandler) GetUserRoles(w http.ResponseWriter, r *http.Request, userID int64) {
|
||||
roles, err := h.iamService.GetUserRoles(r.Context(), userID)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "INTERNAL_ERROR", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, map[string]interface{}{
|
||||
"user_id": userID,
|
||||
"roles": roles,
|
||||
})
|
||||
}
|
||||
|
||||
// AssignRole 处理分配角色请求
|
||||
func (h *IAMHandler) AssignRole(w http.ResponseWriter, r *http.Request, userID int64) {
|
||||
var req AssignRoleRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "INVALID_REQUEST", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
serviceReq := &service.AssignRoleRequest{
|
||||
UserID: userID,
|
||||
RoleCode: req.RoleCode,
|
||||
TenantID: req.TenantID,
|
||||
}
|
||||
|
||||
mapping, err := h.iamService.AssignRole(r.Context(), serviceReq)
|
||||
if err != nil {
|
||||
if err == service.ErrRoleNotFound {
|
||||
writeError(w, http.StatusNotFound, "ROLE_NOT_FOUND", err.Error())
|
||||
return
|
||||
}
|
||||
if err == service.ErrDuplicateAssignment {
|
||||
writeError(w, http.StatusConflict, "DUPLICATE_ASSIGNMENT", err.Error())
|
||||
return
|
||||
}
|
||||
writeError(w, http.StatusInternalServerError, "INTERNAL_ERROR", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusCreated, map[string]interface{}{
|
||||
"message": "role assigned successfully",
|
||||
"mapping": mapping,
|
||||
})
|
||||
}
|
||||
|
||||
// RevokeRole 处理撤销角色请求
|
||||
func (h *IAMHandler) RevokeRole(w http.ResponseWriter, r *http.Request, userID int64, roleCode string, tenantID int64) {
|
||||
err := h.iamService.RevokeRole(r.Context(), userID, roleCode, tenantID)
|
||||
if err != nil {
|
||||
if err == service.ErrRoleNotFound {
|
||||
writeError(w, http.StatusNotFound, "ROLE_NOT_FOUND", err.Error())
|
||||
return
|
||||
}
|
||||
writeError(w, http.StatusInternalServerError, "INTERNAL_ERROR", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, map[string]interface{}{
|
||||
"message": "role revoked successfully",
|
||||
})
|
||||
}
|
||||
|
||||
// CheckScope 处理检查Scope请求
|
||||
func (h *IAMHandler) CheckScope(w http.ResponseWriter, r *http.Request) {
|
||||
scope := r.URL.Query().Get("scope")
|
||||
if scope == "" {
|
||||
writeError(w, http.StatusBadRequest, "MISSING_SCOPE", "scope parameter is required")
|
||||
return
|
||||
}
|
||||
|
||||
// 从context获取userID(实际应用中应从认证中间件获取)
|
||||
userID := int64(1) // 模拟
|
||||
|
||||
hasScope, err := h.iamService.CheckScope(r.Context(), userID, scope)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "INTERNAL_ERROR", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, map[string]interface{}{
|
||||
"has_scope": hasScope,
|
||||
"scope": scope,
|
||||
"user_id": userID,
|
||||
})
|
||||
}
|
||||
|
||||
// toRoleResponse 转换为RoleResponse
|
||||
func toRoleResponse(role *service.Role) *RoleResponse {
|
||||
return &RoleResponse{
|
||||
Code: role.Code,
|
||||
Name: role.Name,
|
||||
Type: role.Type,
|
||||
Level: role.Level,
|
||||
IsActive: role.IsActive,
|
||||
}
|
||||
}
|
||||
|
||||
// writeJSON 写入JSON响应
|
||||
func writeJSON(w http.ResponseWriter, status int, data interface{}) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
json.NewEncoder(w).Encode(data)
|
||||
}
|
||||
|
||||
// writeError 写入错误响应
|
||||
func writeError(w http.ResponseWriter, status int, code, message string) {
|
||||
writeJSON(w, status, ErrorResponse{
|
||||
Error: HTTPError{
|
||||
Code: code,
|
||||
Message: message,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// extractRoleCode 从URL路径提取角色代码
|
||||
func extractRoleCode(path string) string {
|
||||
// /api/v1/iam/roles/developer -> developer
|
||||
parts := splitPath(path)
|
||||
if len(parts) >= 5 {
|
||||
return parts[4]
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// extractUserID 从URL路径提取用户ID
|
||||
func extractUserID(path string) string {
|
||||
// /api/v1/iam/users/123/roles -> 123
|
||||
parts := splitPath(path)
|
||||
if len(parts) >= 4 {
|
||||
return parts[3]
|
||||
}
|
||||
if len(parts) >= 6 {
|
||||
return parts[3]
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// extractRoleCodeFromUserPath 从用户路径提取角色代码
|
||||
func extractRoleCodeFromUserPath(path string) string {
|
||||
// /api/v1/iam/users/123/roles/developer -> developer
|
||||
parts := splitPath(path)
|
||||
if len(parts) >= 6 {
|
||||
return parts[5]
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// splitPath 分割URL路径
|
||||
func splitPath(path string) []string {
|
||||
var parts []string
|
||||
var current string
|
||||
for _, c := range path {
|
||||
if c == '/' {
|
||||
if current != "" {
|
||||
parts = append(parts, current)
|
||||
current = ""
|
||||
}
|
||||
} else {
|
||||
current += string(c)
|
||||
}
|
||||
}
|
||||
if current != "" {
|
||||
parts = append(parts, current)
|
||||
}
|
||||
return parts
|
||||
}
|
||||
|
||||
// RequireScope 返回一个要求特定Scope的中间件函数
|
||||
func RequireScope(scope string, iamService service.IAMServiceInterface) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// 从context获取userID
|
||||
userID := getUserIDFromContext(r.Context())
|
||||
if userID == 0 {
|
||||
writeError(w, http.StatusUnauthorized, "UNAUTHORIZED", "user not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
hasScope, err := iamService.CheckScope(r.Context(), userID, scope)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "INTERNAL_ERROR", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if !hasScope {
|
||||
writeError(w, http.StatusForbidden, "SCOPE_DENIED", "insufficient scope")
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// getUserIDFromContext 从context获取userID(实际应用中应从认证中间件获取)
|
||||
func getUserIDFromContext(ctx context.Context) int64 {
|
||||
// TODO: 从认证中间件获取真实的userID
|
||||
return 1
|
||||
}
|
||||
404
supply-api/internal/iam/handler/iam_handler_test.go
Normal file
404
supply-api/internal/iam/handler/iam_handler_test.go
Normal file
@@ -0,0 +1,404 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// 测试辅助函数
|
||||
|
||||
// testRoleResponse 用于测试的角色响应
|
||||
type testRoleResponse struct {
|
||||
Code string `json:"role_code"`
|
||||
Name string `json:"role_name"`
|
||||
Type string `json:"role_type"`
|
||||
Level int `json:"level"`
|
||||
IsActive bool `json:"is_active"`
|
||||
}
|
||||
|
||||
// testIAMService 模拟IAM服务
|
||||
type testIAMService struct {
|
||||
roles map[string]*testRoleResponse
|
||||
userScopes map[int64][]string
|
||||
}
|
||||
|
||||
type testRoleResponse2 struct {
|
||||
Code string
|
||||
Name string
|
||||
Type string
|
||||
Level int
|
||||
IsActive bool
|
||||
}
|
||||
|
||||
func newTestIAMService() *testIAMService {
|
||||
return &testIAMService{
|
||||
roles: map[string]*testRoleResponse{
|
||||
"viewer": {Code: "viewer", Name: "查看者", Type: "platform", Level: 10, IsActive: true},
|
||||
"operator": {Code: "operator", Name: "运维", Type: "platform", Level: 30, IsActive: true},
|
||||
},
|
||||
userScopes: map[int64][]string{
|
||||
1: {"platform:read", "platform:write"},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (s *testIAMService) CreateRole(req *CreateRoleHTTPRequest) (*testRoleResponse, error) {
|
||||
if _, exists := s.roles[req.Code]; exists {
|
||||
return nil, errDuplicateRole
|
||||
}
|
||||
return &testRoleResponse{
|
||||
Code: req.Code,
|
||||
Name: req.Name,
|
||||
Type: req.Type,
|
||||
Level: req.Level,
|
||||
IsActive: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *testIAMService) GetRole(roleCode string) (*testRoleResponse, error) {
|
||||
if role, exists := s.roles[roleCode]; exists {
|
||||
return role, nil
|
||||
}
|
||||
return nil, errNotFound
|
||||
}
|
||||
|
||||
func (s *testIAMService) ListRoles(roleType string) ([]*testRoleResponse, error) {
|
||||
var result []*testRoleResponse
|
||||
for _, role := range s.roles {
|
||||
if roleType == "" || role.Type == roleType {
|
||||
result = append(result, role)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *testIAMService) CheckScope(userID int64, scope string) bool {
|
||||
scopes, ok := s.userScopes[userID]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
for _, s := range scopes {
|
||||
if s == scope || s == "*" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// HTTP请求/响应类型
|
||||
type CreateRoleHTTPRequest struct {
|
||||
Code string `json:"code"`
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type"`
|
||||
Level int `json:"level"`
|
||||
Scopes []string `json:"scopes"`
|
||||
}
|
||||
|
||||
// 错误
|
||||
var (
|
||||
errNotFound = &HTTPErrorResponse{Code: "NOT_FOUND", Message: "not found"}
|
||||
errDuplicateRole = &HTTPErrorResponse{Code: "DUPLICATE", Message: "duplicate"}
|
||||
)
|
||||
|
||||
// HTTPErrorResponse HTTP错误响应
|
||||
type HTTPErrorResponse struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
func (e *HTTPErrorResponse) Error() string {
|
||||
return e.Message
|
||||
}
|
||||
|
||||
// HTTPHandler 测试用的HTTP处理器
|
||||
type HTTPHandler struct {
|
||||
iam *testIAMService
|
||||
}
|
||||
|
||||
func newHTTPHandler() *HTTPHandler {
|
||||
return &HTTPHandler{iam: newTestIAMService()}
|
||||
}
|
||||
|
||||
// handleCreateRole 创建角色
|
||||
func (h *HTTPHandler) handleCreateRole(w http.ResponseWriter, r *http.Request) {
|
||||
var req CreateRoleHTTPRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeErrorHTTPTest(w, http.StatusBadRequest, "INVALID_REQUEST", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
role, err := h.iam.CreateRole(&req)
|
||||
if err != nil {
|
||||
writeErrorHTTPTest(w, http.StatusInternalServerError, "INTERNAL_ERROR", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
writeJSONHTTPTest(w, http.StatusCreated, map[string]interface{}{
|
||||
"role": role,
|
||||
})
|
||||
}
|
||||
|
||||
// handleListRoles 列出角色
|
||||
func (h *HTTPHandler) handleListRoles(w http.ResponseWriter, r *http.Request) {
|
||||
roleType := r.URL.Query().Get("type")
|
||||
|
||||
roles, err := h.iam.ListRoles(roleType)
|
||||
if err != nil {
|
||||
writeErrorHTTPTest(w, http.StatusInternalServerError, "INTERNAL_ERROR", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
writeJSONHTTPTest(w, http.StatusOK, map[string]interface{}{
|
||||
"roles": roles,
|
||||
})
|
||||
}
|
||||
|
||||
// handleGetRole 获取角色
|
||||
func (h *HTTPHandler) handleGetRole(w http.ResponseWriter, r *http.Request) {
|
||||
roleCode := r.URL.Query().Get("code")
|
||||
if roleCode == "" {
|
||||
writeErrorHTTPTest(w, http.StatusBadRequest, "MISSING_CODE", "role code is required")
|
||||
return
|
||||
}
|
||||
|
||||
role, err := h.iam.GetRole(roleCode)
|
||||
if err != nil {
|
||||
if err == errNotFound {
|
||||
writeErrorHTTPTest(w, http.StatusNotFound, "NOT_FOUND", err.Error())
|
||||
return
|
||||
}
|
||||
writeErrorHTTPTest(w, http.StatusInternalServerError, "INTERNAL_ERROR", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
writeJSONHTTPTest(w, http.StatusOK, map[string]interface{}{
|
||||
"role": role,
|
||||
})
|
||||
}
|
||||
|
||||
// handleCheckScope 检查Scope
|
||||
func (h *HTTPHandler) handleCheckScope(w http.ResponseWriter, r *http.Request) {
|
||||
scope := r.URL.Query().Get("scope")
|
||||
if scope == "" {
|
||||
writeErrorHTTPTest(w, http.StatusBadRequest, "MISSING_SCOPE", "scope is required")
|
||||
return
|
||||
}
|
||||
|
||||
userID := int64(1)
|
||||
hasScope := h.iam.CheckScope(userID, scope)
|
||||
|
||||
writeJSONHTTPTest(w, http.StatusOK, map[string]interface{}{
|
||||
"has_scope": hasScope,
|
||||
"scope": scope,
|
||||
})
|
||||
}
|
||||
|
||||
func writeJSONHTTPTest(w http.ResponseWriter, status int, data interface{}) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
json.NewEncoder(w).Encode(data)
|
||||
}
|
||||
|
||||
func writeErrorHTTPTest(w http.ResponseWriter, status int, code, message string) {
|
||||
writeJSONHTTPTest(w, status, map[string]interface{}{
|
||||
"error": map[string]string{
|
||||
"code": code,
|
||||
"message": message,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// ==================== 测试用例 ====================
|
||||
|
||||
// TestHTTPHandler_CreateRole_Success 测试创建角色成功
|
||||
func TestHTTPHandler_CreateRole_Success(t *testing.T) {
|
||||
// arrange
|
||||
handler := newHTTPHandler()
|
||||
|
||||
body := `{"code":"developer","name":"开发者","type":"platform","level":20}`
|
||||
req := httptest.NewRequest("POST", "/api/v1/iam/roles", strings.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
// act
|
||||
rec := httptest.NewRecorder()
|
||||
handler.handleCreateRole(rec, req)
|
||||
|
||||
// assert
|
||||
assert.Equal(t, http.StatusCreated, rec.Code)
|
||||
|
||||
var resp map[string]interface{}
|
||||
json.Unmarshal(rec.Body.Bytes(), &resp)
|
||||
|
||||
role := resp["role"].(map[string]interface{})
|
||||
assert.Equal(t, "developer", role["role_code"])
|
||||
assert.Equal(t, "开发者", role["role_name"])
|
||||
}
|
||||
|
||||
// TestHTTPHandler_ListRoles_Success 测试列出角色成功
|
||||
func TestHTTPHandler_ListRoles_Success(t *testing.T) {
|
||||
// arrange
|
||||
handler := newHTTPHandler()
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/v1/iam/roles", nil)
|
||||
|
||||
// act
|
||||
rec := httptest.NewRecorder()
|
||||
handler.handleListRoles(rec, req)
|
||||
|
||||
// assert
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var resp map[string]interface{}
|
||||
json.Unmarshal(rec.Body.Bytes(), &resp)
|
||||
|
||||
roles := resp["roles"].([]interface{})
|
||||
assert.Len(t, roles, 2)
|
||||
}
|
||||
|
||||
// TestHTTPHandler_ListRoles_WithType 测试按类型列出角色
|
||||
func TestHTTPHandler_ListRoles_WithType(t *testing.T) {
|
||||
// arrange
|
||||
handler := newHTTPHandler()
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/v1/iam/roles?type=platform", nil)
|
||||
|
||||
// act
|
||||
rec := httptest.NewRecorder()
|
||||
handler.handleListRoles(rec, req)
|
||||
|
||||
// assert
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
}
|
||||
|
||||
// TestHTTPHandler_GetRole_Success 测试获取角色成功
|
||||
func TestHTTPHandler_GetRole_Success(t *testing.T) {
|
||||
// arrange
|
||||
handler := newHTTPHandler()
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/v1/iam/roles?code=viewer", nil)
|
||||
|
||||
// act
|
||||
rec := httptest.NewRecorder()
|
||||
handler.handleGetRole(rec, req)
|
||||
|
||||
// assert
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var resp map[string]interface{}
|
||||
json.Unmarshal(rec.Body.Bytes(), &resp)
|
||||
|
||||
role := resp["role"].(map[string]interface{})
|
||||
assert.Equal(t, "viewer", role["role_code"])
|
||||
}
|
||||
|
||||
// TestHTTPHandler_GetRole_NotFound 测试获取不存在的角色
|
||||
func TestHTTPHandler_GetRole_NotFound(t *testing.T) {
|
||||
// arrange
|
||||
handler := newHTTPHandler()
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/v1/iam/roles?code=nonexistent", nil)
|
||||
|
||||
// act
|
||||
rec := httptest.NewRecorder()
|
||||
handler.handleGetRole(rec, req)
|
||||
|
||||
// assert
|
||||
assert.Equal(t, http.StatusNotFound, rec.Code)
|
||||
}
|
||||
|
||||
// TestHTTPHandler_CheckScope_HasScope 测试检查Scope存在
|
||||
func TestHTTPHandler_CheckScope_HasScope(t *testing.T) {
|
||||
// arrange
|
||||
handler := newHTTPHandler()
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/v1/iam/check-scope?scope=platform:read", nil)
|
||||
|
||||
// act
|
||||
rec := httptest.NewRecorder()
|
||||
handler.handleCheckScope(rec, req)
|
||||
|
||||
// assert
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var resp map[string]interface{}
|
||||
json.Unmarshal(rec.Body.Bytes(), &resp)
|
||||
|
||||
assert.Equal(t, true, resp["has_scope"])
|
||||
assert.Equal(t, "platform:read", resp["scope"])
|
||||
}
|
||||
|
||||
// TestHTTPHandler_CheckScope_NoScope 测试检查Scope不存在
|
||||
func TestHTTPHandler_CheckScope_NoScope(t *testing.T) {
|
||||
// arrange
|
||||
handler := newHTTPHandler()
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/v1/iam/check-scope?scope=platform:admin", nil)
|
||||
|
||||
// act
|
||||
rec := httptest.NewRecorder()
|
||||
handler.handleCheckScope(rec, req)
|
||||
|
||||
// assert
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
|
||||
var resp map[string]interface{}
|
||||
json.Unmarshal(rec.Body.Bytes(), &resp)
|
||||
|
||||
assert.Equal(t, false, resp["has_scope"])
|
||||
}
|
||||
|
||||
// TestHTTPHandler_CheckScope_MissingScope 测试缺少Scope参数
|
||||
func TestHTTPHandler_CheckScope_MissingScope(t *testing.T) {
|
||||
// arrange
|
||||
handler := newHTTPHandler()
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/v1/iam/check-scope", nil)
|
||||
|
||||
// act
|
||||
rec := httptest.NewRecorder()
|
||||
handler.handleCheckScope(rec, req)
|
||||
|
||||
// assert
|
||||
assert.Equal(t, http.StatusBadRequest, rec.Code)
|
||||
}
|
||||
|
||||
// TestHTTPHandler_CreateRole_InvalidJSON 测试无效JSON
|
||||
func TestHTTPHandler_CreateRole_InvalidJSON(t *testing.T) {
|
||||
// arrange
|
||||
handler := newHTTPHandler()
|
||||
|
||||
body := `invalid json`
|
||||
req := httptest.NewRequest("POST", "/api/v1/iam/roles", strings.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
// act
|
||||
rec := httptest.NewRecorder()
|
||||
handler.handleCreateRole(rec, req)
|
||||
|
||||
// assert
|
||||
assert.Equal(t, http.StatusBadRequest, rec.Code)
|
||||
}
|
||||
|
||||
// TestHTTPHandler_GetRole_MissingCode 测试缺少角色代码
|
||||
func TestHTTPHandler_GetRole_MissingCode(t *testing.T) {
|
||||
// arrange
|
||||
handler := newHTTPHandler()
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/v1/iam/roles", nil) // 没有code参数
|
||||
|
||||
// act
|
||||
rec := httptest.NewRecorder()
|
||||
handler.handleGetRole(rec, req)
|
||||
|
||||
// assert
|
||||
assert.Equal(t, http.StatusBadRequest, rec.Code)
|
||||
}
|
||||
|
||||
// 确保函数被使用(避免编译错误)
|
||||
var _ = context.Background
|
||||
296
supply-api/internal/iam/middleware/role_inheritance_test.go
Normal file
296
supply-api/internal/iam/middleware/role_inheritance_test.go
Normal file
@@ -0,0 +1,296 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// TestRoleInheritance_OperatorInheritsViewer 测试运维人员继承查看者
|
||||
func TestRoleInheritance_OperatorInheritsViewer(t *testing.T) {
|
||||
// arrange
|
||||
// operator 显式配置拥有 viewer 所有 scope + platform:write 等
|
||||
operatorScopes := []string{"platform:read", "platform:write", "tenant:read", "tenant:write", "billing:read"}
|
||||
viewerScopes := []string{"platform:read", "tenant:read", "billing:read"}
|
||||
|
||||
operatorClaims := &IAMTokenClaims{
|
||||
SubjectID: "user:1",
|
||||
Role: "operator",
|
||||
Scope: operatorScopes,
|
||||
TenantID: 1,
|
||||
}
|
||||
|
||||
ctx := context.WithValue(context.Background(), IAMTokenClaimsKey, *operatorClaims)
|
||||
|
||||
// act & assert - operator 应该拥有 viewer 的所有 scope
|
||||
for _, viewerScope := range viewerScopes {
|
||||
assert.True(t, CheckScope(ctx, viewerScope),
|
||||
"operator should inherit viewer scope: %s", viewerScope)
|
||||
}
|
||||
|
||||
// operator 还有额外的 scope
|
||||
assert.True(t, CheckScope(ctx, "platform:write"))
|
||||
assert.False(t, CheckScope(ctx, "platform:admin")) // viewer 没有 platform:admin
|
||||
}
|
||||
|
||||
// TestRoleInheritance_ExplicitOverride 测试显式配置的Scope优先
|
||||
func TestRoleInheritance_ExplicitOverride(t *testing.T) {
|
||||
// arrange
|
||||
// org_admin 显式配置拥有 operator + finops + developer + viewer 所有 scope
|
||||
orgAdminScopes := []string{
|
||||
// viewer scopes
|
||||
"platform:read", "tenant:read", "billing:read",
|
||||
// operator scopes
|
||||
"platform:write", "tenant:write",
|
||||
// finops scopes
|
||||
"billing:write",
|
||||
// developer scopes
|
||||
"router:model:list",
|
||||
// org_admin 自身 scope
|
||||
"platform:admin", "tenant:member:manage",
|
||||
}
|
||||
|
||||
orgAdminClaims := &IAMTokenClaims{
|
||||
SubjectID: "user:2",
|
||||
Role: "org_admin",
|
||||
Scope: orgAdminScopes,
|
||||
TenantID: 1,
|
||||
}
|
||||
|
||||
ctx := context.WithValue(context.Background(), IAMTokenClaimsKey, *orgAdminClaims)
|
||||
|
||||
// act & assert - org_admin 应该拥有所有子角色的 scope
|
||||
assert.True(t, CheckScope(ctx, "platform:read")) // viewer
|
||||
assert.True(t, CheckScope(ctx, "tenant:read")) // viewer
|
||||
assert.True(t, CheckScope(ctx, "billing:read")) // viewer/finops
|
||||
assert.True(t, CheckScope(ctx, "platform:write")) // operator
|
||||
assert.True(t, CheckScope(ctx, "tenant:write")) // operator
|
||||
assert.True(t, CheckScope(ctx, "billing:write")) // finops
|
||||
assert.True(t, CheckScope(ctx, "router:model:list")) // developer
|
||||
assert.True(t, CheckScope(ctx, "platform:admin")) // org_admin 自身
|
||||
}
|
||||
|
||||
// TestRoleInheritance_ViewerDoesNotInherit 测试查看者不继承任何角色
|
||||
func TestRoleInheritance_ViewerDoesNotInherit(t *testing.T) {
|
||||
// arrange
|
||||
viewerScopes := []string{"platform:read", "tenant:read", "billing:read"}
|
||||
|
||||
viewerClaims := &IAMTokenClaims{
|
||||
SubjectID: "user:3",
|
||||
Role: "viewer",
|
||||
Scope: viewerScopes,
|
||||
TenantID: 1,
|
||||
}
|
||||
|
||||
ctx := context.WithValue(context.Background(), IAMTokenClaimsKey, *viewerClaims)
|
||||
|
||||
// act & assert - viewer 是基础角色,不继承任何角色
|
||||
assert.True(t, CheckScope(ctx, "platform:read"))
|
||||
assert.False(t, CheckScope(ctx, "platform:write")) // viewer 没有 write
|
||||
assert.False(t, CheckScope(ctx, "platform:admin")) // viewer 没有 admin
|
||||
}
|
||||
|
||||
// TestRoleInheritance_SupplyChain 测试供应方角色链
|
||||
func TestRoleInheritance_SupplyChain(t *testing.T) {
|
||||
// arrange
|
||||
// supply_admin > supply_operator > supply_viewer
|
||||
supplyViewerScopes := []string{"supply:account:read", "supply:package:read"}
|
||||
supplyOperatorScopes := []string{"supply:account:read", "supply:account:write", "supply:package:read", "supply:package:write", "supply:package:publish"}
|
||||
supplyAdminScopes := []string{"supply:account:read", "supply:account:write", "supply:package:read", "supply:package:write", "supply:package:publish", "supply:package:offline", "supply:settlement:withdraw"}
|
||||
|
||||
// supply_viewer 测试
|
||||
viewerCtx := context.WithValue(context.Background(), IAMTokenClaimsKey, IAMTokenClaims{
|
||||
SubjectID: "user:4",
|
||||
Role: "supply_viewer",
|
||||
Scope: supplyViewerScopes,
|
||||
TenantID: 1,
|
||||
})
|
||||
|
||||
// act & assert
|
||||
assert.True(t, CheckScope(viewerCtx, "supply:account:read"))
|
||||
assert.False(t, CheckScope(viewerCtx, "supply:account:write"))
|
||||
|
||||
// supply_operator 测试
|
||||
operatorCtx := context.WithValue(context.Background(), IAMTokenClaimsKey, IAMTokenClaims{
|
||||
SubjectID: "user:5",
|
||||
Role: "supply_operator",
|
||||
Scope: supplyOperatorScopes,
|
||||
TenantID: 1,
|
||||
})
|
||||
|
||||
// act & assert - operator 继承 viewer
|
||||
assert.True(t, CheckScope(operatorCtx, "supply:account:read"))
|
||||
assert.True(t, CheckScope(operatorCtx, "supply:account:write"))
|
||||
assert.False(t, CheckScope(operatorCtx, "supply:settlement:withdraw")) // operator 没有 withdraw
|
||||
|
||||
// supply_admin 测试
|
||||
adminCtx := context.WithValue(context.Background(), IAMTokenClaimsKey, IAMTokenClaims{
|
||||
SubjectID: "user:6",
|
||||
Role: "supply_admin",
|
||||
Scope: supplyAdminScopes,
|
||||
TenantID: 1,
|
||||
})
|
||||
|
||||
// act & assert - admin 继承所有
|
||||
assert.True(t, CheckScope(adminCtx, "supply:account:read"))
|
||||
assert.True(t, CheckScope(adminCtx, "supply:settlement:withdraw"))
|
||||
}
|
||||
|
||||
// TestRoleInheritance_ConsumerChain 测试需求方角色链
|
||||
func TestRoleInheritance_ConsumerChain(t *testing.T) {
|
||||
// arrange
|
||||
// consumer_admin > consumer_operator > consumer_viewer
|
||||
consumerViewerScopes := []string{"consumer:account:read", "consumer:apikey:read", "consumer:usage:read"}
|
||||
consumerOperatorScopes := []string{"consumer:account:read", "consumer:account:write", "consumer:apikey:read", "consumer:apikey:create", "consumer:apikey:revoke", "consumer:usage:read"}
|
||||
consumerAdminScopes := []string{"consumer:account:read", "consumer:account:write", "consumer:apikey:read", "consumer:apikey:create", "consumer:apikey:revoke", "consumer:usage:read"}
|
||||
|
||||
// consumer_viewer 测试
|
||||
viewerCtx := context.WithValue(context.Background(), IAMTokenClaimsKey, IAMTokenClaims{
|
||||
SubjectID: "user:7",
|
||||
Role: "consumer_viewer",
|
||||
Scope: consumerViewerScopes,
|
||||
TenantID: 1,
|
||||
})
|
||||
|
||||
// act & assert
|
||||
assert.True(t, CheckScope(viewerCtx, "consumer:account:read"))
|
||||
assert.True(t, CheckScope(viewerCtx, "consumer:usage:read"))
|
||||
assert.False(t, CheckScope(viewerCtx, "consumer:apikey:create"))
|
||||
|
||||
// consumer_operator 测试
|
||||
operatorCtx := context.WithValue(context.Background(), IAMTokenClaimsKey, IAMTokenClaims{
|
||||
SubjectID: "user:8",
|
||||
Role: "consumer_operator",
|
||||
Scope: consumerOperatorScopes,
|
||||
TenantID: 1,
|
||||
})
|
||||
|
||||
// act & assert - operator 继承 viewer
|
||||
assert.True(t, CheckScope(operatorCtx, "consumer:apikey:create"))
|
||||
assert.True(t, CheckScope(operatorCtx, "consumer:apikey:revoke"))
|
||||
|
||||
// consumer_admin 测试
|
||||
adminCtx := context.WithValue(context.Background(), IAMTokenClaimsKey, IAMTokenClaims{
|
||||
SubjectID: "user:9",
|
||||
Role: "consumer_admin",
|
||||
Scope: consumerAdminScopes,
|
||||
TenantID: 1,
|
||||
})
|
||||
|
||||
// act & assert - admin 继承所有
|
||||
assert.True(t, CheckScope(adminCtx, "consumer:account:read"))
|
||||
assert.True(t, CheckScope(adminCtx, "consumer:apikey:revoke"))
|
||||
}
|
||||
|
||||
// TestRoleInheritance_MultipleRoles 测试多角色继承(显式配置模拟)
|
||||
func TestRoleInheritance_MultipleRoles(t *testing.T) {
|
||||
// arrange
|
||||
// 假设用户同时拥有 developer 和 finops 角色(通过 scope 累加)
|
||||
combinedScopes := []string{
|
||||
// viewer scopes
|
||||
"platform:read", "tenant:read", "billing:read",
|
||||
// developer scopes
|
||||
"router:model:list", "router:invoke",
|
||||
// finops scopes
|
||||
"billing:write",
|
||||
}
|
||||
|
||||
combinedClaims := &IAMTokenClaims{
|
||||
SubjectID: "user:10",
|
||||
Role: "developer", // 主角色
|
||||
Scope: combinedScopes,
|
||||
TenantID: 1,
|
||||
}
|
||||
|
||||
ctx := context.WithValue(context.Background(), IAMTokenClaimsKey, *combinedClaims)
|
||||
|
||||
// act & assert
|
||||
assert.True(t, CheckScope(ctx, "platform:read")) // viewer
|
||||
assert.True(t, CheckScope(ctx, "billing:read")) // viewer
|
||||
assert.True(t, CheckScope(ctx, "router:model:list")) // developer
|
||||
assert.True(t, CheckScope(ctx, "billing:write")) // finops
|
||||
}
|
||||
|
||||
// TestRoleInheritance_SuperAdmin 测试超级管理员
|
||||
func TestRoleInheritance_SuperAdmin(t *testing.T) {
|
||||
// arrange
|
||||
superAdminClaims := &IAMTokenClaims{
|
||||
SubjectID: "user:11",
|
||||
Role: "super_admin",
|
||||
Scope: []string{"*"}, // 通配符拥有所有权限
|
||||
TenantID: 0,
|
||||
}
|
||||
|
||||
ctx := context.WithValue(context.Background(), IAMTokenClaimsKey, *superAdminClaims)
|
||||
|
||||
// act & assert - super_admin 拥有所有 scope
|
||||
assert.True(t, CheckScope(ctx, "platform:read"))
|
||||
assert.True(t, CheckScope(ctx, "platform:admin"))
|
||||
assert.True(t, CheckScope(ctx, "supply:account:write"))
|
||||
assert.True(t, CheckScope(ctx, "consumer:apikey:create"))
|
||||
assert.True(t, CheckScope(ctx, "billing:write"))
|
||||
}
|
||||
|
||||
// TestRoleInheritance_DeveloperInheritsViewer 测试开发者继承查看者
|
||||
func TestRoleInheritance_DeveloperInheritsViewer(t *testing.T) {
|
||||
// arrange
|
||||
developerScopes := []string{"platform:read", "tenant:read", "billing:read", "router:invoke", "router:model:list"}
|
||||
|
||||
developerClaims := &IAMTokenClaims{
|
||||
SubjectID: "user:12",
|
||||
Role: "developer",
|
||||
Scope: developerScopes,
|
||||
TenantID: 1,
|
||||
}
|
||||
|
||||
ctx := context.WithValue(context.Background(), IAMTokenClaimsKey, *developerClaims)
|
||||
|
||||
// act & assert - developer 继承 viewer 的所有 scope
|
||||
assert.True(t, CheckScope(ctx, "platform:read"))
|
||||
assert.True(t, CheckScope(ctx, "tenant:read"))
|
||||
assert.True(t, CheckScope(ctx, "billing:read"))
|
||||
assert.True(t, CheckScope(ctx, "router:invoke")) // developer 自身 scope
|
||||
assert.False(t, CheckScope(ctx, "platform:write")) // developer 没有 write
|
||||
}
|
||||
|
||||
// TestRoleInheritance_FinopsInheritsViewer 测试财务人员继承查看者
|
||||
func TestRoleInheritance_FinopsInheritsViewer(t *testing.T) {
|
||||
// arrange
|
||||
finopsScopes := []string{"platform:read", "tenant:read", "billing:read", "billing:write"}
|
||||
|
||||
finopsClaims := &IAMTokenClaims{
|
||||
SubjectID: "user:13",
|
||||
Role: "finops",
|
||||
Scope: finopsScopes,
|
||||
TenantID: 1,
|
||||
}
|
||||
|
||||
ctx := context.WithValue(context.Background(), IAMTokenClaimsKey, *finopsClaims)
|
||||
|
||||
// act & assert - finops 继承 viewer 的所有 scope
|
||||
assert.True(t, CheckScope(ctx, "platform:read"))
|
||||
assert.True(t, CheckScope(ctx, "tenant:read"))
|
||||
assert.True(t, CheckScope(ctx, "billing:read"))
|
||||
assert.True(t, CheckScope(ctx, "billing:write")) // finops 自身 scope
|
||||
assert.False(t, CheckScope(ctx, "platform:write")) // finops 没有 write
|
||||
}
|
||||
|
||||
// TestRoleInheritance_DeveloperDoesNotInheritOperator 测试开发者不继承运维
|
||||
func TestRoleInheritance_DeveloperDoesNotInheritOperator(t *testing.T) {
|
||||
// arrange
|
||||
developerScopes := []string{"platform:read", "tenant:read", "billing:read", "router:invoke", "router:model:list"}
|
||||
|
||||
developerClaims := &IAMTokenClaims{
|
||||
SubjectID: "user:14",
|
||||
Role: "developer",
|
||||
Scope: developerScopes,
|
||||
TenantID: 1,
|
||||
}
|
||||
|
||||
ctx := context.WithValue(context.Background(), IAMTokenClaimsKey, *developerClaims)
|
||||
|
||||
// act & assert - developer 不继承 operator 的 scope
|
||||
assert.False(t, CheckScope(ctx, "platform:write")) // operator 有,developer 没有
|
||||
assert.False(t, CheckScope(ctx, "tenant:write")) // operator 有,developer 没有
|
||||
}
|
||||
350
supply-api/internal/iam/middleware/scope_auth.go
Normal file
350
supply-api/internal/iam/middleware/scope_auth.go
Normal file
@@ -0,0 +1,350 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"lijiaoqiao/supply-api/internal/middleware"
|
||||
)
|
||||
|
||||
// IAM token claims context key
|
||||
type iamContextKey string
|
||||
|
||||
const (
|
||||
// IAMTokenClaimsKey 用于在context中存储token claims
|
||||
IAMTokenClaimsKey iamContextKey = "iam_token_claims"
|
||||
)
|
||||
|
||||
// IAMTokenClaims IAM扩展Token Claims
|
||||
type IAMTokenClaims struct {
|
||||
SubjectID string `json:"subject_id"`
|
||||
Role string `json:"role"`
|
||||
Scope []string `json:"scope"`
|
||||
TenantID int64 `json:"tenant_id"`
|
||||
UserType string `json:"user_type"` // 用户类型: platform/supply/consumer
|
||||
Permissions []string `json:"permissions"` // 细粒度权限列表
|
||||
}
|
||||
|
||||
// ScopeAuthMiddleware Scope权限验证中间件
|
||||
type ScopeAuthMiddleware struct {
|
||||
// 路由-Scope映射
|
||||
routeScopePolicies map[string][]string
|
||||
// 角色层级
|
||||
roleHierarchy map[string]int
|
||||
}
|
||||
|
||||
// NewScopeAuthMiddleware 创建Scope权限验证中间件
|
||||
func NewScopeAuthMiddleware() *ScopeAuthMiddleware {
|
||||
return &ScopeAuthMiddleware{
|
||||
routeScopePolicies: make(map[string][]string),
|
||||
roleHierarchy: map[string]int{
|
||||
"super_admin": 100,
|
||||
"org_admin": 50,
|
||||
"supply_admin": 40,
|
||||
"consumer_admin": 40,
|
||||
"operator": 30,
|
||||
"developer": 20,
|
||||
"finops": 20,
|
||||
"supply_operator": 30,
|
||||
"supply_finops": 20,
|
||||
"supply_viewer": 10,
|
||||
"consumer_operator": 30,
|
||||
"consumer_viewer": 10,
|
||||
"viewer": 10,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// SetRouteScopePolicy 设置路由的Scope要求
|
||||
func (m *ScopeAuthMiddleware) SetRouteScopePolicy(route string, scopes []string) {
|
||||
m.routeScopePolicies[route] = scopes
|
||||
}
|
||||
|
||||
// CheckScope 检查是否拥有指定Scope
|
||||
func CheckScope(ctx context.Context, requiredScope string) bool {
|
||||
claims := getIAMTokenClaims(ctx)
|
||||
if claims == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// 空scope直接通过
|
||||
if requiredScope == "" {
|
||||
return true
|
||||
}
|
||||
|
||||
return hasScope(claims.Scope, requiredScope)
|
||||
}
|
||||
|
||||
// CheckAllScopes 检查是否拥有所有指定Scope
|
||||
func CheckAllScopes(ctx context.Context, requiredScopes []string) bool {
|
||||
claims := getIAMTokenClaims(ctx)
|
||||
if claims == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// 空列表直接通过
|
||||
if len(requiredScopes) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
for _, scope := range requiredScopes {
|
||||
if !hasScope(claims.Scope, scope) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// CheckAnyScope 检查是否拥有任一指定Scope
|
||||
func CheckAnyScope(ctx context.Context, requiredScopes []string) bool {
|
||||
claims := getIAMTokenClaims(ctx)
|
||||
if claims == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// 空列表直接通过
|
||||
if len(requiredScopes) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
for _, scope := range requiredScopes {
|
||||
if hasScope(claims.Scope, scope) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// HasRole 检查是否拥有指定角色
|
||||
func HasRole(ctx context.Context, requiredRole string) bool {
|
||||
claims := getIAMTokenClaims(ctx)
|
||||
if claims == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return claims.Role == requiredRole
|
||||
}
|
||||
|
||||
// HasRoleLevel 检查角色层级是否满足要求
|
||||
func HasRoleLevel(ctx context.Context, minLevel int) bool {
|
||||
claims := getIAMTokenClaims(ctx)
|
||||
if claims == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
level := GetRoleLevel(claims.Role)
|
||||
return level >= minLevel
|
||||
}
|
||||
|
||||
// GetRoleLevel 获取角色层级数值
|
||||
func GetRoleLevel(role string) int {
|
||||
hierarchy := map[string]int{
|
||||
"super_admin": 100,
|
||||
"org_admin": 50,
|
||||
"supply_admin": 40,
|
||||
"consumer_admin": 40,
|
||||
"operator": 30,
|
||||
"developer": 20,
|
||||
"finops": 20,
|
||||
"supply_operator": 30,
|
||||
"supply_finops": 20,
|
||||
"supply_viewer": 10,
|
||||
"consumer_operator": 30,
|
||||
"consumer_viewer": 10,
|
||||
"viewer": 10,
|
||||
}
|
||||
|
||||
if level, ok := hierarchy[role]; ok {
|
||||
return level
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// GetIAMTokenClaims 获取IAM Token Claims
|
||||
func GetIAMTokenClaims(ctx context.Context) *IAMTokenClaims {
|
||||
if claims, ok := ctx.Value(IAMTokenClaimsKey).(IAMTokenClaims); ok {
|
||||
return &claims
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// getIAMTokenClaims 内部获取IAM Token Claims
|
||||
func getIAMTokenClaims(ctx context.Context) *IAMTokenClaims {
|
||||
if claims, ok := ctx.Value(IAMTokenClaimsKey).(IAMTokenClaims); ok {
|
||||
return &claims
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// hasScope 检查scope列表是否包含目标scope
|
||||
func hasScope(scopes []string, target string) bool {
|
||||
for _, scope := range scopes {
|
||||
if scope == target || scope == "*" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// RequireScope 返回一个要求特定Scope的中间件
|
||||
func (m *ScopeAuthMiddleware) RequireScope(requiredScope string) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
claims := getIAMTokenClaims(r.Context())
|
||||
|
||||
if claims == nil {
|
||||
writeAuthError(w, http.StatusUnauthorized, "AUTH_CONTEXT_MISSING",
|
||||
"authentication context is missing")
|
||||
return
|
||||
}
|
||||
|
||||
// 检查scope
|
||||
if requiredScope != "" && !hasScope(claims.Scope, requiredScope) {
|
||||
writeAuthError(w, http.StatusForbidden, "AUTH_SCOPE_DENIED",
|
||||
"required scope is not granted")
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// RequireAllScopes 返回一个要求所有指定Scope的中间件
|
||||
func (m *ScopeAuthMiddleware) RequireAllScopes(requiredScopes []string) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
claims := getIAMTokenClaims(r.Context())
|
||||
|
||||
if claims == nil {
|
||||
writeAuthError(w, http.StatusUnauthorized, "AUTH_CONTEXT_MISSING",
|
||||
"authentication context is missing")
|
||||
return
|
||||
}
|
||||
|
||||
for _, scope := range requiredScopes {
|
||||
if !hasScope(claims.Scope, scope) {
|
||||
writeAuthError(w, http.StatusForbidden, "AUTH_SCOPE_DENIED",
|
||||
"required scope is not granted")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// RequireAnyScope 返回一个要求任一指定Scope的中间件
|
||||
func (m *ScopeAuthMiddleware) RequireAnyScope(requiredScopes []string) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
claims := getIAMTokenClaims(r.Context())
|
||||
|
||||
if claims == nil {
|
||||
writeAuthError(w, http.StatusUnauthorized, "AUTH_CONTEXT_MISSING",
|
||||
"authentication context is missing")
|
||||
return
|
||||
}
|
||||
|
||||
// 空列表直接通过
|
||||
if len(requiredScopes) > 0 && !hasAnyScope(claims.Scope, requiredScopes) {
|
||||
writeAuthError(w, http.StatusForbidden, "AUTH_SCOPE_DENIED",
|
||||
"none of the required scopes are granted")
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// RequireRole 返回一个要求特定角色的中间件
|
||||
func (m *ScopeAuthMiddleware) RequireRole(requiredRole string) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
claims := getIAMTokenClaims(r.Context())
|
||||
|
||||
if claims == nil {
|
||||
writeAuthError(w, http.StatusUnauthorized, "AUTH_CONTEXT_MISSING",
|
||||
"authentication context is missing")
|
||||
return
|
||||
}
|
||||
|
||||
if claims.Role != requiredRole {
|
||||
writeAuthError(w, http.StatusForbidden, "AUTH_ROLE_DENIED",
|
||||
"required role is not granted")
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// RequireMinLevel 返回一个要求最小角色层级的中间件
|
||||
func (m *ScopeAuthMiddleware) RequireMinLevel(minLevel int) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
claims := getIAMTokenClaims(r.Context())
|
||||
|
||||
if claims == nil {
|
||||
writeAuthError(w, http.StatusUnauthorized, "AUTH_CONTEXT_MISSING",
|
||||
"authentication context is missing")
|
||||
return
|
||||
}
|
||||
|
||||
level := GetRoleLevel(claims.Role)
|
||||
if level < minLevel {
|
||||
writeAuthError(w, http.StatusForbidden, "AUTH_ROLE_LEVEL_DENIED",
|
||||
"insufficient role level")
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// hasAnyScope 检查scope列表是否包含任一目标scope
|
||||
func hasAnyScope(scopes, targets []string) bool {
|
||||
for _, scope := range scopes {
|
||||
for _, target := range targets {
|
||||
if scope == target || scope == "*" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// writeAuthError 写入鉴权错误
|
||||
func writeAuthError(w http.ResponseWriter, status int, code, message string) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
resp := map[string]interface{}{
|
||||
"error": map[string]string{
|
||||
"code": code,
|
||||
"message": message,
|
||||
},
|
||||
}
|
||||
_ = resp
|
||||
}
|
||||
|
||||
// WithIAMClaims 设置IAM Claims到Context
|
||||
func WithIAMClaims(ctx context.Context, claims *IAMTokenClaims) context.Context {
|
||||
return context.WithValue(ctx, IAMTokenClaimsKey, *claims)
|
||||
}
|
||||
|
||||
// GetClaimsFromLegacy 从原有middleware.TokenClaims转换为IAMTokenClaims
|
||||
func GetClaimsFromLegacy(legacy *middleware.TokenClaims) *IAMTokenClaims {
|
||||
if legacy == nil {
|
||||
return nil
|
||||
}
|
||||
return &IAMTokenClaims{
|
||||
SubjectID: legacy.SubjectID,
|
||||
Role: legacy.Role,
|
||||
Scope: legacy.Scope,
|
||||
TenantID: legacy.TenantID,
|
||||
}
|
||||
}
|
||||
439
supply-api/internal/iam/middleware/scope_auth_test.go
Normal file
439
supply-api/internal/iam/middleware/scope_auth_test.go
Normal file
@@ -0,0 +1,439 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"lijiaoqiao/supply-api/internal/middleware"
|
||||
)
|
||||
|
||||
// TestScopeAuth_CheckScope_SuperAdminHasAllScopes 测试超级管理员拥有所有Scope
|
||||
func TestScopeAuth_CheckScope_SuperAdminHasAllScopes(t *testing.T) {
|
||||
// arrange
|
||||
// 创建超级管理员token claims
|
||||
claims := &IAMTokenClaims{
|
||||
SubjectID: "user:1",
|
||||
Role: "super_admin",
|
||||
Scope: []string{"*"}, // 通配符Scope代表所有权限
|
||||
TenantID: 0,
|
||||
}
|
||||
|
||||
ctx := context.WithValue(context.Background(), IAMTokenClaimsKey, *claims)
|
||||
|
||||
// act
|
||||
hasScope := CheckScope(ctx, "platform:read")
|
||||
hasScope2 := CheckScope(ctx, "supply:account:write")
|
||||
hasScope3 := CheckScope(ctx, "consumer:apikey:create")
|
||||
|
||||
// assert
|
||||
assert.True(t, hasScope, "super_admin should have platform:read")
|
||||
assert.True(t, hasScope2, "super_admin should have supply:account:write")
|
||||
assert.True(t, hasScope3, "super_admin should have consumer:apikey:create")
|
||||
}
|
||||
|
||||
// TestScopeAuth_CheckScope_ViewerHasReadOnly 测试Viewer只有只读权限
|
||||
func TestScopeAuth_CheckScope_ViewerHasReadOnly(t *testing.T) {
|
||||
// arrange
|
||||
claims := &IAMTokenClaims{
|
||||
SubjectID: "user:2",
|
||||
Role: "viewer",
|
||||
Scope: []string{"platform:read", "tenant:read", "billing:read"},
|
||||
TenantID: 1,
|
||||
}
|
||||
|
||||
ctx := context.WithValue(context.Background(), IAMTokenClaimsKey, *claims)
|
||||
|
||||
// act & assert
|
||||
assert.True(t, CheckScope(ctx, "platform:read"), "viewer should have platform:read")
|
||||
assert.True(t, CheckScope(ctx, "tenant:read"), "viewer should have tenant:read")
|
||||
assert.True(t, CheckScope(ctx, "billing:read"), "viewer should have billing:read")
|
||||
|
||||
assert.False(t, CheckScope(ctx, "platform:write"), "viewer should NOT have platform:write")
|
||||
assert.False(t, CheckScope(ctx, "tenant:write"), "viewer should NOT have tenant:write")
|
||||
assert.False(t, CheckScope(ctx, "supply:account:write"), "viewer should NOT have supply:account:write")
|
||||
}
|
||||
|
||||
// TestScopeAuth_CheckScope_Denied 测试Scope被拒绝
|
||||
func TestScopeAuth_CheckScope_Denied(t *testing.T) {
|
||||
// arrange
|
||||
claims := &IAMTokenClaims{
|
||||
SubjectID: "user:3",
|
||||
Role: "viewer",
|
||||
Scope: []string{"platform:read"},
|
||||
TenantID: 1,
|
||||
}
|
||||
|
||||
ctx := context.WithValue(context.Background(), IAMTokenClaimsKey, *claims)
|
||||
|
||||
// act & assert
|
||||
assert.False(t, CheckScope(ctx, "platform:write"), "viewer should NOT have platform:write")
|
||||
assert.False(t, CheckScope(ctx, "supply:account:write"), "viewer should NOT have supply:account:write")
|
||||
}
|
||||
|
||||
// TestScopeAuth_CheckScope_MissingTokenClaims 测试缺少Token Claims
|
||||
func TestScopeAuth_CheckScope_MissingTokenClaims(t *testing.T) {
|
||||
// arrange
|
||||
ctx := context.Background() // 没有token claims
|
||||
|
||||
// act
|
||||
hasScope := CheckScope(ctx, "platform:read")
|
||||
|
||||
// assert
|
||||
assert.False(t, hasScope, "should return false when token claims are missing")
|
||||
}
|
||||
|
||||
// TestScopeAuth_CheckScope_EmptyScope 测试空Scope要求
|
||||
func TestScopeAuth_CheckScope_EmptyScope(t *testing.T) {
|
||||
// arrange
|
||||
claims := &IAMTokenClaims{
|
||||
SubjectID: "user:4",
|
||||
Role: "viewer",
|
||||
Scope: []string{"platform:read"},
|
||||
TenantID: 1,
|
||||
}
|
||||
|
||||
ctx := context.WithValue(context.Background(), IAMTokenClaimsKey, *claims)
|
||||
|
||||
// act
|
||||
hasEmptyScope := CheckScope(ctx, "")
|
||||
|
||||
// assert
|
||||
assert.True(t, hasEmptyScope, "empty scope should always pass")
|
||||
}
|
||||
|
||||
// TestScopeAuth_CheckMultipleScopes 测试检查多个Scope(需要全部满足)
|
||||
func TestScopeAuth_CheckMultipleScopes(t *testing.T) {
|
||||
// arrange
|
||||
claims := &IAMTokenClaims{
|
||||
SubjectID: "user:5",
|
||||
Role: "operator",
|
||||
Scope: []string{"platform:read", "platform:write", "tenant:read", "tenant:write"},
|
||||
TenantID: 1,
|
||||
}
|
||||
|
||||
ctx := context.WithValue(context.Background(), IAMTokenClaimsKey, *claims)
|
||||
|
||||
// act & assert
|
||||
assert.True(t, CheckAllScopes(ctx, []string{"platform:read", "platform:write"}), "operator should have both read and write")
|
||||
assert.True(t, CheckAllScopes(ctx, []string{"tenant:read", "tenant:write"}), "operator should have both tenant scopes")
|
||||
assert.False(t, CheckAllScopes(ctx, []string{"platform:read", "platform:admin"}), "operator should NOT have platform:admin")
|
||||
}
|
||||
|
||||
// TestScopeAuth_CheckAnyScope 测试检查多个Scope(只需满足其一)
|
||||
func TestScopeAuth_CheckAnyScope(t *testing.T) {
|
||||
// arrange
|
||||
claims := &IAMTokenClaims{
|
||||
SubjectID: "user:6",
|
||||
Role: "viewer",
|
||||
Scope: []string{"platform:read"},
|
||||
TenantID: 1,
|
||||
}
|
||||
|
||||
ctx := context.WithValue(context.Background(), IAMTokenClaimsKey, *claims)
|
||||
|
||||
// act & assert
|
||||
assert.True(t, CheckAnyScope(ctx, []string{"platform:read", "platform:write"}), "should pass with one matching scope")
|
||||
assert.False(t, CheckAnyScope(ctx, []string{"platform:write", "platform:admin"}), "should fail when no scopes match")
|
||||
assert.True(t, CheckAnyScope(ctx, []string{}), "empty scope list should pass")
|
||||
}
|
||||
|
||||
// TestScopeAuth_GetIAMTokenClaims 测试从Context获取IAMTokenClaims
|
||||
func TestScopeAuth_GetIAMTokenClaims(t *testing.T) {
|
||||
// arrange
|
||||
claims := &IAMTokenClaims{
|
||||
SubjectID: "user:7",
|
||||
Role: "org_admin",
|
||||
Scope: []string{"platform:read", "platform:write"},
|
||||
TenantID: 1,
|
||||
}
|
||||
|
||||
ctx := context.WithValue(context.Background(), IAMTokenClaimsKey, *claims)
|
||||
|
||||
// act
|
||||
retrievedClaims := GetIAMTokenClaims(ctx)
|
||||
|
||||
// assert
|
||||
assert.NotNil(t, retrievedClaims)
|
||||
assert.Equal(t, claims.SubjectID, retrievedClaims.SubjectID)
|
||||
assert.Equal(t, claims.Role, retrievedClaims.Role)
|
||||
assert.Equal(t, claims.Scope, retrievedClaims.Scope)
|
||||
}
|
||||
|
||||
// TestScopeAuth_GetIAMTokenClaims_Missing 测试获取不存在的IAMTokenClaims
|
||||
func TestScopeAuth_GetIAMTokenClaims_Missing(t *testing.T) {
|
||||
// arrange
|
||||
ctx := context.Background()
|
||||
|
||||
// act
|
||||
retrievedClaims := GetIAMTokenClaims(ctx)
|
||||
|
||||
// assert
|
||||
assert.Nil(t, retrievedClaims)
|
||||
}
|
||||
|
||||
// TestScopeAuth_HasRole 测试用户角色检查
|
||||
func TestScopeAuth_HasRole(t *testing.T) {
|
||||
// arrange
|
||||
claims := &IAMTokenClaims{
|
||||
SubjectID: "user:8",
|
||||
Role: "operator",
|
||||
Scope: []string{"platform:read"},
|
||||
TenantID: 1,
|
||||
}
|
||||
|
||||
ctx := context.WithValue(context.Background(), IAMTokenClaimsKey, *claims)
|
||||
|
||||
// act & assert
|
||||
assert.True(t, HasRole(ctx, "operator"))
|
||||
assert.False(t, HasRole(ctx, "viewer"))
|
||||
assert.False(t, HasRole(ctx, "admin"))
|
||||
}
|
||||
|
||||
// TestScopeAuth_HasRole_MissingClaims 测试缺少Claims时的角色检查
|
||||
func TestScopeAuth_HasRole_MissingClaims(t *testing.T) {
|
||||
// arrange
|
||||
ctx := context.Background()
|
||||
|
||||
// act & assert
|
||||
assert.False(t, HasRole(ctx, "operator"))
|
||||
}
|
||||
|
||||
// TestScopeRoleAuthzMiddleware_WithScope 测试带Scope要求的中间件
|
||||
func TestScopeRoleAuthzMiddleware_WithScope(t *testing.T) {
|
||||
// arrange
|
||||
scopeAuth := NewScopeAuthMiddleware()
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`{"status":"ok"}`))
|
||||
})
|
||||
|
||||
// 创建一个带scope验证的handler
|
||||
wrappedHandler := scopeAuth.RequireScope("platform:write")(handler)
|
||||
|
||||
// 创建一个带有token claims的请求
|
||||
claims := &IAMTokenClaims{
|
||||
SubjectID: "user:9",
|
||||
Role: "operator",
|
||||
Scope: []string{"platform:read", "platform:write"},
|
||||
TenantID: 1,
|
||||
}
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req = req.WithContext(context.WithValue(req.Context(), IAMTokenClaimsKey, *claims))
|
||||
|
||||
// act
|
||||
rec := httptest.NewRecorder()
|
||||
wrappedHandler.ServeHTTP(rec, req)
|
||||
|
||||
// assert
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
}
|
||||
|
||||
// TestScopeRoleAuthzMiddleware_Denied 测试Scope不足时中间件拒绝
|
||||
func TestScopeRoleAuthzMiddleware_Denied(t *testing.T) {
|
||||
// arrange
|
||||
scopeAuth := NewScopeAuthMiddleware()
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
wrappedHandler := scopeAuth.RequireScope("platform:admin")(handler)
|
||||
|
||||
claims := &IAMTokenClaims{
|
||||
SubjectID: "user:10",
|
||||
Role: "viewer",
|
||||
Scope: []string{"platform:read"}, // viewer没有platform:admin
|
||||
TenantID: 1,
|
||||
}
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req = req.WithContext(context.WithValue(req.Context(), IAMTokenClaimsKey, *claims))
|
||||
|
||||
// act
|
||||
rec := httptest.NewRecorder()
|
||||
wrappedHandler.ServeHTTP(rec, req)
|
||||
|
||||
// assert
|
||||
assert.Equal(t, http.StatusForbidden, rec.Code)
|
||||
}
|
||||
|
||||
// TestScopeRoleAuthzMiddleware_MissingClaims 测试缺少Claims时中间件拒绝
|
||||
func TestScopeRoleAuthzMiddleware_MissingClaims(t *testing.T) {
|
||||
// arrange
|
||||
scopeAuth := NewScopeAuthMiddleware()
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
wrappedHandler := scopeAuth.RequireScope("platform:read")(handler)
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
// 不设置token claims
|
||||
|
||||
// act
|
||||
rec := httptest.NewRecorder()
|
||||
wrappedHandler.ServeHTTP(rec, req)
|
||||
|
||||
// assert
|
||||
assert.Equal(t, http.StatusUnauthorized, rec.Code)
|
||||
}
|
||||
|
||||
// TestScopeRoleAuthzMiddleware_RequireAllScopes 测试要求所有Scope的中间件
|
||||
func TestScopeRoleAuthzMiddleware_RequireAllScopes(t *testing.T) {
|
||||
// arrange
|
||||
scopeAuth := NewScopeAuthMiddleware()
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
wrappedHandler := scopeAuth.RequireAllScopes([]string{"platform:read", "tenant:read"})(handler)
|
||||
|
||||
claims := &IAMTokenClaims{
|
||||
SubjectID: "user:11",
|
||||
Role: "operator",
|
||||
Scope: []string{"platform:read", "platform:write", "tenant:read"},
|
||||
TenantID: 1,
|
||||
}
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req = req.WithContext(context.WithValue(req.Context(), IAMTokenClaimsKey, *claims))
|
||||
|
||||
// act
|
||||
rec := httptest.NewRecorder()
|
||||
wrappedHandler.ServeHTTP(rec, req)
|
||||
|
||||
// assert
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
}
|
||||
|
||||
// TestScopeRoleAuthzMiddleware_RequireAllScopes_Denied 测试要求所有Scope但不足时拒绝
|
||||
func TestScopeRoleAuthzMiddleware_RequireAllScopes_Denied(t *testing.T) {
|
||||
// arrange
|
||||
scopeAuth := NewScopeAuthMiddleware()
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
wrappedHandler := scopeAuth.RequireAllScopes([]string{"platform:read", "platform:admin"})(handler)
|
||||
|
||||
claims := &IAMTokenClaims{
|
||||
SubjectID: "user:12",
|
||||
Role: "viewer",
|
||||
Scope: []string{"platform:read"}, // viewer没有platform:admin
|
||||
TenantID: 1,
|
||||
}
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req = req.WithContext(context.WithValue(req.Context(), IAMTokenClaimsKey, *claims))
|
||||
|
||||
// act
|
||||
rec := httptest.NewRecorder()
|
||||
wrappedHandler.ServeHTTP(rec, req)
|
||||
|
||||
// assert
|
||||
assert.Equal(t, http.StatusForbidden, rec.Code)
|
||||
}
|
||||
|
||||
// TestScopeAuth_HasRoleLevel 测试角色层级检查
|
||||
func TestScopeAuth_HasRoleLevel(t *testing.T) {
|
||||
// arrange
|
||||
testCases := []struct {
|
||||
role string
|
||||
minLevel int
|
||||
expected bool
|
||||
}{
|
||||
{"super_admin", 50, true},
|
||||
{"super_admin", 100, true},
|
||||
{"org_admin", 50, true},
|
||||
{"org_admin", 60, false},
|
||||
{"operator", 30, true},
|
||||
{"operator", 40, false},
|
||||
{"viewer", 10, true},
|
||||
{"viewer", 20, false},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
claims := &IAMTokenClaims{
|
||||
SubjectID: "user:test",
|
||||
Role: tc.role,
|
||||
Scope: []string{},
|
||||
TenantID: 1,
|
||||
}
|
||||
ctx := context.WithValue(context.Background(), IAMTokenClaimsKey, *claims)
|
||||
|
||||
// act
|
||||
result := HasRoleLevel(ctx, tc.minLevel)
|
||||
|
||||
// assert
|
||||
assert.Equal(t, tc.expected, result, "role=%s, minLevel=%d", tc.role, tc.minLevel)
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetRoleLevel 测试获取角色层级
|
||||
func TestGetRoleLevel(t *testing.T) {
|
||||
testCases := []struct {
|
||||
role string
|
||||
expected int
|
||||
}{
|
||||
{"super_admin", 100},
|
||||
{"org_admin", 50},
|
||||
{"supply_admin", 40},
|
||||
{"operator", 30},
|
||||
{"developer", 20},
|
||||
{"viewer", 10},
|
||||
{"unknown_role", 0},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
// act
|
||||
level := GetRoleLevel(tc.role)
|
||||
|
||||
// assert
|
||||
assert.Equal(t, tc.expected, level, "role=%s", tc.role)
|
||||
}
|
||||
}
|
||||
|
||||
// TestScopeAuth_WithIAMClaims 测试设置IAM Claims到Context
|
||||
func TestScopeAuth_WithIAMClaims(t *testing.T) {
|
||||
// arrange
|
||||
claims := &IAMTokenClaims{
|
||||
SubjectID: "user:13",
|
||||
Role: "org_admin",
|
||||
Scope: []string{"platform:read"},
|
||||
TenantID: 1,
|
||||
}
|
||||
|
||||
// act
|
||||
ctx := WithIAMClaims(context.Background(), claims)
|
||||
retrievedClaims := GetIAMTokenClaims(ctx)
|
||||
|
||||
// assert
|
||||
assert.NotNil(t, retrievedClaims)
|
||||
assert.Equal(t, claims.SubjectID, retrievedClaims.SubjectID)
|
||||
assert.Equal(t, claims.Role, retrievedClaims.Role)
|
||||
}
|
||||
|
||||
// TestGetClaimsFromLegacy 测试从原有TokenClaims转换
|
||||
func TestGetClaimsFromLegacy(t *testing.T) {
|
||||
// arrange
|
||||
legacyClaims := &middleware.TokenClaims{
|
||||
SubjectID: "user:14",
|
||||
Role: "viewer",
|
||||
Scope: []string{"platform:read"},
|
||||
TenantID: 1,
|
||||
}
|
||||
|
||||
// act
|
||||
iamClaims := GetClaimsFromLegacy(legacyClaims)
|
||||
|
||||
// assert
|
||||
assert.NotNil(t, iamClaims)
|
||||
assert.Equal(t, legacyClaims.SubjectID, iamClaims.SubjectID)
|
||||
assert.Equal(t, legacyClaims.Role, iamClaims.Role)
|
||||
assert.Equal(t, legacyClaims.Scope, iamClaims.Scope)
|
||||
assert.Equal(t, legacyClaims.TenantID, iamClaims.TenantID)
|
||||
}
|
||||
211
supply-api/internal/iam/model/role.go
Normal file
211
supply-api/internal/iam/model/role.go
Normal file
@@ -0,0 +1,211 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"time"
|
||||
)
|
||||
|
||||
// 角色类型常量
|
||||
const (
|
||||
RoleTypePlatform = "platform"
|
||||
RoleTypeSupply = "supply"
|
||||
RoleTypeConsumer = "consumer"
|
||||
)
|
||||
|
||||
// 角色层级常量(用于权限优先级判断)
|
||||
const (
|
||||
LevelSuperAdmin = 100
|
||||
LevelOrgAdmin = 50
|
||||
LevelSupplyAdmin = 40
|
||||
LevelOperator = 30
|
||||
LevelDeveloper = 20
|
||||
LevelFinops = 20
|
||||
LevelViewer = 10
|
||||
)
|
||||
|
||||
// 角色错误定义
|
||||
var (
|
||||
ErrInvalidRoleCode = errors.New("invalid role code: cannot be empty")
|
||||
ErrInvalidRoleType = errors.New("invalid role type: must be platform, supply, or consumer")
|
||||
ErrInvalidLevel = errors.New("invalid level: must be non-negative")
|
||||
)
|
||||
|
||||
// Role 角色模型
|
||||
// 对应数据库 iam_roles 表
|
||||
type Role struct {
|
||||
ID int64 // 主键ID
|
||||
Code string // 角色代码 (unique)
|
||||
Name string // 角色名称
|
||||
Type string // 角色类型: platform, supply, consumer
|
||||
ParentRoleID *int64 // 父角色ID(用于继承关系)
|
||||
Level int // 权限层级
|
||||
Description string // 描述
|
||||
IsActive bool // 是否激活
|
||||
|
||||
// 审计字段
|
||||
RequestID string // 请求追踪ID
|
||||
CreatedIP string // 创建者IP
|
||||
UpdatedIP string // 更新者IP
|
||||
Version int // 乐观锁版本号
|
||||
|
||||
// 时间戳
|
||||
CreatedAt *time.Time // 创建时间
|
||||
UpdatedAt *time.Time // 更新时间
|
||||
|
||||
// 关联的Scope列表(运行时填充,不存储在iam_roles表)
|
||||
Scopes []string `json:"scopes,omitempty"`
|
||||
}
|
||||
|
||||
// NewRole 创建新角色(基础构造函数)
|
||||
func NewRole(code, name, roleType string, level int) *Role {
|
||||
now := time.Now()
|
||||
return &Role{
|
||||
Code: code,
|
||||
Name: name,
|
||||
Type: roleType,
|
||||
Level: level,
|
||||
IsActive: true,
|
||||
RequestID: generateRequestID(),
|
||||
Version: 1,
|
||||
CreatedAt: &now,
|
||||
UpdatedAt: &now,
|
||||
}
|
||||
}
|
||||
|
||||
// NewRoleWithParent 创建带父角色的角色
|
||||
func NewRoleWithParent(code, name, roleType string, level int, parentRoleID int64) *Role {
|
||||
role := NewRole(code, name, roleType, level)
|
||||
role.ParentRoleID = &parentRoleID
|
||||
return role
|
||||
}
|
||||
|
||||
// NewRoleWithRequestID 创建带指定RequestID的角色
|
||||
func NewRoleWithRequestID(code, name, roleType string, level int, requestID string) *Role {
|
||||
role := NewRole(code, name, roleType, level)
|
||||
role.RequestID = requestID
|
||||
return role
|
||||
}
|
||||
|
||||
// NewRoleWithAudit 创建带审计信息的角色
|
||||
func NewRoleWithAudit(code, name, roleType string, level int, requestID, createdIP, updatedIP string) *Role {
|
||||
role := NewRole(code, name, roleType, level)
|
||||
role.RequestID = requestID
|
||||
role.CreatedIP = createdIP
|
||||
role.UpdatedIP = updatedIP
|
||||
return role
|
||||
}
|
||||
|
||||
// NewRoleWithValidation 创建角色并进行验证
|
||||
func NewRoleWithValidation(code, name, roleType string, level int) (*Role, error) {
|
||||
// 验证角色代码
|
||||
if code == "" {
|
||||
return nil, ErrInvalidRoleCode
|
||||
}
|
||||
|
||||
// 验证角色类型
|
||||
if roleType != RoleTypePlatform && roleType != RoleTypeSupply && roleType != RoleTypeConsumer {
|
||||
return nil, ErrInvalidRoleType
|
||||
}
|
||||
|
||||
// 验证层级
|
||||
if level < 0 {
|
||||
return nil, ErrInvalidLevel
|
||||
}
|
||||
|
||||
role := NewRole(code, name, roleType, level)
|
||||
return role, nil
|
||||
}
|
||||
|
||||
// Activate 激活角色
|
||||
func (r *Role) Activate() {
|
||||
r.IsActive = true
|
||||
r.UpdatedAt = nowPtr()
|
||||
}
|
||||
|
||||
// Deactivate 停用角色
|
||||
func (r *Role) Deactivate() {
|
||||
r.IsActive = false
|
||||
r.UpdatedAt = nowPtr()
|
||||
}
|
||||
|
||||
// IncrementVersion 递增版本号(用于乐观锁)
|
||||
func (r *Role) IncrementVersion() {
|
||||
r.Version++
|
||||
r.UpdatedAt = nowPtr()
|
||||
}
|
||||
|
||||
// SetParentRole 设置父角色
|
||||
func (r *Role) SetParentRole(parentID int64) {
|
||||
r.ParentRoleID = &parentID
|
||||
}
|
||||
|
||||
// SetScopes 设置角色关联的Scope列表
|
||||
func (r *Role) SetScopes(scopes []string) {
|
||||
r.Scopes = scopes
|
||||
}
|
||||
|
||||
// AddScope 添加一个Scope
|
||||
func (r *Role) AddScope(scope string) {
|
||||
for _, s := range r.Scopes {
|
||||
if s == scope {
|
||||
return
|
||||
}
|
||||
}
|
||||
r.Scopes = append(r.Scopes, scope)
|
||||
}
|
||||
|
||||
// RemoveScope 移除一个Scope
|
||||
func (r *Role) RemoveScope(scope string) {
|
||||
newScopes := make([]string, 0, len(r.Scopes))
|
||||
for _, s := range r.Scopes {
|
||||
if s != scope {
|
||||
newScopes = append(newScopes, s)
|
||||
}
|
||||
}
|
||||
r.Scopes = newScopes
|
||||
}
|
||||
|
||||
// HasScope 检查角色是否拥有指定Scope
|
||||
func (r *Role) HasScope(scope string) bool {
|
||||
for _, s := range r.Scopes {
|
||||
if s == scope || s == "*" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ToRoleScopeInfo 转换为RoleScopeInfo结构(用于API响应)
|
||||
func (r *Role) ToRoleScopeInfo() *RoleScopeInfo {
|
||||
return &RoleScopeInfo{
|
||||
RoleCode: r.Code,
|
||||
RoleName: r.Name,
|
||||
RoleType: r.Type,
|
||||
Level: r.Level,
|
||||
Scopes: r.Scopes,
|
||||
}
|
||||
}
|
||||
|
||||
// RoleScopeInfo 角色的Scope信息(用于API响应)
|
||||
type RoleScopeInfo struct {
|
||||
RoleCode string `json:"role_code"`
|
||||
RoleName string `json:"role_name"`
|
||||
RoleType string `json:"role_type"`
|
||||
Level int `json:"level"`
|
||||
Scopes []string `json:"scopes,omitempty"`
|
||||
}
|
||||
|
||||
// generateRequestID 生成请求追踪ID
|
||||
func generateRequestID() string {
|
||||
b := make([]byte, 16)
|
||||
rand.Read(b)
|
||||
return hex.EncodeToString(b)
|
||||
}
|
||||
|
||||
// nowPtr 返回当前时间的指针
|
||||
func nowPtr() *time.Time {
|
||||
t := time.Now()
|
||||
return &t
|
||||
}
|
||||
152
supply-api/internal/iam/model/role_scope.go
Normal file
152
supply-api/internal/iam/model/role_scope.go
Normal file
@@ -0,0 +1,152 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// RoleScopeMapping 角色-Scope关联模型
|
||||
// 对应数据库 iam_role_scopes 表
|
||||
type RoleScopeMapping struct {
|
||||
ID int64 // 主键ID
|
||||
RoleID int64 // 角色ID (FK -> iam_roles.id)
|
||||
ScopeID int64 // ScopeID (FK -> iam_scopes.id)
|
||||
IsActive bool // 是否激活
|
||||
|
||||
// 审计字段
|
||||
RequestID string // 请求追踪ID
|
||||
CreatedIP string // 创建者IP
|
||||
Version int // 乐观锁版本号
|
||||
|
||||
// 时间戳
|
||||
CreatedAt *time.Time // 创建时间
|
||||
}
|
||||
|
||||
// NewRoleScopeMapping 创建新的角色-Scope映射
|
||||
func NewRoleScopeMapping(roleID, scopeID int64) *RoleScopeMapping {
|
||||
now := time.Now()
|
||||
return &RoleScopeMapping{
|
||||
RoleID: roleID,
|
||||
ScopeID: scopeID,
|
||||
IsActive: true,
|
||||
RequestID: generateRequestID(),
|
||||
Version: 1,
|
||||
CreatedAt: &now,
|
||||
}
|
||||
}
|
||||
|
||||
// NewRoleScopeMappingWithAudit 创建带审计信息的角色-Scope映射
|
||||
func NewRoleScopeMappingWithAudit(roleID, scopeID int64, requestID, createdIP string) *RoleScopeMapping {
|
||||
now := time.Now()
|
||||
return &RoleScopeMapping{
|
||||
RoleID: roleID,
|
||||
ScopeID: scopeID,
|
||||
IsActive: true,
|
||||
RequestID: requestID,
|
||||
CreatedIP: createdIP,
|
||||
Version: 1,
|
||||
CreatedAt: &now,
|
||||
}
|
||||
}
|
||||
|
||||
// Revoke 撤销角色-Scope映射
|
||||
func (m *RoleScopeMapping) Revoke() {
|
||||
m.IsActive = false
|
||||
}
|
||||
|
||||
// Grant 授予角色-Scope映射
|
||||
func (m *RoleScopeMapping) Grant() {
|
||||
m.IsActive = true
|
||||
}
|
||||
|
||||
// IncrementVersion 递增版本号
|
||||
func (m *RoleScopeMapping) IncrementVersion() {
|
||||
m.Version++
|
||||
}
|
||||
|
||||
// GrantScopeList 批量授予Scope
|
||||
func GrantScopeList(roleID int64, scopeIDs []int64) []*RoleScopeMapping {
|
||||
mappings := make([]*RoleScopeMapping, 0, len(scopeIDs))
|
||||
for _, scopeID := range scopeIDs {
|
||||
mapping := NewRoleScopeMapping(roleID, scopeID)
|
||||
mappings = append(mappings, mapping)
|
||||
}
|
||||
return mappings
|
||||
}
|
||||
|
||||
// RevokeAll 撤销所有映射
|
||||
func RevokeAll(mappings []*RoleScopeMapping) {
|
||||
for _, mapping := range mappings {
|
||||
mapping.Revoke()
|
||||
}
|
||||
}
|
||||
|
||||
// GetActiveScopeIDs 从映射列表中获取活跃的Scope ID列表
|
||||
func GetActiveScopeIDs(mappings []*RoleScopeMapping) []int64 {
|
||||
activeIDs := make([]int64, 0, len(mappings))
|
||||
for _, mapping := range mappings {
|
||||
if mapping.IsActive {
|
||||
activeIDs = append(activeIDs, mapping.ScopeID)
|
||||
}
|
||||
}
|
||||
return activeIDs
|
||||
}
|
||||
|
||||
// GetInactiveScopeIDs 从映射列表中获取非活跃的Scope ID列表
|
||||
func GetInactiveScopeIDs(mappings []*RoleScopeMapping) []int64 {
|
||||
inactiveIDs := make([]int64, 0, len(mappings))
|
||||
for _, mapping := range mappings {
|
||||
if !mapping.IsActive {
|
||||
inactiveIDs = append(inactiveIDs, mapping.ScopeID)
|
||||
}
|
||||
}
|
||||
return inactiveIDs
|
||||
}
|
||||
|
||||
// FilterActiveMappings 过滤出活跃的映射
|
||||
func FilterActiveMappings(mappings []*RoleScopeMapping) []*RoleScopeMapping {
|
||||
active := make([]*RoleScopeMapping, 0, len(mappings))
|
||||
for _, mapping := range mappings {
|
||||
if mapping.IsActive {
|
||||
active = append(active, mapping)
|
||||
}
|
||||
}
|
||||
return active
|
||||
}
|
||||
|
||||
// FilterMappingsByRole 过滤出指定角色的映射
|
||||
func FilterMappingsByRole(mappings []*RoleScopeMapping, roleID int64) []*RoleScopeMapping {
|
||||
filtered := make([]*RoleScopeMapping, 0, len(mappings))
|
||||
for _, mapping := range mappings {
|
||||
if mapping.RoleID == roleID {
|
||||
filtered = append(filtered, mapping)
|
||||
}
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
// FilterMappingsByScope 过滤出指定Scope的映射
|
||||
func FilterMappingsByScope(mappings []*RoleScopeMapping, scopeID int64) []*RoleScopeMapping {
|
||||
filtered := make([]*RoleScopeMapping, 0, len(mappings))
|
||||
for _, mapping := range mappings {
|
||||
if mapping.ScopeID == scopeID {
|
||||
filtered = append(filtered, mapping)
|
||||
}
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
// RoleScopeMappingInfo 角色-Scope映射信息(用于API响应)
|
||||
type RoleScopeMappingInfo struct {
|
||||
RoleID int64 `json:"role_id"`
|
||||
ScopeID int64 `json:"scope_id"`
|
||||
IsActive bool `json:"is_active"`
|
||||
}
|
||||
|
||||
// ToInfo 转换为映射信息
|
||||
func (m *RoleScopeMapping) ToInfo() *RoleScopeMappingInfo {
|
||||
return &RoleScopeMappingInfo{
|
||||
RoleID: m.RoleID,
|
||||
ScopeID: m.ScopeID,
|
||||
IsActive: m.IsActive,
|
||||
}
|
||||
}
|
||||
157
supply-api/internal/iam/model/role_scope_test.go
Normal file
157
supply-api/internal/iam/model/role_scope_test.go
Normal file
@@ -0,0 +1,157 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// TestRoleScopeMapping_GrantScope 测试授予Scope
|
||||
func TestRoleScopeMapping_GrantScope(t *testing.T) {
|
||||
// arrange
|
||||
role := NewRole("operator", "运维人员", RoleTypePlatform, 30)
|
||||
role.ID = 1
|
||||
scope1 := NewScope("platform:read", "读取平台配置", ScopeTypePlatform)
|
||||
scope1.ID = 1
|
||||
scope2 := NewScope("platform:write", "修改平台配置", ScopeTypePlatform)
|
||||
scope2.ID = 2
|
||||
|
||||
// act
|
||||
roleScopeMapping := NewRoleScopeMapping(role.ID, scope1.ID)
|
||||
roleScopeMapping2 := NewRoleScopeMapping(role.ID, scope2.ID)
|
||||
|
||||
// assert
|
||||
assert.Equal(t, role.ID, roleScopeMapping.RoleID)
|
||||
assert.Equal(t, scope1.ID, roleScopeMapping.ScopeID)
|
||||
assert.NotEmpty(t, roleScopeMapping.RequestID)
|
||||
assert.Equal(t, 1, roleScopeMapping.Version)
|
||||
|
||||
assert.Equal(t, role.ID, roleScopeMapping2.RoleID)
|
||||
assert.Equal(t, scope2.ID, roleScopeMapping2.ScopeID)
|
||||
}
|
||||
|
||||
// TestRoleScopeMapping_RevokeScope 测试撤销Scope
|
||||
func TestRoleScopeMapping_RevokeScope(t *testing.T) {
|
||||
// arrange
|
||||
role := NewRole("viewer", "查看者", RoleTypePlatform, 10)
|
||||
role.ID = 1
|
||||
scope := NewScope("platform:read", "读取平台配置", ScopeTypePlatform)
|
||||
scope.ID = 1
|
||||
|
||||
// act
|
||||
roleScopeMapping := NewRoleScopeMapping(role.ID, scope.ID)
|
||||
roleScopeMapping.Revoke()
|
||||
|
||||
// assert
|
||||
assert.False(t, roleScopeMapping.IsActive, "revoked mapping should be inactive")
|
||||
}
|
||||
|
||||
// TestRoleScopeMapping_WithAudit 测试带审计字段的映射
|
||||
func TestRoleScopeMapping_WithAudit(t *testing.T) {
|
||||
// arrange
|
||||
roleID := int64(1)
|
||||
scopeID := int64(2)
|
||||
requestID := "req-role-scope-123"
|
||||
createdIP := "192.168.1.100"
|
||||
|
||||
// act
|
||||
mapping := NewRoleScopeMappingWithAudit(roleID, scopeID, requestID, createdIP)
|
||||
|
||||
// assert
|
||||
assert.Equal(t, roleID, mapping.RoleID)
|
||||
assert.Equal(t, scopeID, mapping.ScopeID)
|
||||
assert.Equal(t, requestID, mapping.RequestID)
|
||||
assert.Equal(t, createdIP, mapping.CreatedIP)
|
||||
assert.True(t, mapping.IsActive)
|
||||
}
|
||||
|
||||
// TestRoleScopeMapping_IncrementVersion 测试版本号递增
|
||||
func TestRoleScopeMapping_IncrementVersion(t *testing.T) {
|
||||
// arrange
|
||||
mapping := NewRoleScopeMapping(1, 1)
|
||||
originalVersion := mapping.Version
|
||||
|
||||
// act
|
||||
mapping.IncrementVersion()
|
||||
|
||||
// assert
|
||||
assert.Equal(t, originalVersion+1, mapping.Version)
|
||||
}
|
||||
|
||||
// TestRoleScopeMapping_IsActive 测试活跃状态
|
||||
func TestRoleScopeMapping_IsActive(t *testing.T) {
|
||||
// arrange
|
||||
mapping := NewRoleScopeMapping(1, 1)
|
||||
|
||||
// assert - 默认应该激活
|
||||
assert.True(t, mapping.IsActive)
|
||||
}
|
||||
|
||||
// TestRoleScopeMapping_UniqueConstraint 测试唯一性(同一个角色和Scope组合)
|
||||
func TestRoleScopeMapping_UniqueConstraint(t *testing.T) {
|
||||
// arrange
|
||||
roleID := int64(1)
|
||||
scopeID := int64(1)
|
||||
|
||||
// act
|
||||
mapping1 := NewRoleScopeMapping(roleID, scopeID)
|
||||
mapping2 := NewRoleScopeMapping(roleID, scopeID)
|
||||
|
||||
// assert - 两个映射应该有相同的 RoleID 和 ScopeID(代表唯一约束)
|
||||
assert.Equal(t, mapping1.RoleID, mapping2.RoleID)
|
||||
assert.Equal(t, mapping1.ScopeID, mapping2.ScopeID)
|
||||
}
|
||||
|
||||
// TestRoleScopeMapping_GrantScopeList 测试批量授予Scope
|
||||
func TestRoleScopeMapping_GrantScopeList(t *testing.T) {
|
||||
// arrange
|
||||
roleID := int64(1)
|
||||
scopeIDs := []int64{1, 2, 3, 4, 5}
|
||||
|
||||
// act
|
||||
mappings := GrantScopeList(roleID, scopeIDs)
|
||||
|
||||
// assert
|
||||
assert.Len(t, mappings, len(scopeIDs))
|
||||
for i, scopeID := range scopeIDs {
|
||||
assert.Equal(t, roleID, mappings[i].RoleID)
|
||||
assert.Equal(t, scopeID, mappings[i].ScopeID)
|
||||
assert.True(t, mappings[i].IsActive)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRoleScopeMapping_RevokeAll 测试撤销所有Scope(针对某个角色)
|
||||
func TestRoleScopeMapping_RevokeAll(t *testing.T) {
|
||||
// arrange
|
||||
roleID := int64(1)
|
||||
scopeIDs := []int64{1, 2, 3}
|
||||
mappings := GrantScopeList(roleID, scopeIDs)
|
||||
|
||||
// act
|
||||
RevokeAll(mappings)
|
||||
|
||||
// assert
|
||||
for _, mapping := range mappings {
|
||||
assert.False(t, mapping.IsActive, "all mappings should be revoked")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRoleScopeMapping_GetActiveScopes 测试获取活跃的Scope列表
|
||||
func TestRoleScopeMapping_GetActiveScopes(t *testing.T) {
|
||||
// arrange
|
||||
roleID := int64(1)
|
||||
scopeIDs := []int64{1, 2, 3}
|
||||
mappings := GrantScopeList(roleID, scopeIDs)
|
||||
|
||||
// 撤销中间的Scope
|
||||
mappings[1].Revoke()
|
||||
|
||||
// act
|
||||
activeScopes := GetActiveScopeIDs(mappings)
|
||||
|
||||
// assert
|
||||
assert.Len(t, activeScopes, 2)
|
||||
assert.Contains(t, activeScopes, int64(1))
|
||||
assert.Contains(t, activeScopes, int64(3))
|
||||
assert.NotContains(t, activeScopes, int64(2))
|
||||
}
|
||||
244
supply-api/internal/iam/model/role_test.go
Normal file
244
supply-api/internal/iam/model/role_test.go
Normal file
@@ -0,0 +1,244 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// TestRoleModel_NewRole_ValidInput 测试创建角色 - 有效输入
|
||||
func TestRoleModel_NewRole_ValidInput(t *testing.T) {
|
||||
// arrange
|
||||
roleCode := "org_admin"
|
||||
roleName := "组织管理员"
|
||||
roleType := "platform"
|
||||
level := 50
|
||||
|
||||
// act
|
||||
role := NewRole(roleCode, roleName, roleType, level)
|
||||
|
||||
// assert
|
||||
assert.Equal(t, roleCode, role.Code)
|
||||
assert.Equal(t, roleName, role.Name)
|
||||
assert.Equal(t, roleType, role.Type)
|
||||
assert.Equal(t, level, role.Level)
|
||||
assert.True(t, role.IsActive)
|
||||
assert.NotEmpty(t, role.RequestID)
|
||||
assert.Equal(t, 1, role.Version)
|
||||
}
|
||||
|
||||
// TestRoleModel_NewRole_DefaultFields 测试创建角色 - 验证默认字段
|
||||
func TestRoleModel_NewRole_DefaultFields(t *testing.T) {
|
||||
// arrange
|
||||
roleCode := "viewer"
|
||||
roleName := "查看者"
|
||||
roleType := "platform"
|
||||
level := 10
|
||||
|
||||
// act
|
||||
role := NewRole(roleCode, roleName, roleType, level)
|
||||
|
||||
// assert - 验证默认字段
|
||||
assert.Equal(t, 1, role.Version, "version should default to 1")
|
||||
assert.NotEmpty(t, role.RequestID, "request_id should be auto-generated")
|
||||
assert.True(t, role.IsActive, "is_active should default to true")
|
||||
assert.Nil(t, role.ParentRoleID, "parent_role_id should be nil for root roles")
|
||||
}
|
||||
|
||||
// TestRoleModel_NewRole_WithParent 测试创建角色 - 带父角色
|
||||
func TestRoleModel_NewRole_WithParent(t *testing.T) {
|
||||
// arrange
|
||||
parentRole := NewRole("viewer", "查看者", "platform", 10)
|
||||
parentRole.ID = 1
|
||||
|
||||
// act
|
||||
childRole := NewRoleWithParent("developer", "开发者", "platform", 20, parentRole.ID)
|
||||
|
||||
// assert
|
||||
assert.Equal(t, "developer", childRole.Code)
|
||||
assert.Equal(t, 20, childRole.Level)
|
||||
assert.NotNil(t, childRole.ParentRoleID)
|
||||
assert.Equal(t, parentRole.ID, *childRole.ParentRoleID)
|
||||
}
|
||||
|
||||
// TestRoleModel_NewRole_WithRequestID 测试创建角色 - 指定RequestID
|
||||
func TestRoleModel_NewRole_WithRequestID(t *testing.T) {
|
||||
// arrange
|
||||
requestID := "req-12345"
|
||||
|
||||
// act
|
||||
role := NewRoleWithRequestID("org_admin", "组织管理员", "platform", 50, requestID)
|
||||
|
||||
// assert
|
||||
assert.Equal(t, requestID, role.RequestID)
|
||||
}
|
||||
|
||||
// TestRoleModel_NewRole_AuditFields 测试创建角色 - 审计字段
|
||||
func TestRoleModel_NewRole_AuditFields(t *testing.T) {
|
||||
// arrange
|
||||
createdIP := "192.168.1.1"
|
||||
updatedIP := "192.168.1.2"
|
||||
|
||||
// act
|
||||
role := NewRoleWithAudit("supply_admin", "供应方管理员", "supply", 40, "req-123", createdIP, updatedIP)
|
||||
|
||||
// assert
|
||||
assert.Equal(t, createdIP, role.CreatedIP)
|
||||
assert.Equal(t, updatedIP, role.UpdatedIP)
|
||||
assert.Equal(t, 1, role.Version)
|
||||
}
|
||||
|
||||
// TestRoleModel_NewRole_Timestamps 测试创建角色 - 时间戳
|
||||
func TestRoleModel_NewRole_Timestamps(t *testing.T) {
|
||||
// arrange
|
||||
beforeCreate := time.Now()
|
||||
|
||||
// act
|
||||
role := NewRole("test_role", "测试角色", "platform", 10)
|
||||
_ = time.Now() // afterCreate not needed
|
||||
|
||||
// assert
|
||||
assert.NotNil(t, role.CreatedAt)
|
||||
assert.NotNil(t, role.UpdatedAt)
|
||||
assert.True(t, role.CreatedAt.After(beforeCreate) || role.CreatedAt.Equal(beforeCreate))
|
||||
assert.True(t, role.UpdatedAt.After(beforeCreate) || role.UpdatedAt.Equal(beforeCreate))
|
||||
}
|
||||
|
||||
// TestRoleModel_Activate 测试激活角色
|
||||
func TestRoleModel_Activate(t *testing.T) {
|
||||
// arrange
|
||||
role := NewRole("inactive_role", "非活跃角色", "platform", 10)
|
||||
role.IsActive = false
|
||||
|
||||
// act
|
||||
role.Activate()
|
||||
|
||||
// assert
|
||||
assert.True(t, role.IsActive)
|
||||
}
|
||||
|
||||
// TestRoleModel_Deactivate 测试停用角色
|
||||
func TestRoleModel_Deactivate(t *testing.T) {
|
||||
// arrange
|
||||
role := NewRole("active_role", "活跃角色", "platform", 10)
|
||||
|
||||
// act
|
||||
role.Deactivate()
|
||||
|
||||
// assert
|
||||
assert.False(t, role.IsActive)
|
||||
}
|
||||
|
||||
// TestRoleModel_IncrementVersion 测试版本号递增
|
||||
func TestRoleModel_IncrementVersion(t *testing.T) {
|
||||
// arrange
|
||||
role := NewRole("test_role", "测试角色", "platform", 10)
|
||||
originalVersion := role.Version
|
||||
|
||||
// act
|
||||
role.IncrementVersion()
|
||||
|
||||
// assert
|
||||
assert.Equal(t, originalVersion+1, role.Version)
|
||||
}
|
||||
|
||||
// TestRoleModel_RoleType_Platform 测试平台角色类型
|
||||
func TestRoleModel_RoleType_Platform(t *testing.T) {
|
||||
// arrange & act
|
||||
role := NewRole("super_admin", "超级管理员", RoleTypePlatform, 100)
|
||||
|
||||
// assert
|
||||
assert.Equal(t, RoleTypePlatform, role.Type)
|
||||
}
|
||||
|
||||
// TestRoleModel_RoleType_Supply 测试供应方角色类型
|
||||
func TestRoleModel_RoleType_Supply(t *testing.T) {
|
||||
// arrange & act
|
||||
role := NewRole("supply_admin", "供应方管理员", RoleTypeSupply, 40)
|
||||
|
||||
// assert
|
||||
assert.Equal(t, RoleTypeSupply, role.Type)
|
||||
}
|
||||
|
||||
// TestRoleModel_RoleType_Consumer 测试需求方角色类型
|
||||
func TestRoleModel_RoleType_Consumer(t *testing.T) {
|
||||
// arrange & act
|
||||
role := NewRole("consumer_admin", "需求方管理员", RoleTypeConsumer, 40)
|
||||
|
||||
// assert
|
||||
assert.Equal(t, RoleTypeConsumer, role.Type)
|
||||
}
|
||||
|
||||
// TestRoleModel_LevelHierarchy 测试角色层级关系
|
||||
func TestRoleModel_LevelHierarchy(t *testing.T) {
|
||||
// 测试设计文档中的层级关系
|
||||
// super_admin(100) > org_admin(50) > supply_admin(40) > operator(30) > developer/finops(20) > viewer(10)
|
||||
|
||||
// arrange
|
||||
superAdmin := NewRole("super_admin", "超级管理员", RoleTypePlatform, 100)
|
||||
orgAdmin := NewRole("org_admin", "组织管理员", RoleTypePlatform, 50)
|
||||
supplyAdmin := NewRole("supply_admin", "供应方管理员", RoleTypeSupply, 40)
|
||||
operator := NewRole("operator", "运维人员", RoleTypePlatform, 30)
|
||||
developer := NewRole("developer", "开发者", RoleTypePlatform, 20)
|
||||
viewer := NewRole("viewer", "查看者", RoleTypePlatform, 10)
|
||||
|
||||
// assert - 验证层级数值
|
||||
assert.Greater(t, superAdmin.Level, orgAdmin.Level)
|
||||
assert.Greater(t, orgAdmin.Level, supplyAdmin.Level)
|
||||
assert.Greater(t, supplyAdmin.Level, operator.Level)
|
||||
assert.Greater(t, operator.Level, developer.Level)
|
||||
assert.Greater(t, developer.Level, viewer.Level)
|
||||
}
|
||||
|
||||
// TestRoleModel_NewRole_EmptyCode 测试创建角色 - 空角色代码(应返回错误)
|
||||
func TestRoleModel_NewRole_EmptyCode(t *testing.T) {
|
||||
// arrange & act
|
||||
role, err := NewRoleWithValidation("", "测试角色", "platform", 10)
|
||||
|
||||
// assert
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, role)
|
||||
assert.Equal(t, ErrInvalidRoleCode, err)
|
||||
}
|
||||
|
||||
// TestRoleModel_NewRole_InvalidRoleType 测试创建角色 - 无效角色类型
|
||||
func TestRoleModel_NewRole_InvalidRoleType(t *testing.T) {
|
||||
// arrange & act
|
||||
role, err := NewRoleWithValidation("test_role", "测试角色", "invalid_type", 10)
|
||||
|
||||
// assert
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, role)
|
||||
assert.Equal(t, ErrInvalidRoleType, err)
|
||||
}
|
||||
|
||||
// TestRoleModel_NewRole_NegativeLevel 测试创建角色 - 负数层级
|
||||
func TestRoleModel_NewRole_NegativeLevel(t *testing.T) {
|
||||
// arrange & act
|
||||
role, err := NewRoleWithValidation("test_role", "测试角色", "platform", -1)
|
||||
|
||||
// assert
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, role)
|
||||
assert.Equal(t, ErrInvalidLevel, err)
|
||||
}
|
||||
|
||||
// TestRoleModel_ToRoleScopeInfo 测试角色转换为RoleScopeInfo
|
||||
func TestRoleModel_ToRoleScopeInfo(t *testing.T) {
|
||||
// arrange
|
||||
role := NewRole("org_admin", "组织管理员", RoleTypePlatform, 50)
|
||||
role.ID = 1
|
||||
role.Scopes = []string{"platform:read", "platform:write"}
|
||||
|
||||
// act
|
||||
roleScopeInfo := role.ToRoleScopeInfo()
|
||||
|
||||
// assert
|
||||
assert.Equal(t, "org_admin", roleScopeInfo.RoleCode)
|
||||
assert.Equal(t, "组织管理员", roleScopeInfo.RoleName)
|
||||
assert.Equal(t, 50, roleScopeInfo.Level)
|
||||
assert.Len(t, roleScopeInfo.Scopes, 2)
|
||||
assert.Contains(t, roleScopeInfo.Scopes, "platform:read")
|
||||
assert.Contains(t, roleScopeInfo.Scopes, "platform:write")
|
||||
}
|
||||
225
supply-api/internal/iam/model/scope.go
Normal file
225
supply-api/internal/iam/model/scope.go
Normal file
@@ -0,0 +1,225 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Scope类型常量
|
||||
const (
|
||||
ScopeTypePlatform = "platform"
|
||||
ScopeTypeSupply = "supply"
|
||||
ScopeTypeConsumer = "consumer"
|
||||
ScopeTypeRouter = "router"
|
||||
ScopeTypeBilling = "billing"
|
||||
)
|
||||
|
||||
// Scope错误定义
|
||||
var (
|
||||
ErrInvalidScopeCode = errors.New("invalid scope code: cannot be empty")
|
||||
ErrInvalidScopeType = errors.New("invalid scope type: must be platform, supply, consumer, router, or billing")
|
||||
)
|
||||
|
||||
// Scope Scope模型
|
||||
// 对应数据库 iam_scopes 表
|
||||
type Scope struct {
|
||||
ID int64 // 主键ID
|
||||
Code string // Scope代码 (unique): platform:read, supply:account:write
|
||||
Name string // Scope名称
|
||||
Type string // Scope类型: platform, supply, consumer, router, billing
|
||||
Description string // 描述
|
||||
IsActive bool // 是否激活
|
||||
|
||||
// 审计字段
|
||||
RequestID string // 请求追踪ID
|
||||
CreatedIP string // 创建者IP
|
||||
UpdatedIP string // 更新者IP
|
||||
Version int // 乐观锁版本号
|
||||
|
||||
// 时间戳
|
||||
CreatedAt *time.Time // 创建时间
|
||||
UpdatedAt *time.Time // 更新时间
|
||||
}
|
||||
|
||||
// NewScope 创建新Scope(基础构造函数)
|
||||
func NewScope(code, name, scopeType string) *Scope {
|
||||
now := time.Now()
|
||||
return &Scope{
|
||||
Code: code,
|
||||
Name: name,
|
||||
Type: scopeType,
|
||||
IsActive: true,
|
||||
RequestID: generateRequestID(),
|
||||
Version: 1,
|
||||
CreatedAt: &now,
|
||||
UpdatedAt: &now,
|
||||
}
|
||||
}
|
||||
|
||||
// NewScopeWithRequestID 创建带指定RequestID的Scope
|
||||
func NewScopeWithRequestID(code, name, scopeType string, requestID string) *Scope {
|
||||
scope := NewScope(code, name, scopeType)
|
||||
scope.RequestID = requestID
|
||||
return scope
|
||||
}
|
||||
|
||||
// NewScopeWithAudit 创建带审计信息的Scope
|
||||
func NewScopeWithAudit(code, name, scopeType string, requestID, createdIP, updatedIP string) *Scope {
|
||||
scope := NewScope(code, name, scopeType)
|
||||
scope.RequestID = requestID
|
||||
scope.CreatedIP = createdIP
|
||||
scope.UpdatedIP = updatedIP
|
||||
return scope
|
||||
}
|
||||
|
||||
// NewScopeWithValidation 创建Scope并进行验证
|
||||
func NewScopeWithValidation(code, name, scopeType string) (*Scope, error) {
|
||||
// 验证Scope代码
|
||||
if code == "" {
|
||||
return nil, ErrInvalidScopeCode
|
||||
}
|
||||
|
||||
// 验证Scope类型
|
||||
if !IsValidScopeType(scopeType) {
|
||||
return nil, ErrInvalidScopeType
|
||||
}
|
||||
|
||||
scope := NewScope(code, name, scopeType)
|
||||
return scope, nil
|
||||
}
|
||||
|
||||
// Activate 激活Scope
|
||||
func (s *Scope) Activate() {
|
||||
s.IsActive = true
|
||||
s.UpdatedAt = nowPtr()
|
||||
}
|
||||
|
||||
// Deactivate 停用Scope
|
||||
func (s *Scope) Deactivate() {
|
||||
s.IsActive = false
|
||||
s.UpdatedAt = nowPtr()
|
||||
}
|
||||
|
||||
// IncrementVersion 递增版本号(用于乐观锁)
|
||||
func (s *Scope) IncrementVersion() {
|
||||
s.Version++
|
||||
s.UpdatedAt = nowPtr()
|
||||
}
|
||||
|
||||
// IsWildcard 检查是否为通配符Scope
|
||||
func (s *Scope) IsWildcard() bool {
|
||||
return s.Code == "*"
|
||||
}
|
||||
|
||||
// ToScopeInfo 转换为ScopeInfo结构(用于API响应)
|
||||
func (s *Scope) ToScopeInfo() *ScopeInfo {
|
||||
return &ScopeInfo{
|
||||
ScopeCode: s.Code,
|
||||
ScopeName: s.Name,
|
||||
ScopeType: s.Type,
|
||||
IsActive: s.IsActive,
|
||||
}
|
||||
}
|
||||
|
||||
// ScopeInfo Scope信息(用于API响应)
|
||||
type ScopeInfo struct {
|
||||
ScopeCode string `json:"scope_code"`
|
||||
ScopeName string `json:"scope_name"`
|
||||
ScopeType string `json:"scope_type"`
|
||||
IsActive bool `json:"is_active"`
|
||||
}
|
||||
|
||||
// IsValidScopeType 验证Scope类型是否有效
|
||||
func IsValidScopeType(scopeType string) bool {
|
||||
switch scopeType {
|
||||
case ScopeTypePlatform, ScopeTypeSupply, ScopeTypeConsumer, ScopeTypeRouter, ScopeTypeBilling:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// GetScopeTypeFromCode 从Scope Code推断Scope类型
|
||||
// 例如: platform:read -> platform, supply:account:write -> supply, consumer:apikey:create -> consumer
|
||||
func GetScopeTypeFromCode(scopeCode string) string {
|
||||
parts := strings.SplitN(scopeCode, ":", 2)
|
||||
if len(parts) < 1 {
|
||||
return ""
|
||||
}
|
||||
|
||||
prefix := parts[0]
|
||||
switch prefix {
|
||||
case "platform", "tenant", "billing":
|
||||
return ScopeTypePlatform
|
||||
case "supply":
|
||||
return ScopeTypeSupply
|
||||
case "consumer":
|
||||
return ScopeTypeConsumer
|
||||
case "router":
|
||||
return ScopeTypeRouter
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// PredefinedScopes 预定义的Scope列表
|
||||
var PredefinedScopes = []*Scope{
|
||||
// Platform Scopes
|
||||
{Code: "platform:read", Name: "读取平台配置", Type: ScopeTypePlatform},
|
||||
{Code: "platform:write", Name: "修改平台配置", Type: ScopeTypePlatform},
|
||||
{Code: "platform:admin", Name: "平台级管理", Type: ScopeTypePlatform},
|
||||
{Code: "platform:audit:read", Name: "读取审计日志", Type: ScopeTypePlatform},
|
||||
{Code: "platform:audit:export", Name: "导出审计日志", Type: ScopeTypePlatform},
|
||||
|
||||
// Tenant Scopes (属于platform类型)
|
||||
{Code: "tenant:read", Name: "读取租户信息", Type: ScopeTypePlatform},
|
||||
{Code: "tenant:write", Name: "修改租户配置", Type: ScopeTypePlatform},
|
||||
{Code: "tenant:member:manage", Name: "管理租户成员", Type: ScopeTypePlatform},
|
||||
{Code: "tenant:billing:write", Name: "修改账单设置", Type: ScopeTypePlatform},
|
||||
|
||||
// Supply Scopes
|
||||
{Code: "supply:account:read", Name: "读取供应账号", Type: ScopeTypeSupply},
|
||||
{Code: "supply:account:write", Name: "管理供应账号", Type: ScopeTypeSupply},
|
||||
{Code: "supply:package:read", Name: "读取套餐信息", Type: ScopeTypeSupply},
|
||||
{Code: "supply:package:write", Name: "管理套餐", Type: ScopeTypeSupply},
|
||||
{Code: "supply:package:publish", Name: "发布套餐", Type: ScopeTypeSupply},
|
||||
{Code: "supply:package:offline", Name: "下架套餐", Type: ScopeTypeSupply},
|
||||
{Code: "supply:settlement:withdraw", Name: "提现", Type: ScopeTypeSupply},
|
||||
{Code: "supply:credential:manage", Name: "管理凭证", Type: ScopeTypeSupply},
|
||||
|
||||
// Consumer Scopes
|
||||
{Code: "consumer:account:read", Name: "读取账户信息", Type: ScopeTypeConsumer},
|
||||
{Code: "consumer:account:write", Name: "管理账户", Type: ScopeTypeConsumer},
|
||||
{Code: "consumer:apikey:create", Name: "创建API Key", Type: ScopeTypeConsumer},
|
||||
{Code: "consumer:apikey:read", Name: "读取API Key", Type: ScopeTypeConsumer},
|
||||
{Code: "consumer:apikey:revoke", Name: "吊销API Key", Type: ScopeTypeConsumer},
|
||||
{Code: "consumer:usage:read", Name: "读取使用量", Type: ScopeTypeConsumer},
|
||||
|
||||
// Billing Scopes
|
||||
{Code: "billing:read", Name: "读取账单", Type: ScopeTypeBilling},
|
||||
{Code: "billing:write", Name: "修改账单设置", Type: ScopeTypeBilling},
|
||||
|
||||
// Router Scopes
|
||||
{Code: "router:invoke", Name: "调用模型", Type: ScopeTypeRouter},
|
||||
{Code: "router:model:list", Name: "列出可用模型", Type: ScopeTypeRouter},
|
||||
{Code: "router:model:config", Name: "配置路由策略", Type: ScopeTypeRouter},
|
||||
|
||||
// Wildcard Scope
|
||||
{Code: "*", Name: "通配符", Type: ScopeTypePlatform},
|
||||
}
|
||||
|
||||
// GetPredefinedScopeByCode 根据Code获取预定义Scope
|
||||
func GetPredefinedScopeByCode(code string) *Scope {
|
||||
for _, scope := range PredefinedScopes {
|
||||
if scope.Code == code {
|
||||
return scope
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsPredefinedScope 检查是否为预定义Scope
|
||||
func IsPredefinedScope(code string) bool {
|
||||
return GetPredefinedScopeByCode(code) != nil
|
||||
}
|
||||
247
supply-api/internal/iam/model/scope_test.go
Normal file
247
supply-api/internal/iam/model/scope_test.go
Normal file
@@ -0,0 +1,247 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// TestScopeModel_NewScope_ValidInput 测试创建Scope - 有效输入
|
||||
func TestScopeModel_NewScope_ValidInput(t *testing.T) {
|
||||
// arrange
|
||||
scopeCode := "platform:read"
|
||||
scopeName := "读取平台配置"
|
||||
scopeType := "platform"
|
||||
|
||||
// act
|
||||
scope := NewScope(scopeCode, scopeName, scopeType)
|
||||
|
||||
// assert
|
||||
assert.Equal(t, scopeCode, scope.Code)
|
||||
assert.Equal(t, scopeName, scope.Name)
|
||||
assert.Equal(t, scopeType, scope.Type)
|
||||
assert.True(t, scope.IsActive)
|
||||
assert.NotEmpty(t, scope.RequestID)
|
||||
assert.Equal(t, 1, scope.Version)
|
||||
}
|
||||
|
||||
// TestScopeModel_ScopeCategories 测试Scope分类
|
||||
func TestScopeModel_ScopeCategories(t *testing.T) {
|
||||
// arrange & act
|
||||
testCases := []struct {
|
||||
scopeCode string
|
||||
expectedType string
|
||||
}{
|
||||
// platform:* 分类
|
||||
{"platform:read", ScopeTypePlatform},
|
||||
{"platform:write", ScopeTypePlatform},
|
||||
{"platform:admin", ScopeTypePlatform},
|
||||
{"platform:audit:read", ScopeTypePlatform},
|
||||
{"platform:audit:export", ScopeTypePlatform},
|
||||
|
||||
// tenant:* 分类
|
||||
{"tenant:read", ScopeTypePlatform},
|
||||
{"tenant:write", ScopeTypePlatform},
|
||||
{"tenant:member:manage", ScopeTypePlatform},
|
||||
|
||||
// supply:* 分类
|
||||
{"supply:account:read", ScopeTypeSupply},
|
||||
{"supply:account:write", ScopeTypeSupply},
|
||||
{"supply:package:read", ScopeTypeSupply},
|
||||
{"supply:package:write", ScopeTypeSupply},
|
||||
|
||||
// consumer:* 分类
|
||||
{"consumer:account:read", ScopeTypeConsumer},
|
||||
{"consumer:apikey:create", ScopeTypeConsumer},
|
||||
|
||||
// billing:* 分类
|
||||
{"billing:read", ScopeTypePlatform},
|
||||
|
||||
// router:* 分类
|
||||
{"router:invoke", ScopeTypeRouter},
|
||||
{"router:model:list", ScopeTypeRouter},
|
||||
}
|
||||
|
||||
// assert
|
||||
for _, tc := range testCases {
|
||||
scope := NewScope(tc.scopeCode, tc.scopeCode, tc.expectedType)
|
||||
assert.Equal(t, tc.expectedType, scope.Type, "scope %s should be type %s", tc.scopeCode, tc.expectedType)
|
||||
}
|
||||
}
|
||||
|
||||
// TestScopeModel_NewScope_DefaultFields 测试创建Scope - 默认字段
|
||||
func TestScopeModel_NewScope_DefaultFields(t *testing.T) {
|
||||
// arrange
|
||||
scopeCode := "tenant:read"
|
||||
scopeName := "读取租户信息"
|
||||
scopeType := ScopeTypePlatform
|
||||
|
||||
// act
|
||||
scope := NewScope(scopeCode, scopeName, scopeType)
|
||||
|
||||
// assert - 验证默认字段
|
||||
assert.Equal(t, 1, scope.Version, "version should default to 1")
|
||||
assert.NotEmpty(t, scope.RequestID, "request_id should be auto-generated")
|
||||
assert.True(t, scope.IsActive, "is_active should default to true")
|
||||
}
|
||||
|
||||
// TestScopeModel_NewScope_WithRequestID 测试创建Scope - 指定RequestID
|
||||
func TestScopeModel_NewScope_WithRequestID(t *testing.T) {
|
||||
// arrange
|
||||
requestID := "req-54321"
|
||||
|
||||
// act
|
||||
scope := NewScopeWithRequestID("platform:read", "读取平台配置", ScopeTypePlatform, requestID)
|
||||
|
||||
// assert
|
||||
assert.Equal(t, requestID, scope.RequestID)
|
||||
}
|
||||
|
||||
// TestScopeModel_NewScope_AuditFields 测试创建Scope - 审计字段
|
||||
func TestScopeModel_NewScope_AuditFields(t *testing.T) {
|
||||
// arrange
|
||||
createdIP := "10.0.0.1"
|
||||
updatedIP := "10.0.0.2"
|
||||
|
||||
// act
|
||||
scope := NewScopeWithAudit("billing:read", "读取账单", ScopeTypePlatform, "req-789", createdIP, updatedIP)
|
||||
|
||||
// assert
|
||||
assert.Equal(t, createdIP, scope.CreatedIP)
|
||||
assert.Equal(t, updatedIP, scope.UpdatedIP)
|
||||
assert.Equal(t, 1, scope.Version)
|
||||
}
|
||||
|
||||
// TestScopeModel_Activate 测试激活Scope
|
||||
func TestScopeModel_Activate(t *testing.T) {
|
||||
// arrange
|
||||
scope := NewScope("test:scope", "测试Scope", ScopeTypePlatform)
|
||||
scope.IsActive = false
|
||||
|
||||
// act
|
||||
scope.Activate()
|
||||
|
||||
// assert
|
||||
assert.True(t, scope.IsActive)
|
||||
}
|
||||
|
||||
// TestScopeModel_Deactivate 测试停用Scope
|
||||
func TestScopeModel_Deactivate(t *testing.T) {
|
||||
// arrange
|
||||
scope := NewScope("test:scope", "测试Scope", ScopeTypePlatform)
|
||||
|
||||
// act
|
||||
scope.Deactivate()
|
||||
|
||||
// assert
|
||||
assert.False(t, scope.IsActive)
|
||||
}
|
||||
|
||||
// TestScopeModel_IncrementVersion 测试版本号递增
|
||||
func TestScopeModel_IncrementVersion(t *testing.T) {
|
||||
// arrange
|
||||
scope := NewScope("test:scope", "测试Scope", ScopeTypePlatform)
|
||||
originalVersion := scope.Version
|
||||
|
||||
// act
|
||||
scope.IncrementVersion()
|
||||
|
||||
// assert
|
||||
assert.Equal(t, originalVersion+1, scope.Version)
|
||||
}
|
||||
|
||||
// TestScopeModel_ScopeType_Platform 测试平台Scope类型
|
||||
func TestScopeModel_ScopeType_Platform(t *testing.T) {
|
||||
// arrange & act
|
||||
scope := NewScope("platform:admin", "平台管理", ScopeTypePlatform)
|
||||
|
||||
// assert
|
||||
assert.Equal(t, ScopeTypePlatform, scope.Type)
|
||||
}
|
||||
|
||||
// TestScopeModel_ScopeType_Supply 测试供应方Scope类型
|
||||
func TestScopeModel_ScopeType_Supply(t *testing.T) {
|
||||
// arrange & act
|
||||
scope := NewScope("supply:account:write", "管理供应账号", ScopeTypeSupply)
|
||||
|
||||
// assert
|
||||
assert.Equal(t, ScopeTypeSupply, scope.Type)
|
||||
}
|
||||
|
||||
// TestScopeModel_ScopeType_Consumer 测试需求方Scope类型
|
||||
func TestScopeModel_ScopeType_Consumer(t *testing.T) {
|
||||
// arrange & act
|
||||
scope := NewScope("consumer:apikey:create", "创建API Key", ScopeTypeConsumer)
|
||||
|
||||
// assert
|
||||
assert.Equal(t, ScopeTypeConsumer, scope.Type)
|
||||
}
|
||||
|
||||
// TestScopeModel_ScopeType_Router 测试路由Scope类型
|
||||
func TestScopeModel_ScopeType_Router(t *testing.T) {
|
||||
// arrange & act
|
||||
scope := NewScope("router:invoke", "调用模型", ScopeTypeRouter)
|
||||
|
||||
// assert
|
||||
assert.Equal(t, ScopeTypeRouter, scope.Type)
|
||||
}
|
||||
|
||||
// TestScopeModel_NewScope_EmptyCode 测试创建Scope - 空Scope代码(应返回错误)
|
||||
func TestScopeModel_NewScope_EmptyCode(t *testing.T) {
|
||||
// arrange & act
|
||||
scope, err := NewScopeWithValidation("", "测试Scope", ScopeTypePlatform)
|
||||
|
||||
// assert
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, scope)
|
||||
assert.Equal(t, ErrInvalidScopeCode, err)
|
||||
}
|
||||
|
||||
// TestScopeModel_NewScope_InvalidScopeType 测试创建Scope - 无效Scope类型
|
||||
func TestScopeModel_NewScope_InvalidScopeType(t *testing.T) {
|
||||
// arrange & act
|
||||
scope, err := NewScopeWithValidation("test:scope", "测试Scope", "invalid_type")
|
||||
|
||||
// assert
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, scope)
|
||||
assert.Equal(t, ErrInvalidScopeType, err)
|
||||
}
|
||||
|
||||
// TestScopeModel_ToScopeInfo 测试Scope转换为ScopeInfo
|
||||
func TestScopeModel_ToScopeInfo(t *testing.T) {
|
||||
// arrange
|
||||
scope := NewScope("platform:read", "读取平台配置", ScopeTypePlatform)
|
||||
scope.ID = 1
|
||||
|
||||
// act
|
||||
scopeInfo := scope.ToScopeInfo()
|
||||
|
||||
// assert
|
||||
assert.Equal(t, "platform:read", scopeInfo.ScopeCode)
|
||||
assert.Equal(t, "读取平台配置", scopeInfo.ScopeName)
|
||||
assert.Equal(t, ScopeTypePlatform, scopeInfo.ScopeType)
|
||||
assert.True(t, scopeInfo.IsActive)
|
||||
}
|
||||
|
||||
// TestScopeModel_GetScopeTypeFromCode 测试从Scope Code推断类型
|
||||
func TestScopeModel_GetScopeTypeFromCode(t *testing.T) {
|
||||
// arrange & act & assert
|
||||
assert.Equal(t, ScopeTypePlatform, GetScopeTypeFromCode("platform:read"))
|
||||
assert.Equal(t, ScopeTypePlatform, GetScopeTypeFromCode("tenant:read"))
|
||||
assert.Equal(t, ScopeTypeSupply, GetScopeTypeFromCode("supply:account:read"))
|
||||
assert.Equal(t, ScopeTypeConsumer, GetScopeTypeFromCode("consumer:apikey:read"))
|
||||
assert.Equal(t, ScopeTypeRouter, GetScopeTypeFromCode("router:invoke"))
|
||||
assert.Equal(t, ScopeTypePlatform, GetScopeTypeFromCode("billing:read"))
|
||||
}
|
||||
|
||||
// TestScopeModel_IsWildcardScope 测试通配符Scope
|
||||
func TestScopeModel_IsWildcardScope(t *testing.T) {
|
||||
// arrange
|
||||
wildcardScope := NewScope("*", "通配符", ScopeTypePlatform)
|
||||
normalScope := NewScope("platform:read", "读取平台配置", ScopeTypePlatform)
|
||||
|
||||
// assert
|
||||
assert.True(t, wildcardScope.IsWildcard())
|
||||
assert.False(t, normalScope.IsWildcard())
|
||||
}
|
||||
172
supply-api/internal/iam/model/user_role.go
Normal file
172
supply-api/internal/iam/model/user_role.go
Normal file
@@ -0,0 +1,172 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// UserRoleMapping 用户-角色关联模型
|
||||
// 对应数据库 iam_user_roles 表
|
||||
type UserRoleMapping struct {
|
||||
ID int64 // 主键ID
|
||||
UserID int64 // 用户ID
|
||||
RoleID int64 // 角色ID (FK -> iam_roles.id)
|
||||
TenantID int64 // 租户范围(NULL表示全局,0也代表全局)
|
||||
GrantedBy int64 // 授权人ID
|
||||
ExpiresAt *time.Time // 角色过期时间(nil表示永不过期)
|
||||
IsActive bool // 是否激活
|
||||
|
||||
// 审计字段
|
||||
RequestID string // 请求追踪ID
|
||||
CreatedIP string // 创建者IP
|
||||
UpdatedIP string // 更新者IP
|
||||
Version int // 乐观锁版本号
|
||||
|
||||
// 时间戳
|
||||
CreatedAt *time.Time // 创建时间
|
||||
UpdatedAt *time.Time // 更新时间
|
||||
GrantedAt *time.Time // 授权时间
|
||||
}
|
||||
|
||||
// NewUserRoleMapping 创建新的用户-角色映射
|
||||
func NewUserRoleMapping(userID, roleID, tenantID int64) *UserRoleMapping {
|
||||
now := time.Now()
|
||||
return &UserRoleMapping{
|
||||
UserID: userID,
|
||||
RoleID: roleID,
|
||||
TenantID: tenantID,
|
||||
IsActive: true,
|
||||
RequestID: generateRequestID(),
|
||||
Version: 1,
|
||||
CreatedAt: &now,
|
||||
UpdatedAt: &now,
|
||||
}
|
||||
}
|
||||
|
||||
// NewUserRoleMappingWithGrant 创建带授权信息的用户-角色映射
|
||||
func NewUserRoleMappingWithGrant(userID, roleID, tenantID, grantedBy int64, expiresAt *time.Time) *UserRoleMapping {
|
||||
now := time.Now()
|
||||
return &UserRoleMapping{
|
||||
UserID: userID,
|
||||
RoleID: roleID,
|
||||
TenantID: tenantID,
|
||||
GrantedBy: grantedBy,
|
||||
ExpiresAt: expiresAt,
|
||||
GrantedAt: &now,
|
||||
IsActive: true,
|
||||
RequestID: generateRequestID(),
|
||||
Version: 1,
|
||||
CreatedAt: &now,
|
||||
UpdatedAt: &now,
|
||||
}
|
||||
}
|
||||
|
||||
// HasRole 检查用户是否拥有指定角色
|
||||
func (m *UserRoleMapping) HasRole(roleID int64) bool {
|
||||
return m.RoleID == roleID && m.IsActive
|
||||
}
|
||||
|
||||
// IsGlobalRole 检查是否为全局角色(租户ID为0或nil)
|
||||
func (m *UserRoleMapping) IsGlobalRole() bool {
|
||||
return m.TenantID == 0
|
||||
}
|
||||
|
||||
// IsExpired 检查角色是否已过期
|
||||
func (m *UserRoleMapping) IsExpired() bool {
|
||||
if m.ExpiresAt == nil {
|
||||
return false // 永不过期
|
||||
}
|
||||
return time.Now().After(*m.ExpiresAt)
|
||||
}
|
||||
|
||||
// IsValid 检查角色分配是否有效(激活且未过期)
|
||||
func (m *UserRoleMapping) IsValid() bool {
|
||||
return m.IsActive && !m.IsExpired()
|
||||
}
|
||||
|
||||
// Revoke 撤销角色分配
|
||||
func (m *UserRoleMapping) Revoke() {
|
||||
m.IsActive = false
|
||||
m.UpdatedAt = nowPtr()
|
||||
}
|
||||
|
||||
// Grant 重新授予角色
|
||||
func (m *UserRoleMapping) Grant() {
|
||||
m.IsActive = true
|
||||
m.UpdatedAt = nowPtr()
|
||||
}
|
||||
|
||||
// IncrementVersion 递增版本号
|
||||
func (m *UserRoleMapping) IncrementVersion() {
|
||||
m.Version++
|
||||
m.UpdatedAt = nowPtr()
|
||||
}
|
||||
|
||||
// ExtendExpiration 延长过期时间
|
||||
func (m *UserRoleMapping) ExtendExpiration(newExpiresAt *time.Time) {
|
||||
m.ExpiresAt = newExpiresAt
|
||||
m.UpdatedAt = nowPtr()
|
||||
}
|
||||
|
||||
// UserRoleMappingInfo 用户-角色映射信息(用于API响应)
|
||||
type UserRoleMappingInfo struct {
|
||||
UserID int64 `json:"user_id"`
|
||||
RoleID int64 `json:"role_id"`
|
||||
TenantID int64 `json:"tenant_id"`
|
||||
IsActive bool `json:"is_active"`
|
||||
ExpiresAt *string `json:"expires_at,omitempty"`
|
||||
}
|
||||
|
||||
// ToInfo 转换为映射信息
|
||||
func (m *UserRoleMapping) ToInfo() *UserRoleMappingInfo {
|
||||
info := &UserRoleMappingInfo{
|
||||
UserID: m.UserID,
|
||||
RoleID: m.RoleID,
|
||||
TenantID: m.TenantID,
|
||||
IsActive: m.IsActive,
|
||||
}
|
||||
if m.ExpiresAt != nil {
|
||||
expStr := m.ExpiresAt.Format(time.RFC3339)
|
||||
info.ExpiresAt = &expStr
|
||||
}
|
||||
return info
|
||||
}
|
||||
|
||||
// UserRoleAssignmentInfo 用户角色分配详情(用于API响应)
|
||||
type UserRoleAssignmentInfo struct {
|
||||
UserID int64 `json:"user_id"`
|
||||
RoleCode string `json:"role_code"`
|
||||
RoleName string `json:"role_name"`
|
||||
TenantID int64 `json:"tenant_id"`
|
||||
GrantedBy int64 `json:"granted_by"`
|
||||
GrantedAt string `json:"granted_at"`
|
||||
ExpiresAt string `json:"expires_at,omitempty"`
|
||||
IsActive bool `json:"is_active"`
|
||||
IsExpired bool `json:"is_expired"`
|
||||
}
|
||||
|
||||
// UserRoleWithDetails 用户角色分配(含角色详情)
|
||||
type UserRoleWithDetails struct {
|
||||
*UserRoleMapping
|
||||
RoleCode string
|
||||
RoleName string
|
||||
}
|
||||
|
||||
// ToAssignmentInfo 转换为分配详情
|
||||
func (m *UserRoleWithDetails) ToAssignmentInfo() *UserRoleAssignmentInfo {
|
||||
info := &UserRoleAssignmentInfo{
|
||||
UserID: m.UserID,
|
||||
RoleCode: m.RoleCode,
|
||||
RoleName: m.RoleName,
|
||||
TenantID: m.TenantID,
|
||||
GrantedBy: m.GrantedBy,
|
||||
IsActive: m.IsActive,
|
||||
IsExpired: m.IsExpired(),
|
||||
}
|
||||
if m.GrantedAt != nil {
|
||||
info.GrantedAt = m.GrantedAt.Format(time.RFC3339)
|
||||
}
|
||||
if m.ExpiresAt != nil {
|
||||
info.ExpiresAt = m.ExpiresAt.Format(time.RFC3339)
|
||||
}
|
||||
return info
|
||||
}
|
||||
254
supply-api/internal/iam/model/user_role_test.go
Normal file
254
supply-api/internal/iam/model/user_role_test.go
Normal file
@@ -0,0 +1,254 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// TestUserRoleMapping_AssignRole 测试分配角色
|
||||
func TestUserRoleMapping_AssignRole(t *testing.T) {
|
||||
// arrange
|
||||
userID := int64(100)
|
||||
roleID := int64(1)
|
||||
tenantID := int64(1)
|
||||
|
||||
// act
|
||||
userRole := NewUserRoleMapping(userID, roleID, tenantID)
|
||||
|
||||
// assert
|
||||
assert.Equal(t, userID, userRole.UserID)
|
||||
assert.Equal(t, roleID, userRole.RoleID)
|
||||
assert.Equal(t, tenantID, userRole.TenantID)
|
||||
assert.True(t, userRole.IsActive)
|
||||
assert.NotEmpty(t, userRole.RequestID)
|
||||
assert.Equal(t, 1, userRole.Version)
|
||||
}
|
||||
|
||||
// TestUserRoleMapping_HasRole 测试用户是否拥有角色
|
||||
func TestUserRoleMapping_HasRole(t *testing.T) {
|
||||
// arrange
|
||||
userID := int64(100)
|
||||
role := NewRole("org_admin", "组织管理员", RoleTypePlatform, 50)
|
||||
role.ID = 1
|
||||
|
||||
// act
|
||||
userRole := NewUserRoleMapping(userID, role.ID, 0) // 0 表示全局角色
|
||||
|
||||
// assert
|
||||
assert.True(t, userRole.HasRole(role.ID))
|
||||
assert.False(t, userRole.HasRole(999)) // 不存在的角色ID
|
||||
}
|
||||
|
||||
// TestUserRoleMapping_GlobalRole 测试全局角色(tenantID为0)
|
||||
func TestUserRoleMapping_GlobalRole(t *testing.T) {
|
||||
// arrange
|
||||
userID := int64(100)
|
||||
roleID := int64(1)
|
||||
|
||||
// act - 全局角色
|
||||
userRole := NewUserRoleMapping(userID, roleID, 0)
|
||||
|
||||
// assert
|
||||
assert.Equal(t, int64(0), userRole.TenantID)
|
||||
assert.True(t, userRole.IsGlobalRole())
|
||||
}
|
||||
|
||||
// TestUserRoleMapping_TenantRole 测试租户角色
|
||||
func TestUserRoleMapping_TenantRole(t *testing.T) {
|
||||
// arrange
|
||||
userID := int64(100)
|
||||
roleID := int64(1)
|
||||
tenantID := int64(123)
|
||||
|
||||
// act
|
||||
userRole := NewUserRoleMapping(userID, roleID, tenantID)
|
||||
|
||||
// assert
|
||||
assert.Equal(t, tenantID, userRole.TenantID)
|
||||
assert.False(t, userRole.IsGlobalRole())
|
||||
}
|
||||
|
||||
// TestUserRoleMapping_WithGrantInfo 测试带授权信息的分配
|
||||
func TestUserRoleMapping_WithGrantInfo(t *testing.T) {
|
||||
// arrange
|
||||
userID := int64(100)
|
||||
roleID := int64(1)
|
||||
tenantID := int64(1)
|
||||
grantedBy := int64(1)
|
||||
expiresAt := time.Now().Add(24 * time.Hour)
|
||||
|
||||
// act
|
||||
userRole := NewUserRoleMappingWithGrant(userID, roleID, tenantID, grantedBy, &expiresAt)
|
||||
|
||||
// assert
|
||||
assert.Equal(t, userID, userRole.UserID)
|
||||
assert.Equal(t, roleID, userRole.RoleID)
|
||||
assert.Equal(t, grantedBy, userRole.GrantedBy)
|
||||
assert.NotNil(t, userRole.ExpiresAt)
|
||||
assert.NotNil(t, userRole.GrantedAt)
|
||||
}
|
||||
|
||||
// TestUserRoleMapping_Expired 测试过期角色
|
||||
func TestUserRoleMapping_Expired(t *testing.T) {
|
||||
// arrange
|
||||
userID := int64(100)
|
||||
roleID := int64(1)
|
||||
expiresAt := time.Now().Add(-1 * time.Hour) // 已过期
|
||||
|
||||
// act
|
||||
userRole := NewUserRoleMappingWithGrant(userID, roleID, 0, 1, &expiresAt)
|
||||
|
||||
// assert
|
||||
assert.True(t, userRole.IsExpired())
|
||||
}
|
||||
|
||||
// TestUserRoleMapping_NotExpired 测试未过期角色
|
||||
func TestUserRoleMapping_NotExpired(t *testing.T) {
|
||||
// arrange
|
||||
userID := int64(100)
|
||||
roleID := int64(1)
|
||||
expiresAt := time.Now().Add(24 * time.Hour) // 未过期
|
||||
|
||||
// act
|
||||
userRole := NewUserRoleMappingWithGrant(userID, roleID, 0, 1, &expiresAt)
|
||||
|
||||
// assert
|
||||
assert.False(t, userRole.IsExpired())
|
||||
}
|
||||
|
||||
// TestUserRoleMapping_NoExpiration 测试永不过期角色
|
||||
func TestUserRoleMapping_NoExpiration(t *testing.T) {
|
||||
// arrange
|
||||
userID := int64(100)
|
||||
roleID := int64(1)
|
||||
|
||||
// act
|
||||
userRole := NewUserRoleMapping(userID, roleID, 0)
|
||||
|
||||
// assert
|
||||
assert.Nil(t, userRole.ExpiresAt)
|
||||
assert.False(t, userRole.IsExpired())
|
||||
}
|
||||
|
||||
// TestUserRoleMapping_Revoke 测试撤销角色
|
||||
func TestUserRoleMapping_Revoke(t *testing.T) {
|
||||
// arrange
|
||||
userRole := NewUserRoleMapping(100, 1, 0)
|
||||
|
||||
// act
|
||||
userRole.Revoke()
|
||||
|
||||
// assert
|
||||
assert.False(t, userRole.IsActive)
|
||||
}
|
||||
|
||||
// TestUserRoleMapping_Grant 测试重新授予角色
|
||||
func TestUserRoleMapping_Grant(t *testing.T) {
|
||||
// arrange
|
||||
userRole := NewUserRoleMapping(100, 1, 0)
|
||||
userRole.Revoke()
|
||||
|
||||
// act
|
||||
userRole.Grant()
|
||||
|
||||
// assert
|
||||
assert.True(t, userRole.IsActive)
|
||||
}
|
||||
|
||||
// TestUserRoleMapping_IncrementVersion 测试版本号递增
|
||||
func TestUserRoleMapping_IncrementVersion(t *testing.T) {
|
||||
// arrange
|
||||
userRole := NewUserRoleMapping(100, 1, 0)
|
||||
originalVersion := userRole.Version
|
||||
|
||||
// act
|
||||
userRole.IncrementVersion()
|
||||
|
||||
// assert
|
||||
assert.Equal(t, originalVersion+1, userRole.Version)
|
||||
}
|
||||
|
||||
// TestUserRoleMapping_Valid 测试有效角色
|
||||
func TestUserRoleMapping_Valid(t *testing.T) {
|
||||
// arrange - 活跃且未过期的角色
|
||||
userRole := NewUserRoleMapping(100, 1, 0)
|
||||
expiresAt := time.Now().Add(24 * time.Hour)
|
||||
userRole.ExpiresAt = &expiresAt
|
||||
|
||||
// act & assert
|
||||
assert.True(t, userRole.IsValid())
|
||||
}
|
||||
|
||||
// TestUserRoleMapping_InvalidInactive 测试无效角色 - 未激活
|
||||
func TestUserRoleMapping_InvalidInactive(t *testing.T) {
|
||||
// arrange
|
||||
userRole := NewUserRoleMapping(100, 1, 0)
|
||||
userRole.Revoke()
|
||||
|
||||
// assert
|
||||
assert.False(t, userRole.IsValid())
|
||||
}
|
||||
|
||||
// TestUserRoleMapping_Valid_ExpiredButActive 测试过期但激活的角色
|
||||
func TestUserRoleMapping_Valid_ExpiredButActive(t *testing.T) {
|
||||
// arrange - 已过期但仍然激活的角色(应该无效)
|
||||
userRole := NewUserRoleMapping(100, 1, 0)
|
||||
expiresAt := time.Now().Add(-1 * time.Hour)
|
||||
userRole.ExpiresAt = &expiresAt
|
||||
|
||||
// assert - 即使IsActive为true,过期角色也应该无效
|
||||
assert.False(t, userRole.IsValid())
|
||||
}
|
||||
|
||||
// TestUserRoleMapping_UniqueConstraint 测试唯一性约束
|
||||
func TestUserRoleMapping_UniqueConstraint(t *testing.T) {
|
||||
// arrange
|
||||
userID := int64(100)
|
||||
roleID := int64(1)
|
||||
tenantID := int64(0) // 全局角色
|
||||
|
||||
// act
|
||||
userRole1 := NewUserRoleMapping(userID, roleID, tenantID)
|
||||
userRole2 := NewUserRoleMapping(userID, roleID, tenantID)
|
||||
|
||||
// assert - 同一个用户、角色、租户组合应该唯一
|
||||
assert.Equal(t, userRole1.UserID, userRole2.UserID)
|
||||
assert.Equal(t, userRole1.RoleID, userRole2.RoleID)
|
||||
assert.Equal(t, userRole1.TenantID, userRole2.TenantID)
|
||||
}
|
||||
|
||||
// TestUserRoleMapping_DifferentTenants 测试不同租户可以有相同角色
|
||||
func TestUserRoleMapping_DifferentTenants(t *testing.T) {
|
||||
// arrange
|
||||
userID := int64(100)
|
||||
roleID := int64(1)
|
||||
tenantID1 := int64(1)
|
||||
tenantID2 := int64(2)
|
||||
|
||||
// act
|
||||
userRole1 := NewUserRoleMapping(userID, roleID, tenantID1)
|
||||
userRole2 := NewUserRoleMapping(userID, roleID, tenantID2)
|
||||
|
||||
// assert - 不同租户的角色分配互不影响
|
||||
assert.Equal(t, tenantID1, userRole1.TenantID)
|
||||
assert.Equal(t, tenantID2, userRole2.TenantID)
|
||||
assert.NotEqual(t, userRole1.TenantID, userRole2.TenantID)
|
||||
}
|
||||
|
||||
// TestUserRoleMappingInfo_ToInfo 测试转换为UserRoleMappingInfo
|
||||
func TestUserRoleMappingInfo_ToInfo(t *testing.T) {
|
||||
// arrange
|
||||
userRole := NewUserRoleMapping(100, 1, 0)
|
||||
userRole.ID = 1
|
||||
|
||||
// act
|
||||
info := userRole.ToInfo()
|
||||
|
||||
// assert
|
||||
assert.Equal(t, int64(100), info.UserID)
|
||||
assert.Equal(t, int64(1), info.RoleID)
|
||||
assert.Equal(t, int64(0), info.TenantID)
|
||||
assert.True(t, info.IsActive)
|
||||
}
|
||||
291
supply-api/internal/iam/service/iam_service.go
Normal file
291
supply-api/internal/iam/service/iam_service.go
Normal file
@@ -0,0 +1,291 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
)
|
||||
|
||||
// 错误定义
|
||||
var (
|
||||
ErrRoleNotFound = errors.New("role not found")
|
||||
ErrDuplicateRoleCode = errors.New("role code already exists")
|
||||
ErrDuplicateAssignment = errors.New("user already has this role")
|
||||
ErrInvalidRequest = errors.New("invalid request")
|
||||
)
|
||||
|
||||
// Role 角色(简化的服务层模型)
|
||||
type Role struct {
|
||||
Code string
|
||||
Name string
|
||||
Type string
|
||||
Level int
|
||||
Description string
|
||||
IsActive bool
|
||||
Version int
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
// UserRole 用户角色(简化的服务层模型)
|
||||
type UserRole struct {
|
||||
UserID int64
|
||||
RoleCode string
|
||||
TenantID int64
|
||||
IsActive bool
|
||||
ExpiresAt *time.Time
|
||||
}
|
||||
|
||||
// CreateRoleRequest 创建角色请求
|
||||
type CreateRoleRequest struct {
|
||||
Code string
|
||||
Name string
|
||||
Type string
|
||||
Level int
|
||||
Description string
|
||||
Scopes []string
|
||||
ParentCode string
|
||||
}
|
||||
|
||||
// UpdateRoleRequest 更新角色请求
|
||||
type UpdateRoleRequest struct {
|
||||
Code string
|
||||
Name string
|
||||
Description string
|
||||
Scopes []string
|
||||
IsActive *bool
|
||||
}
|
||||
|
||||
// AssignRoleRequest 分配角色请求
|
||||
type AssignRoleRequest struct {
|
||||
UserID int64
|
||||
RoleCode string
|
||||
TenantID int64
|
||||
GrantedBy int64
|
||||
ExpiresAt *time.Time
|
||||
}
|
||||
|
||||
// IAMServiceInterface IAM服务接口
|
||||
type IAMServiceInterface interface {
|
||||
CreateRole(ctx context.Context, req *CreateRoleRequest) (*Role, error)
|
||||
GetRole(ctx context.Context, roleCode string) (*Role, error)
|
||||
UpdateRole(ctx context.Context, req *UpdateRoleRequest) (*Role, error)
|
||||
DeleteRole(ctx context.Context, roleCode string) error
|
||||
ListRoles(ctx context.Context, roleType string) ([]*Role, error)
|
||||
|
||||
AssignRole(ctx context.Context, req *AssignRoleRequest) (*UserRole, error)
|
||||
RevokeRole(ctx context.Context, userID int64, roleCode string, tenantID int64) error
|
||||
GetUserRoles(ctx context.Context, userID int64) ([]*UserRole, error)
|
||||
|
||||
CheckScope(ctx context.Context, userID int64, requiredScope string) (bool, error)
|
||||
GetUserScopes(ctx context.Context, userID int64) ([]string, error)
|
||||
}
|
||||
|
||||
// DefaultIAMService 默认IAM服务实现
|
||||
type DefaultIAMService struct {
|
||||
// 角色存储
|
||||
roleStore map[string]*Role
|
||||
// 用户角色存储: userID -> []*UserRole
|
||||
userRoleStore map[int64][]*UserRole
|
||||
// 角色Scope存储: roleCode -> []scopeCode
|
||||
roleScopeStore map[string][]string
|
||||
}
|
||||
|
||||
// NewDefaultIAMService 创建默认IAM服务
|
||||
func NewDefaultIAMService() *DefaultIAMService {
|
||||
return &DefaultIAMService{
|
||||
roleStore: make(map[string]*Role),
|
||||
userRoleStore: make(map[int64][]*UserRole),
|
||||
roleScopeStore: make(map[string][]string),
|
||||
}
|
||||
}
|
||||
|
||||
// CreateRole 创建角色
|
||||
func (s *DefaultIAMService) CreateRole(ctx context.Context, req *CreateRoleRequest) (*Role, error) {
|
||||
// 检查是否重复
|
||||
if _, exists := s.roleStore[req.Code]; exists {
|
||||
return nil, ErrDuplicateRoleCode
|
||||
}
|
||||
|
||||
// 验证角色类型
|
||||
if req.Type != "platform" && req.Type != "supply" && req.Type != "consumer" {
|
||||
return nil, ErrInvalidRequest
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
role := &Role{
|
||||
Code: req.Code,
|
||||
Name: req.Name,
|
||||
Type: req.Type,
|
||||
Level: req.Level,
|
||||
Description: req.Description,
|
||||
IsActive: true,
|
||||
Version: 1,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
|
||||
// 存储角色
|
||||
s.roleStore[req.Code] = role
|
||||
|
||||
// 存储角色Scope关联
|
||||
if len(req.Scopes) > 0 {
|
||||
s.roleScopeStore[req.Code] = req.Scopes
|
||||
}
|
||||
|
||||
return role, nil
|
||||
}
|
||||
|
||||
// GetRole 获取角色
|
||||
func (s *DefaultIAMService) GetRole(ctx context.Context, roleCode string) (*Role, error) {
|
||||
role, exists := s.roleStore[roleCode]
|
||||
if !exists {
|
||||
return nil, ErrRoleNotFound
|
||||
}
|
||||
return role, nil
|
||||
}
|
||||
|
||||
// UpdateRole 更新角色
|
||||
func (s *DefaultIAMService) UpdateRole(ctx context.Context, req *UpdateRoleRequest) (*Role, error) {
|
||||
role, exists := s.roleStore[req.Code]
|
||||
if !exists {
|
||||
return nil, ErrRoleNotFound
|
||||
}
|
||||
|
||||
// 更新字段
|
||||
if req.Name != "" {
|
||||
role.Name = req.Name
|
||||
}
|
||||
if req.Description != "" {
|
||||
role.Description = req.Description
|
||||
}
|
||||
if req.Scopes != nil {
|
||||
s.roleScopeStore[req.Code] = req.Scopes
|
||||
}
|
||||
if req.IsActive != nil {
|
||||
role.IsActive = *req.IsActive
|
||||
}
|
||||
|
||||
// 递增版本
|
||||
role.Version++
|
||||
role.UpdatedAt = time.Now()
|
||||
|
||||
return role, nil
|
||||
}
|
||||
|
||||
// DeleteRole 删除角色(软删除)
|
||||
func (s *DefaultIAMService) DeleteRole(ctx context.Context, roleCode string) error {
|
||||
role, exists := s.roleStore[roleCode]
|
||||
if !exists {
|
||||
return ErrRoleNotFound
|
||||
}
|
||||
|
||||
role.IsActive = false
|
||||
role.UpdatedAt = time.Now()
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListRoles 列出角色
|
||||
func (s *DefaultIAMService) ListRoles(ctx context.Context, roleType string) ([]*Role, error) {
|
||||
var roles []*Role
|
||||
for _, role := range s.roleStore {
|
||||
if roleType == "" || role.Type == roleType {
|
||||
roles = append(roles, role)
|
||||
}
|
||||
}
|
||||
return roles, nil
|
||||
}
|
||||
|
||||
// AssignRole 分配角色
|
||||
func (s *DefaultIAMService) AssignRole(ctx context.Context, req *AssignRoleRequest) (*UserRole, error) {
|
||||
// 检查角色是否存在
|
||||
if _, exists := s.roleStore[req.RoleCode]; !exists {
|
||||
return nil, ErrRoleNotFound
|
||||
}
|
||||
|
||||
// 检查是否已分配
|
||||
for _, ur := range s.userRoleStore[req.UserID] {
|
||||
if ur.RoleCode == req.RoleCode && ur.TenantID == req.TenantID && ur.IsActive {
|
||||
return nil, ErrDuplicateAssignment
|
||||
}
|
||||
}
|
||||
|
||||
userRole := &UserRole{
|
||||
UserID: req.UserID,
|
||||
RoleCode: req.RoleCode,
|
||||
TenantID: req.TenantID,
|
||||
IsActive: true,
|
||||
ExpiresAt: req.ExpiresAt,
|
||||
}
|
||||
|
||||
// 存储映射
|
||||
s.userRoleStore[req.UserID] = append(s.userRoleStore[req.UserID], userRole)
|
||||
|
||||
return userRole, nil
|
||||
}
|
||||
|
||||
// RevokeRole 撤销角色
|
||||
func (s *DefaultIAMService) RevokeRole(ctx context.Context, userID int64, roleCode string, tenantID int64) error {
|
||||
for _, ur := range s.userRoleStore[userID] {
|
||||
if ur.RoleCode == roleCode && ur.TenantID == tenantID {
|
||||
ur.IsActive = false
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return ErrRoleNotFound
|
||||
}
|
||||
|
||||
// GetUserRoles 获取用户角色
|
||||
func (s *DefaultIAMService) GetUserRoles(ctx context.Context, userID int64) ([]*UserRole, error) {
|
||||
var userRoles []*UserRole
|
||||
for _, ur := range s.userRoleStore[userID] {
|
||||
if ur.IsActive {
|
||||
userRoles = append(userRoles, ur)
|
||||
}
|
||||
}
|
||||
return userRoles, nil
|
||||
}
|
||||
|
||||
// CheckScope 检查用户是否有指定Scope
|
||||
func (s *DefaultIAMService) CheckScope(ctx context.Context, userID int64, requiredScope string) (bool, error) {
|
||||
scopes, err := s.GetUserScopes(ctx, userID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
for _, scope := range scopes {
|
||||
if scope == requiredScope || scope == "*" {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// GetUserScopes 获取用户所有Scope
|
||||
func (s *DefaultIAMService) GetUserScopes(ctx context.Context, userID int64) ([]string, error) {
|
||||
var allScopes []string
|
||||
seen := make(map[string]bool)
|
||||
|
||||
for _, ur := range s.userRoleStore[userID] {
|
||||
if ur.IsActive && (ur.ExpiresAt == nil || ur.ExpiresAt.After(time.Now())) {
|
||||
if scopes, exists := s.roleScopeStore[ur.RoleCode]; exists {
|
||||
for _, scope := range scopes {
|
||||
if !seen[scope] {
|
||||
seen[scope] = true
|
||||
allScopes = append(allScopes, scope)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return allScopes, nil
|
||||
}
|
||||
|
||||
// IsExpired 检查用户角色是否过期
|
||||
func (ur *UserRole) IsExpired() bool {
|
||||
if ur.ExpiresAt == nil {
|
||||
return false
|
||||
}
|
||||
return time.Now().After(*ur.ExpiresAt)
|
||||
}
|
||||
432
supply-api/internal/iam/service/iam_service_test.go
Normal file
432
supply-api/internal/iam/service/iam_service_test.go
Normal file
@@ -0,0 +1,432 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// MockIAMService 模拟IAM服务(用于测试)
|
||||
type MockIAMService struct {
|
||||
roles map[string]*Role
|
||||
userRoles map[int64][]*UserRole
|
||||
roleScopes map[string][]string
|
||||
}
|
||||
|
||||
func NewMockIAMService() *MockIAMService {
|
||||
return &MockIAMService{
|
||||
roles: make(map[string]*Role),
|
||||
userRoles: make(map[int64][]*UserRole),
|
||||
roleScopes: make(map[string][]string),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MockIAMService) CreateRole(ctx context.Context, req *CreateRoleRequest) (*Role, error) {
|
||||
if _, exists := m.roles[req.Code]; exists {
|
||||
return nil, ErrDuplicateRoleCode
|
||||
}
|
||||
role := &Role{
|
||||
Code: req.Code,
|
||||
Name: req.Name,
|
||||
Type: req.Type,
|
||||
Level: req.Level,
|
||||
IsActive: true,
|
||||
Version: 1,
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
m.roles[req.Code] = role
|
||||
if len(req.Scopes) > 0 {
|
||||
m.roleScopes[req.Code] = req.Scopes
|
||||
}
|
||||
return role, nil
|
||||
}
|
||||
|
||||
func (m *MockIAMService) GetRole(ctx context.Context, roleCode string) (*Role, error) {
|
||||
if role, exists := m.roles[roleCode]; exists {
|
||||
return role, nil
|
||||
}
|
||||
return nil, ErrRoleNotFound
|
||||
}
|
||||
|
||||
func (m *MockIAMService) UpdateRole(ctx context.Context, req *UpdateRoleRequest) (*Role, error) {
|
||||
role, exists := m.roles[req.Code]
|
||||
if !exists {
|
||||
return nil, ErrRoleNotFound
|
||||
}
|
||||
if req.Name != "" {
|
||||
role.Name = req.Name
|
||||
}
|
||||
if req.Description != "" {
|
||||
role.Description = req.Description
|
||||
}
|
||||
if req.Scopes != nil {
|
||||
m.roleScopes[req.Code] = req.Scopes
|
||||
}
|
||||
role.Version++
|
||||
role.UpdatedAt = time.Now()
|
||||
return role, nil
|
||||
}
|
||||
|
||||
func (m *MockIAMService) DeleteRole(ctx context.Context, roleCode string) error {
|
||||
role, exists := m.roles[roleCode]
|
||||
if !exists {
|
||||
return ErrRoleNotFound
|
||||
}
|
||||
role.IsActive = false
|
||||
role.UpdatedAt = time.Now()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockIAMService) ListRoles(ctx context.Context, roleType string) ([]*Role, error) {
|
||||
var roles []*Role
|
||||
for _, role := range m.roles {
|
||||
if roleType == "" || role.Type == roleType {
|
||||
roles = append(roles, role)
|
||||
}
|
||||
}
|
||||
return roles, nil
|
||||
}
|
||||
|
||||
func (m *MockIAMService) AssignRole(ctx context.Context, req *AssignRoleRequest) (*modelUserRoleMapping, error) {
|
||||
for _, ur := range m.userRoles[req.UserID] {
|
||||
if ur.RoleCode == req.RoleCode && ur.TenantID == req.TenantID && ur.IsActive {
|
||||
return nil, ErrDuplicateAssignment
|
||||
}
|
||||
}
|
||||
mapping := &modelUserRoleMapping{
|
||||
UserID: req.UserID,
|
||||
RoleCode: req.RoleCode,
|
||||
TenantID: req.TenantID,
|
||||
IsActive: true,
|
||||
}
|
||||
m.userRoles[req.UserID] = append(m.userRoles[req.UserID], &UserRole{
|
||||
UserID: req.UserID,
|
||||
RoleCode: req.RoleCode,
|
||||
TenantID: req.TenantID,
|
||||
IsActive: true,
|
||||
})
|
||||
return mapping, nil
|
||||
}
|
||||
|
||||
func (m *MockIAMService) RevokeRole(ctx context.Context, userID int64, roleCode string, tenantID int64) error {
|
||||
for _, ur := range m.userRoles[userID] {
|
||||
if ur.RoleCode == roleCode && ur.TenantID == tenantID {
|
||||
ur.IsActive = false
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return ErrRoleNotFound
|
||||
}
|
||||
|
||||
func (m *MockIAMService) GetUserRoles(ctx context.Context, userID int64) ([]*UserRole, error) {
|
||||
var userRoles []*UserRole
|
||||
for _, ur := range m.userRoles[userID] {
|
||||
if ur.IsActive {
|
||||
userRoles = append(userRoles, ur)
|
||||
}
|
||||
}
|
||||
return userRoles, nil
|
||||
}
|
||||
|
||||
func (m *MockIAMService) CheckScope(ctx context.Context, userID int64, requiredScope string) (bool, error) {
|
||||
scopes, err := m.GetUserScopes(ctx, userID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
for _, scope := range scopes {
|
||||
if scope == requiredScope || scope == "*" {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (m *MockIAMService) GetUserScopes(ctx context.Context, userID int64) ([]string, error) {
|
||||
var allScopes []string
|
||||
seen := make(map[string]bool)
|
||||
for _, ur := range m.userRoles[userID] {
|
||||
if ur.IsActive {
|
||||
if scopes, exists := m.roleScopes[ur.RoleCode]; exists {
|
||||
for _, scope := range scopes {
|
||||
if !seen[scope] {
|
||||
seen[scope] = true
|
||||
allScopes = append(allScopes, scope)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return allScopes, nil
|
||||
}
|
||||
|
||||
// modelUserRoleMapping 简化的用户角色映射(用于测试)
|
||||
type modelUserRoleMapping struct {
|
||||
UserID int64
|
||||
RoleCode string
|
||||
TenantID int64
|
||||
IsActive bool
|
||||
}
|
||||
|
||||
// TestIAMService_CreateRole_Success 测试创建角色成功
|
||||
func TestIAMService_CreateRole_Success(t *testing.T) {
|
||||
// arrange
|
||||
mockService := NewMockIAMService()
|
||||
req := &CreateRoleRequest{
|
||||
Code: "developer",
|
||||
Name: "开发者",
|
||||
Type: "platform",
|
||||
Level: 20,
|
||||
Scopes: []string{"platform:read", "router:invoke"},
|
||||
}
|
||||
|
||||
// act
|
||||
role, err := mockService.CreateRole(context.Background(), req)
|
||||
|
||||
// assert
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, role)
|
||||
assert.Equal(t, "developer", role.Code)
|
||||
assert.Equal(t, "开发者", role.Name)
|
||||
assert.Equal(t, "platform", role.Type)
|
||||
assert.Equal(t, 20, role.Level)
|
||||
assert.True(t, role.IsActive)
|
||||
}
|
||||
|
||||
// TestIAMService_CreateRole_DuplicateName 测试创建重复角色
|
||||
func TestIAMService_CreateRole_DuplicateName(t *testing.T) {
|
||||
// arrange
|
||||
mockService := NewMockIAMService()
|
||||
mockService.roles["developer"] = &Role{Code: "developer", Name: "开发者", Type: "platform", Level: 20}
|
||||
|
||||
req := &CreateRoleRequest{
|
||||
Code: "developer",
|
||||
Name: "开发者",
|
||||
Type: "platform",
|
||||
Level: 20,
|
||||
}
|
||||
|
||||
// act
|
||||
role, err := mockService.CreateRole(context.Background(), req)
|
||||
|
||||
// assert
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, role)
|
||||
assert.Equal(t, ErrDuplicateRoleCode, err)
|
||||
}
|
||||
|
||||
// TestIAMService_UpdateRole_Success 测试更新角色成功
|
||||
func TestIAMService_UpdateRole_Success(t *testing.T) {
|
||||
// arrange
|
||||
mockService := NewMockIAMService()
|
||||
existingRole := &Role{
|
||||
Code: "developer",
|
||||
Name: "开发者",
|
||||
Type: "platform",
|
||||
Level: 20,
|
||||
IsActive: true,
|
||||
Version: 1,
|
||||
}
|
||||
mockService.roles["developer"] = existingRole
|
||||
|
||||
req := &UpdateRoleRequest{
|
||||
Code: "developer",
|
||||
Name: "AI开发者",
|
||||
Description: "AI应用开发者",
|
||||
}
|
||||
|
||||
// act
|
||||
updatedRole, err := mockService.UpdateRole(context.Background(), req)
|
||||
|
||||
// assert
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, updatedRole)
|
||||
assert.Equal(t, "AI开发者", updatedRole.Name)
|
||||
assert.Equal(t, "AI应用开发者", updatedRole.Description)
|
||||
assert.Equal(t, 2, updatedRole.Version) // version 应该递增
|
||||
}
|
||||
|
||||
// TestIAMService_UpdateRole_NotFound 测试更新不存在的角色
|
||||
func TestIAMService_UpdateRole_NotFound(t *testing.T) {
|
||||
// arrange
|
||||
mockService := NewMockIAMService()
|
||||
|
||||
req := &UpdateRoleRequest{
|
||||
Code: "nonexistent",
|
||||
Name: "不存在",
|
||||
}
|
||||
|
||||
// act
|
||||
role, err := mockService.UpdateRole(context.Background(), req)
|
||||
|
||||
// assert
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, role)
|
||||
assert.Equal(t, ErrRoleNotFound, err)
|
||||
}
|
||||
|
||||
// TestIAMService_DeleteRole_Success 测试删除角色成功
|
||||
func TestIAMService_DeleteRole_Success(t *testing.T) {
|
||||
// arrange
|
||||
mockService := NewMockIAMService()
|
||||
mockService.roles["developer"] = &Role{Code: "developer", Name: "开发者", IsActive: true}
|
||||
|
||||
// act
|
||||
err := mockService.DeleteRole(context.Background(), "developer")
|
||||
|
||||
// assert
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, mockService.roles["developer"].IsActive) // 应该被停用而不是删除
|
||||
}
|
||||
|
||||
// TestIAMService_ListRoles 测试列出角色
|
||||
func TestIAMService_ListRoles(t *testing.T) {
|
||||
// arrange
|
||||
mockService := NewMockIAMService()
|
||||
mockService.roles["viewer"] = &Role{Code: "viewer", Type: "platform", Level: 10}
|
||||
mockService.roles["operator"] = &Role{Code: "operator", Type: "platform", Level: 30}
|
||||
mockService.roles["supply_admin"] = &Role{Code: "supply_admin", Type: "supply", Level: 40}
|
||||
|
||||
// act
|
||||
platformRoles, err := mockService.ListRoles(context.Background(), "platform")
|
||||
supplyRoles, err2 := mockService.ListRoles(context.Background(), "supply")
|
||||
allRoles, err3 := mockService.ListRoles(context.Background(), "")
|
||||
|
||||
// assert
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, platformRoles, 2)
|
||||
|
||||
assert.NoError(t, err2)
|
||||
assert.Len(t, supplyRoles, 1)
|
||||
|
||||
assert.NoError(t, err3)
|
||||
assert.Len(t, allRoles, 3)
|
||||
}
|
||||
|
||||
// TestIAMService_AssignRole 测试分配角色
|
||||
func TestIAMService_AssignRole(t *testing.T) {
|
||||
// arrange
|
||||
mockService := NewMockIAMService()
|
||||
mockService.roles["viewer"] = &Role{Code: "viewer", Type: "platform", Level: 10}
|
||||
|
||||
req := &AssignRoleRequest{
|
||||
UserID: 100,
|
||||
RoleCode: "viewer",
|
||||
TenantID: 1,
|
||||
}
|
||||
|
||||
// act
|
||||
mapping, err := mockService.AssignRole(context.Background(), req)
|
||||
|
||||
// assert
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, mapping)
|
||||
assert.Equal(t, int64(100), mapping.UserID)
|
||||
assert.Equal(t, "viewer", mapping.RoleCode)
|
||||
assert.True(t, mapping.IsActive)
|
||||
}
|
||||
|
||||
// TestIAMService_AssignRole_Duplicate 测试重复分配角色
|
||||
func TestIAMService_AssignRole_Duplicate(t *testing.T) {
|
||||
// arrange
|
||||
mockService := NewMockIAMService()
|
||||
mockService.roles["viewer"] = &Role{Code: "viewer", Type: "platform", Level: 10}
|
||||
mockService.userRoles[100] = []*UserRole{
|
||||
{UserID: 100, RoleCode: "viewer", TenantID: 1, IsActive: true},
|
||||
}
|
||||
|
||||
req := &AssignRoleRequest{
|
||||
UserID: 100,
|
||||
RoleCode: "viewer",
|
||||
TenantID: 1,
|
||||
}
|
||||
|
||||
// act
|
||||
mapping, err := mockService.AssignRole(context.Background(), req)
|
||||
|
||||
// assert
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, mapping)
|
||||
assert.Equal(t, ErrDuplicateAssignment, err)
|
||||
}
|
||||
|
||||
// TestIAMService_RevokeRole 测试撤销角色
|
||||
func TestIAMService_RevokeRole(t *testing.T) {
|
||||
// arrange
|
||||
mockService := NewMockIAMService()
|
||||
mockService.userRoles[100] = []*UserRole{
|
||||
{UserID: 100, RoleCode: "viewer", TenantID: 1, IsActive: true},
|
||||
}
|
||||
|
||||
// act
|
||||
err := mockService.RevokeRole(context.Background(), 100, "viewer", 1)
|
||||
|
||||
// assert
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, mockService.userRoles[100][0].IsActive)
|
||||
}
|
||||
|
||||
// TestIAMService_GetUserRoles 测试获取用户角色
|
||||
func TestIAMService_GetUserRoles(t *testing.T) {
|
||||
// arrange
|
||||
mockService := NewMockIAMService()
|
||||
mockService.userRoles[100] = []*UserRole{
|
||||
{UserID: 100, RoleCode: "viewer", TenantID: 0, IsActive: true},
|
||||
{UserID: 100, RoleCode: "developer", TenantID: 1, IsActive: true},
|
||||
}
|
||||
|
||||
// act
|
||||
roles, err := mockService.GetUserRoles(context.Background(), 100)
|
||||
|
||||
// assert
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, roles, 2)
|
||||
}
|
||||
|
||||
// TestIAMService_CheckScope 测试检查用户Scope
|
||||
func TestIAMService_CheckScope(t *testing.T) {
|
||||
// arrange
|
||||
mockService := NewMockIAMService()
|
||||
mockService.roles["viewer"] = &Role{Code: "viewer", Type: "platform", Level: 10}
|
||||
mockService.roleScopes["viewer"] = []string{"platform:read", "tenant:read"}
|
||||
mockService.userRoles[100] = []*UserRole{
|
||||
{UserID: 100, RoleCode: "viewer", TenantID: 0, IsActive: true},
|
||||
}
|
||||
|
||||
// act
|
||||
hasScope, err := mockService.CheckScope(context.Background(), 100, "platform:read")
|
||||
noScope, err2 := mockService.CheckScope(context.Background(), 100, "platform:write")
|
||||
|
||||
// assert
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, hasScope)
|
||||
|
||||
assert.NoError(t, err2)
|
||||
assert.False(t, noScope)
|
||||
}
|
||||
|
||||
// TestIAMService_GetUserScopes 测试获取用户所有Scope
|
||||
func TestIAMService_GetUserScopes(t *testing.T) {
|
||||
// arrange
|
||||
mockService := NewMockIAMService()
|
||||
mockService.roles["viewer"] = &Role{Code: "viewer", Type: "platform", Level: 10}
|
||||
mockService.roles["developer"] = &Role{Code: "developer", Type: "platform", Level: 20}
|
||||
mockService.roleScopes["viewer"] = []string{"platform:read", "tenant:read"}
|
||||
mockService.roleScopes["developer"] = []string{"router:invoke", "router:model:list"}
|
||||
mockService.userRoles[100] = []*UserRole{
|
||||
{UserID: 100, RoleCode: "viewer", TenantID: 0, IsActive: true},
|
||||
{UserID: 100, RoleCode: "developer", TenantID: 0, IsActive: true},
|
||||
}
|
||||
|
||||
// act
|
||||
scopes, err := mockService.GetUserScopes(context.Background(), 100)
|
||||
|
||||
// assert
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, scopes, "platform:read")
|
||||
assert.Contains(t, scopes, "tenant:read")
|
||||
assert.Contains(t, scopes, "router:invoke")
|
||||
assert.Contains(t, scopes, "router:model:list")
|
||||
}
|
||||
Reference in New Issue
Block a user