package util import ( "backend/Type" "backend/common" "context" "fmt" "log" "net" "strings" "sync" "github.com/go-sql-driver/mysql" // _ "github.com/go-sql-driver/mysql" "github.com/jmoiron/sqlx" "golang.org/x/crypto/ssh" ) type MysqlPool struct { lock sync.Mutex poolList map[string]*poolInfo } type poolInfo struct { DbList []*Db Addr string } type Db struct { *sqlx.DB key string ssh *ssh.Client } func (db *Db) Close() error { if db == nil { return nil } MPool.putDb(db.key, db) return nil } func (m *MysqlPool) getDb(key string) *Db { m.lock.Lock() pool, ok := m.poolList[key] if !ok || len(pool.DbList) == 0 { m.lock.Unlock() return nil } db := pool.DbList[0] pool.DbList = pool.DbList[1:] m.lock.Unlock() // 在锁释放后再做 Ping 与关闭操作,避免死锁 if err := db.Ping(); err != nil { // 底层连接失效,关闭 SSH 隧道和数据库连接,不放回池中 if db.ssh != nil { db.ssh.Close() } db.DB.Close() return nil } return db } func (m *MysqlPool) putDb(key string, db *Db) { m.lock.Lock() pool := m.poolList[key] if pool == nil { pool = &poolInfo{Addr: key} m.poolList[key] = pool } if len(pool.DbList) >= 5 { // 池已满,释放锁后再关闭连接,避免递归调用 putDb 导致死锁 m.lock.Unlock() if db.ssh != nil { db.ssh.Close() } db.DB.Close() return } pool.DbList = append(pool.DbList, db) m.lock.Unlock() } func (m *MysqlPool) GetMysqlDB(AppCnf *Type.App, ServerId int) *Db { key := fmt.Sprintf("%s_%s_%d", AppCnf.NodeName, AppCnf.MysqlName, ServerId) ADb := m.getDb(key) if ADb != nil { return ADb } var SQLDb *sqlx.DB var err error SQLDb, sshConn, err := connectToMySQLViaSSH(AppCnf, ServerId) if err != nil { log.Printf("failed to connect to mysql: %v", err) return nil } ADb = &Db{SQLDb, key, sshConn} // m.putDb(key, Db) return ADb } func (m *MysqlPool) GetGameDB() *Db { Server, Mysql := "log", "log" key := fmt.Sprintf("%s_%s_game", Server, Mysql) ADb := m.getDb(key) if ADb != nil { return ADb } SQLDb, err := ConnectMysql(Mysql, "game") if err != nil { return nil } ADb = &Db{SQLDb, key, nil} // m.putDb(key, Db) return ADb } func (m *MysqlPool) GetTopicDB(Topic string) *Db { Server, Mysql := "log", "log" key := fmt.Sprintf("%s_%s_%s", Server, Mysql, Topic) ADb := m.getDb(key) if ADb != nil { return ADb } SQLDb, err := ConnectMysql(Mysql, Topic) if err != nil { return nil } ADb = &Db{SQLDb, key, nil} // m.putDb(key, Db) return ADb } func init() { MPool = &MysqlPool{poolList: make(map[string]*poolInfo)} } var MPool *MysqlPool func connectToMySQLViaSSH(AppCnf *Type.App, ServerId int) (*sqlx.DB, *ssh.Client, error) { MysqlConfig, err := common.GetMysqlConfig(AppCnf.MysqlName) if err != nil { return nil, nil, fmt.Errorf("failed to get MySQL config: %v", err) } if common.GetSsh() { SshConfig, err := common.GetServerConfig(AppCnf.NodeName) if err != nil { return nil, nil, fmt.Errorf("failed to get SSH config: %v", err) } SP, _ := Decrypt(SshConfig.Password) // 创建 SSH 客户端配置 sshConfig := &ssh.ClientConfig{ User: SshConfig.Username, Auth: []ssh.AuthMethod{ ssh.Password(SP), }, HostKeyCallback: ssh.InsecureIgnoreHostKey(), } // 连接到 SSH 服务器 sshConn, err := ssh.Dial("tcp", fmt.Sprintf("%s:%d", SshConfig.Host, SshConfig.Port), sshConfig) if err != nil { return nil, nil, fmt.Errorf("failed to dial SSH: %v", err) } // defer sshConn.Close() // 创建到 MySQL 服务器的隧道 mysqlConn, err := sshConn.Dial("tcp", fmt.Sprintf("%s:%d", MysqlConfig.Host, MysqlConfig.Port)) if err != nil { return nil, nil, fmt.Errorf("failed to dial MySQL: %v", err) } // 注册 MySQL 驱动 mysql.RegisterDialContext("mysql+tcp", func(ctx context.Context, addr string) (net.Conn, error) { return mysqlConn, nil }) var Database string if strings.Contains(AppCnf.Database, "%") { Database = fmt.Sprintf(AppCnf.Database, ServerId) } else { Database = AppCnf.Database } MP, _ := Decrypt(MysqlConfig.Password) // 连接到 MySQL 数据库 dsn := fmt.Sprintf("%s:%s@mysql+tcp(%s:%d)/%s", MysqlConfig.Username, MP, MysqlConfig.Host, MysqlConfig.Port, Database) db, err := sqlx.Open("mysql", dsn) if err != nil { return nil, nil, fmt.Errorf("failed to open MySQL: %v", err) } return db, sshConn, nil } else { var Database string if strings.Contains(AppCnf.Database, "%") { Database = fmt.Sprintf(AppCnf.Database, ServerId) } else { Database = AppCnf.Database } MP, _ := Decrypt(MysqlConfig.Password) dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s", MysqlConfig.Username, MP, MysqlConfig.Host, MysqlConfig.Port, Database) db, err := sqlx.Open("mysql", dsn) if err != nil { return nil, nil, fmt.Errorf("failed to open MySQL: %v", err) } return db, nil, nil } } func ConnectMysql(MysqlName, DataBase string) (*sqlx.DB, error) { MysqlConfig, err := common.GetMysqlConfig(MysqlName) if err != nil { return nil, fmt.Errorf("failed to get MySQL config: %v", err) } MP, _ := Decrypt(MysqlConfig.Password) // 连接到 MySQL 数据库 dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s", MysqlConfig.Username, MP, MysqlConfig.Host, MysqlConfig.Port, DataBase) db, err := sqlx.Open("mysql", dsn) if err != nil { return nil, fmt.Errorf("failed to open MySQL: %v", err) } return db, nil } func GetAppConfig(AppId int) (*Type.App, error) { Db := MPool.GetTopicDB("game") defer Db.Close() var app Type.App err := Db.Get(&app, "SELECT * FROM app WHERE `AppId` = ?", AppId) if err != nil { return nil, fmt.Errorf("failed to scan rows: %v", err) } return &app, nil } func GetServerConfig(AppId, ServerId int) (*Type.ServerInfo, error) { Db := MPool.GetTopicDB("game") defer Db.Close() var server Type.ServerInfo err := Db.Get(&server, "SELECT `AppId`,`ServerId`,`Host`,`Status`,`ws_port`,`ecs`, `grpc_port` FROM server WHERE `AppId` = ? AND `ServerId` = ?", AppId, ServerId) if err != nil { return nil, fmt.Errorf("failed to scan rows: %v", err) } return &server, nil }