package controller import ( "backend/util" "encoding/json" "fmt" "net/http" "strconv" "strings" "sync" "time" abtestdao "backend/sdk/abtest/dao" abtestmodel "backend/sdk/abtest/model" abtestservice "backend/sdk/abtest/service" "github.com/gin-gonic/gin" ) var ( experimentServiceOnce sync.Once experimentService *abtestservice.Service experimentServiceErr error ) type experimentCreateVariantRequest struct { Name string `json:"name" binding:"required"` Description string `json:"description"` Weight int `json:"weight"` Params json.RawMessage `json:"params"` } type experimentCreateWhitelistRequest struct { UserID string `json:"user_id" binding:"required"` VariantName string `json:"variant_name" binding:"required"` } type experimentCreateRequest struct { Name string `json:"name" binding:"required"` Description string `json:"description"` StartTime string `json:"start_time"` EndTime string `json:"end_time"` Status *int `json:"status"` Variants []experimentCreateVariantRequest `json:"variants"` Whitelist []experimentCreateWhitelistRequest `json:"whitelist"` } type optionalStringField struct { Set bool Value *string } func (f *optionalStringField) UnmarshalJSON(data []byte) error { f.Set = true if string(data) == "null" { f.Value = nil return nil } var value string if err := json.Unmarshal(data, &value); err != nil { return err } f.Value = &value return nil } type experimentUpdateRequest struct { Name *string `json:"name"` Description *string `json:"description"` StartTime optionalStringField `json:"start_time"` EndTime optionalStringField `json:"end_time"` Status *int `json:"status"` Variants []experimentCreateVariantRequest `json:"variants"` Whitelist []experimentCreateWhitelistRequest `json:"whitelist"` } func getExperimentService() (*abtestservice.Service, error) { experimentServiceOnce.Do(func() { db, err := util.ConnectMysql("log", "abtest") if err != nil { experimentServiceErr = err return } experimentService = abtestservice.New(abtestdao.New(db)) }) return experimentService, experimentServiceErr } func ExperimentList(c *gin.Context) { svc, err := getExperimentService() if err != nil { failed(c, err.Error()) return } page, _ := strconv.Atoi(c.DefaultQuery("page", "1")) pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20")) var statusPtr *int if statusRaw := c.Query("status"); statusRaw != "" { status, convErr := strconv.Atoi(statusRaw) if convErr != nil { failed(c, "invalid status") return } statusPtr = &status } list, total, err := svc.ListExperiments(statusPtr, page, pageSize) if err != nil { failed(c, err.Error()) return } success(c, gin.H{"list": list, "total": total, "page": page, "page_size": pageSize}) } func ExperimentCreate(c *gin.Context) { svc, err := getExperimentService() if err != nil { failed(c, err.Error()) return } var req experimentCreateRequest if err := c.ShouldBindJSON(&req); err != nil { failed(c, err.Error()) return } result, err := createExperimentAggregate(svc, &req) if err != nil { failed(c, err.Error()) return } util.AddAdminLog(c, "创建AB实验", req) success(c, result) } func createExperimentAggregate(svc *abtestservice.Service, req *experimentCreateRequest) (gin.H, error) { name := strings.TrimSpace(req.Name) if name == "" { return nil, fmt.Errorf("name is required") } startTime, err := parseExperimentTime(req.StartTime) if err != nil { return nil, fmt.Errorf("invalid start_time: %w", err) } endTime, err := parseExperimentTime(req.EndTime) if err != nil { return nil, fmt.Errorf("invalid end_time: %w", err) } if startTime != nil && endTime != nil && startTime.After(*endTime) { return nil, fmt.Errorf("start_time must be earlier than end_time") } experiment, err := svc.CreateExperiment(&abtestmodel.CreateExperimentReq{ Name: name, Description: req.Description, }) if err != nil { return nil, err } shouldRollback := true defer func() { if shouldRollback { _ = svc.DeleteExperiment(experiment.ID) } }() if req.Status != nil || startTime != nil || endTime != nil { experiment, err = svc.UpdateExperiment(experiment.ID, &abtestmodel.UpdateExperimentReq{ Status: req.Status, StartTime: startTime, EndTime: endTime, }) if err != nil { return nil, err } } variants := make([]*abtestmodel.Variant, 0, len(req.Variants)) variantsByName := make(map[string]*abtestmodel.Variant, len(req.Variants)) for _, variantReq := range req.Variants { variantName := strings.TrimSpace(variantReq.Name) if variantName == "" { return nil, fmt.Errorf("variant name is required") } if _, exists := variantsByName[variantName]; exists { return nil, fmt.Errorf("duplicate variant name: %s", variantName) } variant, createErr := svc.CreateVariant(experiment.ID, &abtestmodel.CreateVariantReq{ Name: variantName, Description: variantReq.Description, Weight: variantReq.Weight, Params: variantReq.Params, }) if createErr != nil { return nil, createErr } variants = append(variants, variant) variantsByName[variant.Name] = variant } whitelist := make([]*abtestmodel.Whitelist, 0, len(req.Whitelist)) for _, whitelistReq := range req.Whitelist { userID := strings.TrimSpace(whitelistReq.UserID) variantName := strings.TrimSpace(whitelistReq.VariantName) if userID == "" { return nil, fmt.Errorf("whitelist user_id is required") } variant, exists := variantsByName[variantName] if !exists { return nil, fmt.Errorf("whitelist variant_name not found: %s", variantName) } entry, addErr := svc.AddWhitelist(experiment.ID, &abtestmodel.CreateWhitelistReq{ UserID: userID, VariantID: variant.ID, }) if addErr != nil { return nil, addErr } whitelist = append(whitelist, entry) } shouldRollback = false return gin.H{ "experiment": experiment, "variants": variants, "whitelist": whitelist, }, nil } func parseExperimentTime(raw string) (*time.Time, error) { raw = strings.TrimSpace(raw) if raw == "" { return nil, nil } layouts := []string{ "2006-01-02 15:04:05", time.RFC3339, "2006-01-02T15:04:05", } for _, layout := range layouts { parsed, err := time.ParseInLocation(layout, raw, time.Local) if err == nil { return &parsed, nil } } return nil, fmt.Errorf("unsupported time format") } func ExperimentUpdate(c *gin.Context) { svc, err := getExperimentService() if err != nil { failed(c, err.Error()) return } id, err := strconv.ParseInt(c.Param("id"), 10, 64) if err != nil { failed(c, "invalid id") return } var req experimentUpdateRequest if err := c.ShouldBindJSON(&req); err != nil { failed(c, err.Error()) return } experiment, err := updateExperimentAggregate(svc, id, &req) if err != nil { status := http.StatusBadRequest if err == abtestservice.ErrExperimentNotFound { status = http.StatusNotFound } c.JSON(status, gin.H{"code": 1, "message": err.Error()}) return } util.AddAdminLog(c, "更新AB实验", gin.H{"id": id, "payload": req}) success(c, experiment) } func updateExperimentAggregate(svc *abtestservice.Service, experimentID int64, req *experimentUpdateRequest) (*abtestmodel.Experiment, error) { if _, err := svc.GetExperiment(experimentID); err != nil { return nil, err } updateReq := &abtestmodel.UpdateExperimentReq{ Name: req.Name, Description: req.Description, Status: req.Status, } if req.StartTime.Set { if req.StartTime.Value == nil { updateReq.ClearStartTime = true } else { parsed, err := parseExperimentTime(*req.StartTime.Value) if err != nil { return nil, fmt.Errorf("invalid start_time: %w", err) } updateReq.StartTime = parsed } } if req.EndTime.Set { if req.EndTime.Value == nil { updateReq.ClearEndTime = true } else { parsed, err := parseExperimentTime(*req.EndTime.Value) if err != nil { return nil, fmt.Errorf("invalid end_time: %w", err) } updateReq.EndTime = parsed } } if updateReq.StartTime != nil && updateReq.EndTime != nil && updateReq.StartTime.After(*updateReq.EndTime) { return nil, fmt.Errorf("start_time must be earlier than end_time") } experiment, err := svc.UpdateExperiment(experimentID, updateReq) if err != nil { return nil, err } variantsProvided := req.Variants != nil whitelistProvided := req.Whitelist != nil if !variantsProvided && !whitelistProvided { return experiment, nil } var variants []*abtestmodel.Variant variantsByName := make(map[string]*abtestmodel.Variant) if variantsProvided { if err := svc.ResetExperimentRelations(experimentID); err != nil { return nil, err } for _, variantReq := range req.Variants { variantName := strings.TrimSpace(variantReq.Name) if variantName == "" { return nil, fmt.Errorf("variant name is required") } if _, exists := variantsByName[variantName]; exists { return nil, fmt.Errorf("duplicate variant name: %s", variantName) } variant, createErr := svc.CreateVariant(experimentID, &abtestmodel.CreateVariantReq{ Name: variantName, Description: variantReq.Description, Weight: variantReq.Weight, Params: variantReq.Params, }) if createErr != nil { return nil, createErr } variants = append(variants, variant) variantsByName[variant.Name] = variant } } else if whitelistProvided { variants, err = svc.ListVariants(experimentID) if err != nil { return nil, err } for _, variant := range variants { variantsByName[variant.Name] = variant } if err := svc.ClearWhitelist(experimentID); err != nil { return nil, err } } if whitelistProvided { for _, whitelistReq := range req.Whitelist { userID := strings.TrimSpace(whitelistReq.UserID) variantName := strings.TrimSpace(whitelistReq.VariantName) if userID == "" { return nil, fmt.Errorf("whitelist user_id is required") } variant, exists := variantsByName[variantName] if !exists { return nil, fmt.Errorf("whitelist variant_name not found: %s", variantName) } if _, err := svc.AddWhitelist(experimentID, &abtestmodel.CreateWhitelistReq{ UserID: userID, VariantID: variant.ID, }); err != nil { return nil, err } } } return svc.GetExperiment(experimentID) } func ExperimentDelete(c *gin.Context) { svc, err := getExperimentService() if err != nil { failed(c, err.Error()) return } id, err := strconv.ParseInt(c.Param("id"), 10, 64) if err != nil { failed(c, "invalid id") return } if err := svc.DeleteExperiment(id); err != nil { status := http.StatusBadRequest if err == abtestservice.ErrExperimentNotFound { status = http.StatusNotFound } c.JSON(status, gin.H{"code": 1, "message": err.Error()}) return } util.AddAdminLog(c, "删除AB实验", gin.H{"id": id}) success(c, nil) } func ExperimentVariantList(c *gin.Context) { svc, err := getExperimentService() if err != nil { failed(c, err.Error()) return } id, err := strconv.ParseInt(c.Param("id"), 10, 64) if err != nil { failed(c, "invalid id") return } list, err := svc.ListVariants(id) if err != nil { failed(c, err.Error()) return } success(c, list) } func ExperimentVariantCreate(c *gin.Context) { svc, err := getExperimentService() if err != nil { failed(c, err.Error()) return } id, err := strconv.ParseInt(c.Param("id"), 10, 64) if err != nil { failed(c, "invalid id") return } var req abtestmodel.CreateVariantReq if err := c.ShouldBindJSON(&req); err != nil { failed(c, err.Error()) return } variant, err := svc.CreateVariant(id, &req) if err != nil { status := http.StatusBadRequest if err == abtestservice.ErrExperimentNotFound { status = http.StatusNotFound } c.JSON(status, gin.H{"code": 1, "message": err.Error()}) return } util.AddAdminLog(c, "创建AB实验变体", gin.H{"experiment_id": id, "payload": req}) success(c, variant) } func ExperimentVariantUpdate(c *gin.Context) { svc, err := getExperimentService() if err != nil { failed(c, err.Error()) return } variantID, err := strconv.ParseInt(c.Param("variantId"), 10, 64) if err != nil { failed(c, "invalid variantId") return } var req abtestmodel.UpdateVariantReq if err := c.ShouldBindJSON(&req); err != nil { failed(c, err.Error()) return } variant, err := svc.UpdateVariant(variantID, &req) if err != nil { status := http.StatusBadRequest if err == abtestservice.ErrVariantNotFound { status = http.StatusNotFound } c.JSON(status, gin.H{"code": 1, "message": err.Error()}) return } util.AddAdminLog(c, "更新AB实验变体", gin.H{"variant_id": variantID, "payload": req}) success(c, variant) } func ExperimentVariantDelete(c *gin.Context) { svc, err := getExperimentService() if err != nil { failed(c, err.Error()) return } variantID, err := strconv.ParseInt(c.Param("variantId"), 10, 64) if err != nil { failed(c, "invalid variantId") return } if err := svc.DeleteVariant(variantID); err != nil { status := http.StatusBadRequest if err == abtestservice.ErrVariantNotFound { status = http.StatusNotFound } c.JSON(status, gin.H{"code": 1, "message": err.Error()}) return } util.AddAdminLog(c, "删除AB实验变体", gin.H{"variant_id": variantID}) success(c, nil) } func ExperimentWhitelistList(c *gin.Context) { svc, err := getExperimentService() if err != nil { failed(c, err.Error()) return } id, err := strconv.ParseInt(c.Param("id"), 10, 64) if err != nil { failed(c, "invalid id") return } list, err := svc.ListWhitelist(id) if err != nil { failed(c, err.Error()) return } success(c, list) } func ExperimentWhitelistCreate(c *gin.Context) { svc, err := getExperimentService() if err != nil { failed(c, err.Error()) return } id, err := strconv.ParseInt(c.Param("id"), 10, 64) if err != nil { failed(c, "invalid id") return } var req abtestmodel.CreateWhitelistReq if err := c.ShouldBindJSON(&req); err != nil { failed(c, err.Error()) return } entry, err := svc.AddWhitelist(id, &req) if err != nil { status := http.StatusBadRequest if err == abtestservice.ErrExperimentNotFound || err == abtestservice.ErrVariantNotFound { status = http.StatusNotFound } c.JSON(status, gin.H{"code": 1, "message": err.Error()}) return } util.AddAdminLog(c, "新增AB白名单", gin.H{"experiment_id": id, "payload": req}) success(c, entry) } func ExperimentWhitelistBatchCreate(c *gin.Context) { svc, err := getExperimentService() if err != nil { failed(c, err.Error()) return } id, err := strconv.ParseInt(c.Param("id"), 10, 64) if err != nil { failed(c, "invalid id") return } var req abtestmodel.BatchCreateWhitelistReq if err := c.ShouldBindJSON(&req); err != nil { failed(c, err.Error()) return } list, err := svc.BatchAddWhitelist(id, &req) if err != nil { status := http.StatusBadRequest if err == abtestservice.ErrExperimentNotFound || err == abtestservice.ErrVariantNotFound { status = http.StatusNotFound } c.JSON(status, gin.H{"code": 1, "message": err.Error()}) return } util.AddAdminLog(c, "批量新增AB白名单", gin.H{"experiment_id": id, "payload": req}) success(c, list) } func ExperimentWhitelistDelete(c *gin.Context) { svc, err := getExperimentService() if err != nil { failed(c, err.Error()) return } id, err := strconv.ParseInt(c.Param("id"), 10, 64) if err != nil { failed(c, "invalid id") return } userID := c.Param("userId") if userID == "" { failed(c, "invalid userId") return } if err := svc.RemoveWhitelist(id, userID); err != nil { failed(c, err.Error()) return } util.AddAdminLog(c, "删除AB白名单", gin.H{"experiment_id": id, "user_id": userID}) success(c, nil) } func ExperimentResult(c *gin.Context) { svc, err := getExperimentService() if err != nil { failed(c, err.Error()) return } id, err := strconv.ParseInt(c.Param("id"), 10, 64) if err != nil { failed(c, "invalid id") return } result, err := svc.GetExperimentResult(id) if err != nil { status := http.StatusBadRequest if err == abtestservice.ErrExperimentNotFound { status = http.StatusNotFound } c.JSON(status, gin.H{"code": 1, "message": err.Error()}) return } success(c, result) } func UserExperimentGroups(c *gin.Context) { svc, err := getExperimentService() if err != nil { failed(c, err.Error()) return } userID := strings.TrimSpace(c.Param("userId")) if userID == "" { failed(c, "invalid userId") return } groups, err := svc.GetUserGroups(userID) if err != nil { failed(c, err.Error()) return } success(c, groups) }