From 1547bbd17767d83e8f3db4d814a45195b4eb9da5 Mon Sep 17 00:00:00 2001 From: hahwu <31872165+hahwu@users.noreply.github.com> Date: Tue, 12 May 2026 17:56:08 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E7=9F=A5=E8=AF=86=E5=BA=93?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- controller/knowledge.go | 103 ++++++++++++++++++++++++++++++++++++++++ main.go | 5 +- 2 files changed, 107 insertions(+), 1 deletion(-) create mode 100644 controller/knowledge.go diff --git a/controller/knowledge.go b/controller/knowledge.go new file mode 100644 index 0000000..1bbdf0e --- /dev/null +++ b/controller/knowledge.go @@ -0,0 +1,103 @@ +package controller + +import ( + "bytes" + "crypto/tls" + "io" + "net" + "net/http" + "net/url" + "os" + "strings" + "time" + + "github.com/gin-gonic/gin" +) + +// 反向代理到本机 RAG FastAPI (127.0.0.1:8765),转发 SSE 流。 +// 不在 Go 端重做 BM25/LLM;上游已封装。 +const defaultRagUpstream = "http://8.159.157.223:8765/api/chat" + +func getRagUpstream() string { + raw := strings.TrimSpace(os.Getenv("RAG_UPSTREAM")) + if raw == "" { + return defaultRagUpstream + } + return raw +} + +func newRagHTTPClient(upstream string) *http.Client { + transport := http.DefaultTransport.(*http.Transport).Clone() + if parsed, err := url.Parse(upstream); err == nil && parsed.Scheme == "https" { + if host := parsed.Hostname(); net.ParseIP(host) != nil { + // 上游若使用 https://IP 形式且证书没有 IP SAN,会触发 x509 校验失败。 + // 这里仅对 IP 直连场景放宽校验,域名场景仍保持默认校验。 + transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} + } + } + return &http.Client{ + Timeout: 5 * time.Minute, + Transport: transport, + } +} + +func KnowledgeChat(c *gin.Context) { + body, err := io.ReadAll(c.Request.Body) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "read body: " + err.Error()}) + return + } + + // 长超时:流式问答可能持续 30s+;不限上游耗时由上游自己控 + upstream := getRagUpstream() + client := newRagHTTPClient(upstream) + req, err := http.NewRequest(http.MethodPost, upstream, bytes.NewReader(body)) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "text/event-stream") + + resp, err := client.Do(req) + if err != nil { + c.JSON(http.StatusBadGateway, gin.H{"error": "rag upstream: " + err.Error()}) + return + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + b, _ := io.ReadAll(resp.Body) + c.Data(resp.StatusCode, resp.Header.Get("Content-Type"), b) + return + } + + // 透传 SSE 头:禁用 Nginx 缓冲,否则前端无法获得实时分片 + h := c.Writer.Header() + h.Set("Content-Type", "text/event-stream") + h.Set("Cache-Control", "no-cache") + h.Set("Connection", "keep-alive") + h.Set("X-Accel-Buffering", "no") + c.Writer.WriteHeader(http.StatusOK) + + flusher, ok := c.Writer.(http.Flusher) + if !ok { + c.JSON(http.StatusInternalServerError, gin.H{"error": "streaming unsupported"}) + return + } + + // 小缓冲逐块转发;上游每发一帧即 flush,让浏览器立即看到 token + buf := make([]byte, 4096) + for { + n, rerr := resp.Body.Read(buf) + if n > 0 { + if _, werr := c.Writer.Write(buf[:n]); werr != nil { + return + } + flusher.Flush() + } + if rerr != nil { + return + } + } +} diff --git a/main.go b/main.go index 751ecda..bb7432c 100644 --- a/main.go +++ b/main.go @@ -223,7 +223,10 @@ func main() { api.PUT("/config/notification/update", controller.NotificationConfigSave) } - // 自动化脚本 + knowledgeApi := r.Group("/api/knowledge", middleware.ValidateToken()) + { + knowledgeApi.POST("/chat", controller.KnowledgeChat) + } scripts := r.Group("/api/scripts", middleware.ValidateToken()) { scripts.POST("/copywriting", controller.Copywriting) // 下载文案文件