688 lines
17 KiB
Go
688 lines
17 KiB
Go
package controller
|
|
|
|
import (
|
|
"backend/util"
|
|
"encoding/json"
|
|
"fmt"
|
|
"net/http"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
abtestdao "backend/sdk/abtest/dao"
|
|
abtestmodel "backend/sdk/abtest/model"
|
|
abtestservice "backend/sdk/abtest/service"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
)
|
|
|
|
var (
|
|
experimentServiceOnce sync.Once
|
|
experimentService *abtestservice.Service
|
|
experimentServiceErr error
|
|
)
|
|
|
|
type experimentCreateVariantRequest struct {
|
|
Name string `json:"name" binding:"required"`
|
|
Description string `json:"description"`
|
|
Weight int `json:"weight"`
|
|
Params json.RawMessage `json:"params"`
|
|
}
|
|
|
|
type experimentCreateWhitelistRequest struct {
|
|
UserID string `json:"user_id" binding:"required"`
|
|
VariantName string `json:"variant_name" binding:"required"`
|
|
}
|
|
|
|
type experimentCreateRequest struct {
|
|
Name string `json:"name" binding:"required"`
|
|
Description string `json:"description"`
|
|
StartTime string `json:"start_time"`
|
|
EndTime string `json:"end_time"`
|
|
Status *int `json:"status"`
|
|
Variants []experimentCreateVariantRequest `json:"variants"`
|
|
Whitelist []experimentCreateWhitelistRequest `json:"whitelist"`
|
|
}
|
|
|
|
type optionalStringField struct {
|
|
Set bool
|
|
Value *string
|
|
}
|
|
|
|
func (f *optionalStringField) UnmarshalJSON(data []byte) error {
|
|
f.Set = true
|
|
if string(data) == "null" {
|
|
f.Value = nil
|
|
return nil
|
|
}
|
|
var value string
|
|
if err := json.Unmarshal(data, &value); err != nil {
|
|
return err
|
|
}
|
|
f.Value = &value
|
|
return nil
|
|
}
|
|
|
|
type experimentUpdateRequest struct {
|
|
Name *string `json:"name"`
|
|
Description *string `json:"description"`
|
|
StartTime optionalStringField `json:"start_time"`
|
|
EndTime optionalStringField `json:"end_time"`
|
|
Status *int `json:"status"`
|
|
Variants []experimentCreateVariantRequest `json:"variants"`
|
|
Whitelist []experimentCreateWhitelistRequest `json:"whitelist"`
|
|
}
|
|
|
|
func getExperimentService() (*abtestservice.Service, error) {
|
|
experimentServiceOnce.Do(func() {
|
|
db, err := util.ConnectMysql("log", "abtest")
|
|
if err != nil {
|
|
experimentServiceErr = err
|
|
return
|
|
}
|
|
experimentService = abtestservice.New(abtestdao.New(db))
|
|
})
|
|
return experimentService, experimentServiceErr
|
|
}
|
|
|
|
func ExperimentList(c *gin.Context) {
|
|
svc, err := getExperimentService()
|
|
if err != nil {
|
|
failed(c, err.Error())
|
|
return
|
|
}
|
|
|
|
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
|
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
|
|
var statusPtr *int
|
|
if statusRaw := c.Query("status"); statusRaw != "" {
|
|
status, convErr := strconv.Atoi(statusRaw)
|
|
if convErr != nil {
|
|
failed(c, "invalid status")
|
|
return
|
|
}
|
|
statusPtr = &status
|
|
}
|
|
|
|
list, total, err := svc.ListExperiments(statusPtr, page, pageSize)
|
|
if err != nil {
|
|
failed(c, err.Error())
|
|
return
|
|
}
|
|
success(c, gin.H{"list": list, "total": total, "page": page, "page_size": pageSize})
|
|
}
|
|
|
|
func ExperimentCreate(c *gin.Context) {
|
|
svc, err := getExperimentService()
|
|
if err != nil {
|
|
failed(c, err.Error())
|
|
return
|
|
}
|
|
|
|
var req experimentCreateRequest
|
|
if err := c.ShouldBindJSON(&req); err != nil {
|
|
failed(c, err.Error())
|
|
return
|
|
}
|
|
|
|
result, err := createExperimentAggregate(svc, &req)
|
|
if err != nil {
|
|
failed(c, err.Error())
|
|
return
|
|
}
|
|
util.AddAdminLog(c, "创建AB实验", req)
|
|
success(c, result)
|
|
}
|
|
|
|
func createExperimentAggregate(svc *abtestservice.Service, req *experimentCreateRequest) (gin.H, error) {
|
|
name := strings.TrimSpace(req.Name)
|
|
if name == "" {
|
|
return nil, fmt.Errorf("name is required")
|
|
}
|
|
|
|
startTime, err := parseExperimentTime(req.StartTime)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("invalid start_time: %w", err)
|
|
}
|
|
endTime, err := parseExperimentTime(req.EndTime)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("invalid end_time: %w", err)
|
|
}
|
|
if startTime != nil && endTime != nil && startTime.After(*endTime) {
|
|
return nil, fmt.Errorf("start_time must be earlier than end_time")
|
|
}
|
|
|
|
experiment, err := svc.CreateExperiment(&abtestmodel.CreateExperimentReq{
|
|
Name: name,
|
|
Description: req.Description,
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
shouldRollback := true
|
|
defer func() {
|
|
if shouldRollback {
|
|
_ = svc.DeleteExperiment(experiment.ID)
|
|
}
|
|
}()
|
|
|
|
if req.Status != nil || startTime != nil || endTime != nil {
|
|
experiment, err = svc.UpdateExperiment(experiment.ID, &abtestmodel.UpdateExperimentReq{
|
|
Status: req.Status,
|
|
StartTime: startTime,
|
|
EndTime: endTime,
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
variants := make([]*abtestmodel.Variant, 0, len(req.Variants))
|
|
variantsByName := make(map[string]*abtestmodel.Variant, len(req.Variants))
|
|
for _, variantReq := range req.Variants {
|
|
variantName := strings.TrimSpace(variantReq.Name)
|
|
if variantName == "" {
|
|
return nil, fmt.Errorf("variant name is required")
|
|
}
|
|
if _, exists := variantsByName[variantName]; exists {
|
|
return nil, fmt.Errorf("duplicate variant name: %s", variantName)
|
|
}
|
|
variant, createErr := svc.CreateVariant(experiment.ID, &abtestmodel.CreateVariantReq{
|
|
Name: variantName,
|
|
Description: variantReq.Description,
|
|
Weight: variantReq.Weight,
|
|
Params: variantReq.Params,
|
|
})
|
|
if createErr != nil {
|
|
return nil, createErr
|
|
}
|
|
variants = append(variants, variant)
|
|
variantsByName[variant.Name] = variant
|
|
}
|
|
|
|
whitelist := make([]*abtestmodel.Whitelist, 0, len(req.Whitelist))
|
|
for _, whitelistReq := range req.Whitelist {
|
|
userID := strings.TrimSpace(whitelistReq.UserID)
|
|
variantName := strings.TrimSpace(whitelistReq.VariantName)
|
|
if userID == "" {
|
|
return nil, fmt.Errorf("whitelist user_id is required")
|
|
}
|
|
variant, exists := variantsByName[variantName]
|
|
if !exists {
|
|
return nil, fmt.Errorf("whitelist variant_name not found: %s", variantName)
|
|
}
|
|
entry, addErr := svc.AddWhitelist(experiment.ID, &abtestmodel.CreateWhitelistReq{
|
|
UserID: userID,
|
|
VariantID: variant.ID,
|
|
})
|
|
if addErr != nil {
|
|
return nil, addErr
|
|
}
|
|
whitelist = append(whitelist, entry)
|
|
}
|
|
|
|
shouldRollback = false
|
|
return gin.H{
|
|
"experiment": experiment,
|
|
"variants": variants,
|
|
"whitelist": whitelist,
|
|
}, nil
|
|
}
|
|
|
|
func parseExperimentTime(raw string) (*time.Time, error) {
|
|
raw = strings.TrimSpace(raw)
|
|
if raw == "" {
|
|
return nil, nil
|
|
}
|
|
layouts := []string{
|
|
"2006-01-02 15:04:05",
|
|
time.RFC3339,
|
|
"2006-01-02T15:04:05",
|
|
}
|
|
for _, layout := range layouts {
|
|
parsed, err := time.ParseInLocation(layout, raw, time.Local)
|
|
if err == nil {
|
|
return &parsed, nil
|
|
}
|
|
}
|
|
return nil, fmt.Errorf("unsupported time format")
|
|
}
|
|
|
|
func ExperimentUpdate(c *gin.Context) {
|
|
svc, err := getExperimentService()
|
|
if err != nil {
|
|
failed(c, err.Error())
|
|
return
|
|
}
|
|
|
|
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
|
if err != nil {
|
|
failed(c, "invalid id")
|
|
return
|
|
}
|
|
|
|
var req experimentUpdateRequest
|
|
if err := c.ShouldBindJSON(&req); err != nil {
|
|
failed(c, err.Error())
|
|
return
|
|
}
|
|
|
|
experiment, err := updateExperimentAggregate(svc, id, &req)
|
|
if err != nil {
|
|
status := http.StatusBadRequest
|
|
if err == abtestservice.ErrExperimentNotFound {
|
|
status = http.StatusNotFound
|
|
}
|
|
c.JSON(status, gin.H{"code": 1, "message": err.Error()})
|
|
return
|
|
}
|
|
util.AddAdminLog(c, "更新AB实验", gin.H{"id": id, "payload": req})
|
|
success(c, experiment)
|
|
}
|
|
|
|
func updateExperimentAggregate(svc *abtestservice.Service, experimentID int64, req *experimentUpdateRequest) (*abtestmodel.Experiment, error) {
|
|
if _, err := svc.GetExperiment(experimentID); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
updateReq := &abtestmodel.UpdateExperimentReq{
|
|
Name: req.Name,
|
|
Description: req.Description,
|
|
Status: req.Status,
|
|
}
|
|
|
|
if req.StartTime.Set {
|
|
if req.StartTime.Value == nil {
|
|
updateReq.ClearStartTime = true
|
|
} else {
|
|
parsed, err := parseExperimentTime(*req.StartTime.Value)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("invalid start_time: %w", err)
|
|
}
|
|
updateReq.StartTime = parsed
|
|
}
|
|
}
|
|
if req.EndTime.Set {
|
|
if req.EndTime.Value == nil {
|
|
updateReq.ClearEndTime = true
|
|
} else {
|
|
parsed, err := parseExperimentTime(*req.EndTime.Value)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("invalid end_time: %w", err)
|
|
}
|
|
updateReq.EndTime = parsed
|
|
}
|
|
}
|
|
if updateReq.StartTime != nil && updateReq.EndTime != nil && updateReq.StartTime.After(*updateReq.EndTime) {
|
|
return nil, fmt.Errorf("start_time must be earlier than end_time")
|
|
}
|
|
|
|
experiment, err := svc.UpdateExperiment(experimentID, updateReq)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
variantsProvided := req.Variants != nil
|
|
whitelistProvided := req.Whitelist != nil
|
|
if !variantsProvided && !whitelistProvided {
|
|
return experiment, nil
|
|
}
|
|
|
|
var variants []*abtestmodel.Variant
|
|
variantsByName := make(map[string]*abtestmodel.Variant)
|
|
|
|
if variantsProvided {
|
|
if err := svc.ResetExperimentRelations(experimentID); err != nil {
|
|
return nil, err
|
|
}
|
|
for _, variantReq := range req.Variants {
|
|
variantName := strings.TrimSpace(variantReq.Name)
|
|
if variantName == "" {
|
|
return nil, fmt.Errorf("variant name is required")
|
|
}
|
|
if _, exists := variantsByName[variantName]; exists {
|
|
return nil, fmt.Errorf("duplicate variant name: %s", variantName)
|
|
}
|
|
variant, createErr := svc.CreateVariant(experimentID, &abtestmodel.CreateVariantReq{
|
|
Name: variantName,
|
|
Description: variantReq.Description,
|
|
Weight: variantReq.Weight,
|
|
Params: variantReq.Params,
|
|
})
|
|
if createErr != nil {
|
|
return nil, createErr
|
|
}
|
|
variants = append(variants, variant)
|
|
variantsByName[variant.Name] = variant
|
|
}
|
|
} else if whitelistProvided {
|
|
variants, err = svc.ListVariants(experimentID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
for _, variant := range variants {
|
|
variantsByName[variant.Name] = variant
|
|
}
|
|
if err := svc.ClearWhitelist(experimentID); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
if whitelistProvided {
|
|
for _, whitelistReq := range req.Whitelist {
|
|
userID := strings.TrimSpace(whitelistReq.UserID)
|
|
variantName := strings.TrimSpace(whitelistReq.VariantName)
|
|
if userID == "" {
|
|
return nil, fmt.Errorf("whitelist user_id is required")
|
|
}
|
|
variant, exists := variantsByName[variantName]
|
|
if !exists {
|
|
return nil, fmt.Errorf("whitelist variant_name not found: %s", variantName)
|
|
}
|
|
if _, err := svc.AddWhitelist(experimentID, &abtestmodel.CreateWhitelistReq{
|
|
UserID: userID,
|
|
VariantID: variant.ID,
|
|
}); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
}
|
|
|
|
return svc.GetExperiment(experimentID)
|
|
}
|
|
|
|
func ExperimentDelete(c *gin.Context) {
|
|
svc, err := getExperimentService()
|
|
if err != nil {
|
|
failed(c, err.Error())
|
|
return
|
|
}
|
|
|
|
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
|
if err != nil {
|
|
failed(c, "invalid id")
|
|
return
|
|
}
|
|
|
|
if err := svc.DeleteExperiment(id); err != nil {
|
|
status := http.StatusBadRequest
|
|
if err == abtestservice.ErrExperimentNotFound {
|
|
status = http.StatusNotFound
|
|
}
|
|
c.JSON(status, gin.H{"code": 1, "message": err.Error()})
|
|
return
|
|
}
|
|
util.AddAdminLog(c, "删除AB实验", gin.H{"id": id})
|
|
success(c, nil)
|
|
}
|
|
|
|
func ExperimentVariantList(c *gin.Context) {
|
|
svc, err := getExperimentService()
|
|
if err != nil {
|
|
failed(c, err.Error())
|
|
return
|
|
}
|
|
|
|
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
|
if err != nil {
|
|
failed(c, "invalid id")
|
|
return
|
|
}
|
|
|
|
list, err := svc.ListVariants(id)
|
|
if err != nil {
|
|
failed(c, err.Error())
|
|
return
|
|
}
|
|
success(c, list)
|
|
}
|
|
|
|
func ExperimentVariantCreate(c *gin.Context) {
|
|
svc, err := getExperimentService()
|
|
if err != nil {
|
|
failed(c, err.Error())
|
|
return
|
|
}
|
|
|
|
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
|
if err != nil {
|
|
failed(c, "invalid id")
|
|
return
|
|
}
|
|
|
|
var req abtestmodel.CreateVariantReq
|
|
if err := c.ShouldBindJSON(&req); err != nil {
|
|
failed(c, err.Error())
|
|
return
|
|
}
|
|
|
|
variant, err := svc.CreateVariant(id, &req)
|
|
if err != nil {
|
|
status := http.StatusBadRequest
|
|
if err == abtestservice.ErrExperimentNotFound {
|
|
status = http.StatusNotFound
|
|
}
|
|
c.JSON(status, gin.H{"code": 1, "message": err.Error()})
|
|
return
|
|
}
|
|
util.AddAdminLog(c, "创建AB实验变体", gin.H{"experiment_id": id, "payload": req})
|
|
success(c, variant)
|
|
}
|
|
|
|
func ExperimentVariantUpdate(c *gin.Context) {
|
|
svc, err := getExperimentService()
|
|
if err != nil {
|
|
failed(c, err.Error())
|
|
return
|
|
}
|
|
|
|
variantID, err := strconv.ParseInt(c.Param("variantId"), 10, 64)
|
|
if err != nil {
|
|
failed(c, "invalid variantId")
|
|
return
|
|
}
|
|
|
|
var req abtestmodel.UpdateVariantReq
|
|
if err := c.ShouldBindJSON(&req); err != nil {
|
|
failed(c, err.Error())
|
|
return
|
|
}
|
|
|
|
variant, err := svc.UpdateVariant(variantID, &req)
|
|
if err != nil {
|
|
status := http.StatusBadRequest
|
|
if err == abtestservice.ErrVariantNotFound {
|
|
status = http.StatusNotFound
|
|
}
|
|
c.JSON(status, gin.H{"code": 1, "message": err.Error()})
|
|
return
|
|
}
|
|
util.AddAdminLog(c, "更新AB实验变体", gin.H{"variant_id": variantID, "payload": req})
|
|
success(c, variant)
|
|
}
|
|
|
|
func ExperimentVariantDelete(c *gin.Context) {
|
|
svc, err := getExperimentService()
|
|
if err != nil {
|
|
failed(c, err.Error())
|
|
return
|
|
}
|
|
|
|
variantID, err := strconv.ParseInt(c.Param("variantId"), 10, 64)
|
|
if err != nil {
|
|
failed(c, "invalid variantId")
|
|
return
|
|
}
|
|
|
|
if err := svc.DeleteVariant(variantID); err != nil {
|
|
status := http.StatusBadRequest
|
|
if err == abtestservice.ErrVariantNotFound {
|
|
status = http.StatusNotFound
|
|
}
|
|
c.JSON(status, gin.H{"code": 1, "message": err.Error()})
|
|
return
|
|
}
|
|
util.AddAdminLog(c, "删除AB实验变体", gin.H{"variant_id": variantID})
|
|
success(c, nil)
|
|
}
|
|
|
|
func ExperimentWhitelistList(c *gin.Context) {
|
|
svc, err := getExperimentService()
|
|
if err != nil {
|
|
failed(c, err.Error())
|
|
return
|
|
}
|
|
|
|
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
|
if err != nil {
|
|
failed(c, "invalid id")
|
|
return
|
|
}
|
|
|
|
list, err := svc.ListWhitelist(id)
|
|
if err != nil {
|
|
failed(c, err.Error())
|
|
return
|
|
}
|
|
success(c, list)
|
|
}
|
|
|
|
func ExperimentWhitelistCreate(c *gin.Context) {
|
|
svc, err := getExperimentService()
|
|
if err != nil {
|
|
failed(c, err.Error())
|
|
return
|
|
}
|
|
|
|
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
|
if err != nil {
|
|
failed(c, "invalid id")
|
|
return
|
|
}
|
|
|
|
var req abtestmodel.CreateWhitelistReq
|
|
if err := c.ShouldBindJSON(&req); err != nil {
|
|
failed(c, err.Error())
|
|
return
|
|
}
|
|
|
|
entry, err := svc.AddWhitelist(id, &req)
|
|
if err != nil {
|
|
status := http.StatusBadRequest
|
|
if err == abtestservice.ErrExperimentNotFound || err == abtestservice.ErrVariantNotFound {
|
|
status = http.StatusNotFound
|
|
}
|
|
c.JSON(status, gin.H{"code": 1, "message": err.Error()})
|
|
return
|
|
}
|
|
util.AddAdminLog(c, "新增AB白名单", gin.H{"experiment_id": id, "payload": req})
|
|
success(c, entry)
|
|
}
|
|
|
|
func ExperimentWhitelistBatchCreate(c *gin.Context) {
|
|
svc, err := getExperimentService()
|
|
if err != nil {
|
|
failed(c, err.Error())
|
|
return
|
|
}
|
|
|
|
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
|
if err != nil {
|
|
failed(c, "invalid id")
|
|
return
|
|
}
|
|
|
|
var req abtestmodel.BatchCreateWhitelistReq
|
|
if err := c.ShouldBindJSON(&req); err != nil {
|
|
failed(c, err.Error())
|
|
return
|
|
}
|
|
|
|
list, err := svc.BatchAddWhitelist(id, &req)
|
|
if err != nil {
|
|
status := http.StatusBadRequest
|
|
if err == abtestservice.ErrExperimentNotFound || err == abtestservice.ErrVariantNotFound {
|
|
status = http.StatusNotFound
|
|
}
|
|
c.JSON(status, gin.H{"code": 1, "message": err.Error()})
|
|
return
|
|
}
|
|
util.AddAdminLog(c, "批量新增AB白名单", gin.H{"experiment_id": id, "payload": req})
|
|
success(c, list)
|
|
}
|
|
|
|
func ExperimentWhitelistDelete(c *gin.Context) {
|
|
svc, err := getExperimentService()
|
|
if err != nil {
|
|
failed(c, err.Error())
|
|
return
|
|
}
|
|
|
|
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
|
if err != nil {
|
|
failed(c, "invalid id")
|
|
return
|
|
}
|
|
|
|
userID := c.Param("userId")
|
|
if userID == "" {
|
|
failed(c, "invalid userId")
|
|
return
|
|
}
|
|
|
|
if err := svc.RemoveWhitelist(id, userID); err != nil {
|
|
failed(c, err.Error())
|
|
return
|
|
}
|
|
util.AddAdminLog(c, "删除AB白名单", gin.H{"experiment_id": id, "user_id": userID})
|
|
success(c, nil)
|
|
}
|
|
|
|
func ExperimentResult(c *gin.Context) {
|
|
svc, err := getExperimentService()
|
|
if err != nil {
|
|
failed(c, err.Error())
|
|
return
|
|
}
|
|
|
|
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
|
if err != nil {
|
|
failed(c, "invalid id")
|
|
return
|
|
}
|
|
|
|
result, err := svc.GetExperimentResult(id)
|
|
if err != nil {
|
|
status := http.StatusBadRequest
|
|
if err == abtestservice.ErrExperimentNotFound {
|
|
status = http.StatusNotFound
|
|
}
|
|
c.JSON(status, gin.H{"code": 1, "message": err.Error()})
|
|
return
|
|
}
|
|
success(c, result)
|
|
}
|
|
|
|
func UserExperimentGroups(c *gin.Context) {
|
|
svc, err := getExperimentService()
|
|
if err != nil {
|
|
failed(c, err.Error())
|
|
return
|
|
}
|
|
|
|
userID := strings.TrimSpace(c.Param("userId"))
|
|
if userID == "" {
|
|
failed(c, "invalid userId")
|
|
return
|
|
}
|
|
|
|
groups, err := svc.GetUserGroups(userID)
|
|
if err != nil {
|
|
failed(c, err.Error())
|
|
return
|
|
}
|
|
success(c, groups)
|
|
}
|