Browse Source

fix(mcp): 优化 server启动和停止逻辑

-完善 server 初始化,确保 mux 和 srv 不为 nil
-调整 middleware 应用时机,提高代码可读性
-增加日志输出,便于监控 server 状态
- 优化 Stop 方法,增加优雅停止和强制停止逻辑
dcsunny 4 months ago
parent
commit
a9af99b614
1 changed files with 19 additions and 5 deletions
  1. 19 5
      mcp/server.go

+ 19 - 5
mcp/server.go

@@ -7,6 +7,7 @@ import (
 	"net/http"
 	"net/http"
 	"net/url"
 	"net/url"
 
 
+	"github.com/go-kratos/kratos/v2/log"
 	"github.com/go-kratos/kratos/v2/transport"
 	"github.com/go-kratos/kratos/v2/transport"
 
 
 	"github.com/mark3labs/mcp-go/server"
 	"github.com/mark3labs/mcp-go/server"
@@ -79,14 +80,16 @@ func NewServer(name, version string, opts ...ServerOption) *Server {
 	for _, o := range opts {
 	for _, o := range opts {
 		o(srv)
 		o(srv)
 	}
 	}
-	srv.MCPServer = server.NewMCPServer(name, version, srv.srvOpts...)
-	srv.srv = &http.Server{Handler: srv.middleware(srv)}
-	srv.streamableHttpServer = server.NewStreamableHTTPServer(srv.MCPServer, append(srv.streamableHTTPOpts, server.WithStreamableHTTPServer(srv.srv))...)
 	if srv.mux == nil {
 	if srv.mux == nil {
 		srv.mux = http.NewServeMux()
 		srv.mux = http.NewServeMux()
 	}
 	}
+	srv.MCPServer = server.NewMCPServer(name, version, srv.srvOpts...)
+	if srv.srv == nil {
+		srv.srv = &http.Server{}
+	}
+	srv.streamableHttpServer = server.NewStreamableHTTPServer(srv.MCPServer, append(srv.streamableHTTPOpts, server.WithStreamableHTTPServer(srv.srv))...)
 	srv.mux.Handle("/mcp", srv.streamableHttpServer)
 	srv.mux.Handle("/mcp", srv.streamableHttpServer)
-	srv.srv.Handler = srv.mux
+	srv.srv.Handler = srv.middleware(srv.mux)
 
 
 	return srv
 	return srv
 }
 }
@@ -108,6 +111,7 @@ func (s *Server) Endpoint() (*url.URL, error) {
 
 
 // Start start the MCP server.
 // Start start the MCP server.
 func (s *Server) Start(_ context.Context) error {
 func (s *Server) Start(_ context.Context) error {
+	log.Infof("[MCP] server listening on: %s", s.address)
 	if err := s.streamableHttpServer.Start(s.address); err != nil {
 	if err := s.streamableHttpServer.Start(s.address); err != nil {
 		if !errors.Is(err, http.ErrServerClosed) {
 		if !errors.Is(err, http.ErrServerClosed) {
 			return err
 			return err
@@ -118,5 +122,15 @@ func (s *Server) Start(_ context.Context) error {
 
 
 // Stop stop the MCP server.
 // Stop stop the MCP server.
 func (s *Server) Stop(ctx context.Context) error {
 func (s *Server) Stop(ctx context.Context) error {
-	return s.streamableHttpServer.Shutdown(ctx)
+	defer func() {
+		log.Info("[MCP] server stopping")
+	}()
+	err := s.streamableHttpServer.Shutdown(ctx)
+	if err != nil {
+		if ctx.Err() != nil {
+			log.Warn("[MCP] server couldn't stop gracefully in time, doing force stop")
+			err = s.srv.Close()
+		}
+	}
+	return err
 }
 }