admin_backend/controller/experiment.go
hahwu ddae026231 ab test and notification
Co-authored-by: Copilot <copilot@github.com>
2026-04-24 17:23:55 +08:00

688 lines
17 KiB
Go

package controller
import (
"backend/util"
"encoding/json"
"fmt"
"net/http"
"strconv"
"strings"
"sync"
"time"
abtestdao "backend/sdk/abtest/dao"
abtestmodel "backend/sdk/abtest/model"
abtestservice "backend/sdk/abtest/service"
"github.com/gin-gonic/gin"
)
var (
experimentServiceOnce sync.Once
experimentService *abtestservice.Service
experimentServiceErr error
)
type experimentCreateVariantRequest struct {
Name string `json:"name" binding:"required"`
Description string `json:"description"`
Weight int `json:"weight"`
Params json.RawMessage `json:"params"`
}
type experimentCreateWhitelistRequest struct {
UserID string `json:"user_id" binding:"required"`
VariantName string `json:"variant_name" binding:"required"`
}
type experimentCreateRequest struct {
Name string `json:"name" binding:"required"`
Description string `json:"description"`
StartTime string `json:"start_time"`
EndTime string `json:"end_time"`
Status *int `json:"status"`
Variants []experimentCreateVariantRequest `json:"variants"`
Whitelist []experimentCreateWhitelistRequest `json:"whitelist"`
}
type optionalStringField struct {
Set bool
Value *string
}
func (f *optionalStringField) UnmarshalJSON(data []byte) error {
f.Set = true
if string(data) == "null" {
f.Value = nil
return nil
}
var value string
if err := json.Unmarshal(data, &value); err != nil {
return err
}
f.Value = &value
return nil
}
type experimentUpdateRequest struct {
Name *string `json:"name"`
Description *string `json:"description"`
StartTime optionalStringField `json:"start_time"`
EndTime optionalStringField `json:"end_time"`
Status *int `json:"status"`
Variants []experimentCreateVariantRequest `json:"variants"`
Whitelist []experimentCreateWhitelistRequest `json:"whitelist"`
}
func getExperimentService() (*abtestservice.Service, error) {
experimentServiceOnce.Do(func() {
db, err := util.ConnectMysql("log", "abtest")
if err != nil {
experimentServiceErr = err
return
}
experimentService = abtestservice.New(abtestdao.New(db))
})
return experimentService, experimentServiceErr
}
func ExperimentList(c *gin.Context) {
svc, err := getExperimentService()
if err != nil {
failed(c, err.Error())
return
}
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
var statusPtr *int
if statusRaw := c.Query("status"); statusRaw != "" {
status, convErr := strconv.Atoi(statusRaw)
if convErr != nil {
failed(c, "invalid status")
return
}
statusPtr = &status
}
list, total, err := svc.ListExperiments(statusPtr, page, pageSize)
if err != nil {
failed(c, err.Error())
return
}
success(c, gin.H{"list": list, "total": total, "page": page, "page_size": pageSize})
}
func ExperimentCreate(c *gin.Context) {
svc, err := getExperimentService()
if err != nil {
failed(c, err.Error())
return
}
var req experimentCreateRequest
if err := c.ShouldBindJSON(&req); err != nil {
failed(c, err.Error())
return
}
result, err := createExperimentAggregate(svc, &req)
if err != nil {
failed(c, err.Error())
return
}
util.AddAdminLog(c, "创建AB实验", req)
success(c, result)
}
func createExperimentAggregate(svc *abtestservice.Service, req *experimentCreateRequest) (gin.H, error) {
name := strings.TrimSpace(req.Name)
if name == "" {
return nil, fmt.Errorf("name is required")
}
startTime, err := parseExperimentTime(req.StartTime)
if err != nil {
return nil, fmt.Errorf("invalid start_time: %w", err)
}
endTime, err := parseExperimentTime(req.EndTime)
if err != nil {
return nil, fmt.Errorf("invalid end_time: %w", err)
}
if startTime != nil && endTime != nil && startTime.After(*endTime) {
return nil, fmt.Errorf("start_time must be earlier than end_time")
}
experiment, err := svc.CreateExperiment(&abtestmodel.CreateExperimentReq{
Name: name,
Description: req.Description,
})
if err != nil {
return nil, err
}
shouldRollback := true
defer func() {
if shouldRollback {
_ = svc.DeleteExperiment(experiment.ID)
}
}()
if req.Status != nil || startTime != nil || endTime != nil {
experiment, err = svc.UpdateExperiment(experiment.ID, &abtestmodel.UpdateExperimentReq{
Status: req.Status,
StartTime: startTime,
EndTime: endTime,
})
if err != nil {
return nil, err
}
}
variants := make([]*abtestmodel.Variant, 0, len(req.Variants))
variantsByName := make(map[string]*abtestmodel.Variant, len(req.Variants))
for _, variantReq := range req.Variants {
variantName := strings.TrimSpace(variantReq.Name)
if variantName == "" {
return nil, fmt.Errorf("variant name is required")
}
if _, exists := variantsByName[variantName]; exists {
return nil, fmt.Errorf("duplicate variant name: %s", variantName)
}
variant, createErr := svc.CreateVariant(experiment.ID, &abtestmodel.CreateVariantReq{
Name: variantName,
Description: variantReq.Description,
Weight: variantReq.Weight,
Params: variantReq.Params,
})
if createErr != nil {
return nil, createErr
}
variants = append(variants, variant)
variantsByName[variant.Name] = variant
}
whitelist := make([]*abtestmodel.Whitelist, 0, len(req.Whitelist))
for _, whitelistReq := range req.Whitelist {
userID := strings.TrimSpace(whitelistReq.UserID)
variantName := strings.TrimSpace(whitelistReq.VariantName)
if userID == "" {
return nil, fmt.Errorf("whitelist user_id is required")
}
variant, exists := variantsByName[variantName]
if !exists {
return nil, fmt.Errorf("whitelist variant_name not found: %s", variantName)
}
entry, addErr := svc.AddWhitelist(experiment.ID, &abtestmodel.CreateWhitelistReq{
UserID: userID,
VariantID: variant.ID,
})
if addErr != nil {
return nil, addErr
}
whitelist = append(whitelist, entry)
}
shouldRollback = false
return gin.H{
"experiment": experiment,
"variants": variants,
"whitelist": whitelist,
}, nil
}
func parseExperimentTime(raw string) (*time.Time, error) {
raw = strings.TrimSpace(raw)
if raw == "" {
return nil, nil
}
layouts := []string{
"2006-01-02 15:04:05",
time.RFC3339,
"2006-01-02T15:04:05",
}
for _, layout := range layouts {
parsed, err := time.ParseInLocation(layout, raw, time.Local)
if err == nil {
return &parsed, nil
}
}
return nil, fmt.Errorf("unsupported time format")
}
func ExperimentUpdate(c *gin.Context) {
svc, err := getExperimentService()
if err != nil {
failed(c, err.Error())
return
}
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
failed(c, "invalid id")
return
}
var req experimentUpdateRequest
if err := c.ShouldBindJSON(&req); err != nil {
failed(c, err.Error())
return
}
experiment, err := updateExperimentAggregate(svc, id, &req)
if err != nil {
status := http.StatusBadRequest
if err == abtestservice.ErrExperimentNotFound {
status = http.StatusNotFound
}
c.JSON(status, gin.H{"code": 1, "message": err.Error()})
return
}
util.AddAdminLog(c, "更新AB实验", gin.H{"id": id, "payload": req})
success(c, experiment)
}
func updateExperimentAggregate(svc *abtestservice.Service, experimentID int64, req *experimentUpdateRequest) (*abtestmodel.Experiment, error) {
if _, err := svc.GetExperiment(experimentID); err != nil {
return nil, err
}
updateReq := &abtestmodel.UpdateExperimentReq{
Name: req.Name,
Description: req.Description,
Status: req.Status,
}
if req.StartTime.Set {
if req.StartTime.Value == nil {
updateReq.ClearStartTime = true
} else {
parsed, err := parseExperimentTime(*req.StartTime.Value)
if err != nil {
return nil, fmt.Errorf("invalid start_time: %w", err)
}
updateReq.StartTime = parsed
}
}
if req.EndTime.Set {
if req.EndTime.Value == nil {
updateReq.ClearEndTime = true
} else {
parsed, err := parseExperimentTime(*req.EndTime.Value)
if err != nil {
return nil, fmt.Errorf("invalid end_time: %w", err)
}
updateReq.EndTime = parsed
}
}
if updateReq.StartTime != nil && updateReq.EndTime != nil && updateReq.StartTime.After(*updateReq.EndTime) {
return nil, fmt.Errorf("start_time must be earlier than end_time")
}
experiment, err := svc.UpdateExperiment(experimentID, updateReq)
if err != nil {
return nil, err
}
variantsProvided := req.Variants != nil
whitelistProvided := req.Whitelist != nil
if !variantsProvided && !whitelistProvided {
return experiment, nil
}
var variants []*abtestmodel.Variant
variantsByName := make(map[string]*abtestmodel.Variant)
if variantsProvided {
if err := svc.ResetExperimentRelations(experimentID); err != nil {
return nil, err
}
for _, variantReq := range req.Variants {
variantName := strings.TrimSpace(variantReq.Name)
if variantName == "" {
return nil, fmt.Errorf("variant name is required")
}
if _, exists := variantsByName[variantName]; exists {
return nil, fmt.Errorf("duplicate variant name: %s", variantName)
}
variant, createErr := svc.CreateVariant(experimentID, &abtestmodel.CreateVariantReq{
Name: variantName,
Description: variantReq.Description,
Weight: variantReq.Weight,
Params: variantReq.Params,
})
if createErr != nil {
return nil, createErr
}
variants = append(variants, variant)
variantsByName[variant.Name] = variant
}
} else if whitelistProvided {
variants, err = svc.ListVariants(experimentID)
if err != nil {
return nil, err
}
for _, variant := range variants {
variantsByName[variant.Name] = variant
}
if err := svc.ClearWhitelist(experimentID); err != nil {
return nil, err
}
}
if whitelistProvided {
for _, whitelistReq := range req.Whitelist {
userID := strings.TrimSpace(whitelistReq.UserID)
variantName := strings.TrimSpace(whitelistReq.VariantName)
if userID == "" {
return nil, fmt.Errorf("whitelist user_id is required")
}
variant, exists := variantsByName[variantName]
if !exists {
return nil, fmt.Errorf("whitelist variant_name not found: %s", variantName)
}
if _, err := svc.AddWhitelist(experimentID, &abtestmodel.CreateWhitelistReq{
UserID: userID,
VariantID: variant.ID,
}); err != nil {
return nil, err
}
}
}
return svc.GetExperiment(experimentID)
}
func ExperimentDelete(c *gin.Context) {
svc, err := getExperimentService()
if err != nil {
failed(c, err.Error())
return
}
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
failed(c, "invalid id")
return
}
if err := svc.DeleteExperiment(id); err != nil {
status := http.StatusBadRequest
if err == abtestservice.ErrExperimentNotFound {
status = http.StatusNotFound
}
c.JSON(status, gin.H{"code": 1, "message": err.Error()})
return
}
util.AddAdminLog(c, "删除AB实验", gin.H{"id": id})
success(c, nil)
}
func ExperimentVariantList(c *gin.Context) {
svc, err := getExperimentService()
if err != nil {
failed(c, err.Error())
return
}
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
failed(c, "invalid id")
return
}
list, err := svc.ListVariants(id)
if err != nil {
failed(c, err.Error())
return
}
success(c, list)
}
func ExperimentVariantCreate(c *gin.Context) {
svc, err := getExperimentService()
if err != nil {
failed(c, err.Error())
return
}
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
failed(c, "invalid id")
return
}
var req abtestmodel.CreateVariantReq
if err := c.ShouldBindJSON(&req); err != nil {
failed(c, err.Error())
return
}
variant, err := svc.CreateVariant(id, &req)
if err != nil {
status := http.StatusBadRequest
if err == abtestservice.ErrExperimentNotFound {
status = http.StatusNotFound
}
c.JSON(status, gin.H{"code": 1, "message": err.Error()})
return
}
util.AddAdminLog(c, "创建AB实验变体", gin.H{"experiment_id": id, "payload": req})
success(c, variant)
}
func ExperimentVariantUpdate(c *gin.Context) {
svc, err := getExperimentService()
if err != nil {
failed(c, err.Error())
return
}
variantID, err := strconv.ParseInt(c.Param("variantId"), 10, 64)
if err != nil {
failed(c, "invalid variantId")
return
}
var req abtestmodel.UpdateVariantReq
if err := c.ShouldBindJSON(&req); err != nil {
failed(c, err.Error())
return
}
variant, err := svc.UpdateVariant(variantID, &req)
if err != nil {
status := http.StatusBadRequest
if err == abtestservice.ErrVariantNotFound {
status = http.StatusNotFound
}
c.JSON(status, gin.H{"code": 1, "message": err.Error()})
return
}
util.AddAdminLog(c, "更新AB实验变体", gin.H{"variant_id": variantID, "payload": req})
success(c, variant)
}
func ExperimentVariantDelete(c *gin.Context) {
svc, err := getExperimentService()
if err != nil {
failed(c, err.Error())
return
}
variantID, err := strconv.ParseInt(c.Param("variantId"), 10, 64)
if err != nil {
failed(c, "invalid variantId")
return
}
if err := svc.DeleteVariant(variantID); err != nil {
status := http.StatusBadRequest
if err == abtestservice.ErrVariantNotFound {
status = http.StatusNotFound
}
c.JSON(status, gin.H{"code": 1, "message": err.Error()})
return
}
util.AddAdminLog(c, "删除AB实验变体", gin.H{"variant_id": variantID})
success(c, nil)
}
func ExperimentWhitelistList(c *gin.Context) {
svc, err := getExperimentService()
if err != nil {
failed(c, err.Error())
return
}
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
failed(c, "invalid id")
return
}
list, err := svc.ListWhitelist(id)
if err != nil {
failed(c, err.Error())
return
}
success(c, list)
}
func ExperimentWhitelistCreate(c *gin.Context) {
svc, err := getExperimentService()
if err != nil {
failed(c, err.Error())
return
}
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
failed(c, "invalid id")
return
}
var req abtestmodel.CreateWhitelistReq
if err := c.ShouldBindJSON(&req); err != nil {
failed(c, err.Error())
return
}
entry, err := svc.AddWhitelist(id, &req)
if err != nil {
status := http.StatusBadRequest
if err == abtestservice.ErrExperimentNotFound || err == abtestservice.ErrVariantNotFound {
status = http.StatusNotFound
}
c.JSON(status, gin.H{"code": 1, "message": err.Error()})
return
}
util.AddAdminLog(c, "新增AB白名单", gin.H{"experiment_id": id, "payload": req})
success(c, entry)
}
func ExperimentWhitelistBatchCreate(c *gin.Context) {
svc, err := getExperimentService()
if err != nil {
failed(c, err.Error())
return
}
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
failed(c, "invalid id")
return
}
var req abtestmodel.BatchCreateWhitelistReq
if err := c.ShouldBindJSON(&req); err != nil {
failed(c, err.Error())
return
}
list, err := svc.BatchAddWhitelist(id, &req)
if err != nil {
status := http.StatusBadRequest
if err == abtestservice.ErrExperimentNotFound || err == abtestservice.ErrVariantNotFound {
status = http.StatusNotFound
}
c.JSON(status, gin.H{"code": 1, "message": err.Error()})
return
}
util.AddAdminLog(c, "批量新增AB白名单", gin.H{"experiment_id": id, "payload": req})
success(c, list)
}
func ExperimentWhitelistDelete(c *gin.Context) {
svc, err := getExperimentService()
if err != nil {
failed(c, err.Error())
return
}
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
failed(c, "invalid id")
return
}
userID := c.Param("userId")
if userID == "" {
failed(c, "invalid userId")
return
}
if err := svc.RemoveWhitelist(id, userID); err != nil {
failed(c, err.Error())
return
}
util.AddAdminLog(c, "删除AB白名单", gin.H{"experiment_id": id, "user_id": userID})
success(c, nil)
}
func ExperimentResult(c *gin.Context) {
svc, err := getExperimentService()
if err != nil {
failed(c, err.Error())
return
}
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
failed(c, "invalid id")
return
}
result, err := svc.GetExperimentResult(id)
if err != nil {
status := http.StatusBadRequest
if err == abtestservice.ErrExperimentNotFound {
status = http.StatusNotFound
}
c.JSON(status, gin.H{"code": 1, "message": err.Error()})
return
}
success(c, result)
}
func UserExperimentGroups(c *gin.Context) {
svc, err := getExperimentService()
if err != nil {
failed(c, err.Error())
return
}
userID := strings.TrimSpace(c.Param("userId"))
if userID == "" {
failed(c, "invalid userId")
return
}
groups, err := svc.GetUserGroups(userID)
if err != nil {
failed(c, err.Error())
return
}
success(c, groups)
}