104 lines
2.7 KiB
Go
104 lines
2.7 KiB
Go
package controller
|
||
|
||
import (
|
||
"bytes"
|
||
"crypto/tls"
|
||
"io"
|
||
"net"
|
||
"net/http"
|
||
"net/url"
|
||
"os"
|
||
"strings"
|
||
"time"
|
||
|
||
"github.com/gin-gonic/gin"
|
||
)
|
||
|
||
// 反向代理到本机 RAG FastAPI (127.0.0.1:8765),转发 SSE 流。
|
||
// 不在 Go 端重做 BM25/LLM;上游已封装。
|
||
const defaultRagUpstream = "http://8.159.157.223:8765/api/chat"
|
||
|
||
func getRagUpstream() string {
|
||
raw := strings.TrimSpace(os.Getenv("RAG_UPSTREAM"))
|
||
if raw == "" {
|
||
return defaultRagUpstream
|
||
}
|
||
return raw
|
||
}
|
||
|
||
func newRagHTTPClient(upstream string) *http.Client {
|
||
transport := http.DefaultTransport.(*http.Transport).Clone()
|
||
if parsed, err := url.Parse(upstream); err == nil && parsed.Scheme == "https" {
|
||
if host := parsed.Hostname(); net.ParseIP(host) != nil {
|
||
// 上游若使用 https://IP 形式且证书没有 IP SAN,会触发 x509 校验失败。
|
||
// 这里仅对 IP 直连场景放宽校验,域名场景仍保持默认校验。
|
||
transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
|
||
}
|
||
}
|
||
return &http.Client{
|
||
Timeout: 5 * time.Minute,
|
||
Transport: transport,
|
||
}
|
||
}
|
||
|
||
func KnowledgeChat(c *gin.Context) {
|
||
body, err := io.ReadAll(c.Request.Body)
|
||
if err != nil {
|
||
c.JSON(http.StatusBadRequest, gin.H{"error": "read body: " + err.Error()})
|
||
return
|
||
}
|
||
|
||
// 长超时:流式问答可能持续 30s+;不限上游耗时由上游自己控
|
||
upstream := getRagUpstream()
|
||
client := newRagHTTPClient(upstream)
|
||
req, err := http.NewRequest(http.MethodPost, upstream, bytes.NewReader(body))
|
||
if err != nil {
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||
return
|
||
}
|
||
req.Header.Set("Content-Type", "application/json")
|
||
req.Header.Set("Accept", "text/event-stream")
|
||
|
||
resp, err := client.Do(req)
|
||
if err != nil {
|
||
c.JSON(http.StatusBadGateway, gin.H{"error": "rag upstream: " + err.Error()})
|
||
return
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
if resp.StatusCode != http.StatusOK {
|
||
b, _ := io.ReadAll(resp.Body)
|
||
c.Data(resp.StatusCode, resp.Header.Get("Content-Type"), b)
|
||
return
|
||
}
|
||
|
||
// 透传 SSE 头:禁用 Nginx 缓冲,否则前端无法获得实时分片
|
||
h := c.Writer.Header()
|
||
h.Set("Content-Type", "text/event-stream")
|
||
h.Set("Cache-Control", "no-cache")
|
||
h.Set("Connection", "keep-alive")
|
||
h.Set("X-Accel-Buffering", "no")
|
||
c.Writer.WriteHeader(http.StatusOK)
|
||
|
||
flusher, ok := c.Writer.(http.Flusher)
|
||
if !ok {
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": "streaming unsupported"})
|
||
return
|
||
}
|
||
|
||
// 小缓冲逐块转发;上游每发一帧即 flush,让浏览器立即看到 token
|
||
buf := make([]byte, 4096)
|
||
for {
|
||
n, rerr := resp.Body.Read(buf)
|
||
if n > 0 {
|
||
if _, werr := c.Writer.Write(buf[:n]); werr != nil {
|
||
return
|
||
}
|
||
flusher.Flush()
|
||
}
|
||
if rerr != nil {
|
||
return
|
||
}
|
||
}
|
||
}
|