|
@@ -7,6 +7,7 @@ import (
|
|
|
"net/http"
|
|
"net/http"
|
|
|
"net/url"
|
|
"net/url"
|
|
|
|
|
|
|
|
|
|
+ "github.com/go-kratos/kratos/v2/log"
|
|
|
"github.com/go-kratos/kratos/v2/transport"
|
|
"github.com/go-kratos/kratos/v2/transport"
|
|
|
|
|
|
|
|
"github.com/mark3labs/mcp-go/server"
|
|
"github.com/mark3labs/mcp-go/server"
|
|
@@ -79,14 +80,16 @@ func NewServer(name, version string, opts ...ServerOption) *Server {
|
|
|
for _, o := range opts {
|
|
for _, o := range opts {
|
|
|
o(srv)
|
|
o(srv)
|
|
|
}
|
|
}
|
|
|
- srv.MCPServer = server.NewMCPServer(name, version, srv.srvOpts...)
|
|
|
|
|
- srv.srv = &http.Server{Handler: srv.middleware(srv)}
|
|
|
|
|
- srv.streamableHttpServer = server.NewStreamableHTTPServer(srv.MCPServer, append(srv.streamableHTTPOpts, server.WithStreamableHTTPServer(srv.srv))...)
|
|
|
|
|
if srv.mux == nil {
|
|
if srv.mux == nil {
|
|
|
srv.mux = http.NewServeMux()
|
|
srv.mux = http.NewServeMux()
|
|
|
}
|
|
}
|
|
|
|
|
+ srv.MCPServer = server.NewMCPServer(name, version, srv.srvOpts...)
|
|
|
|
|
+ if srv.srv == nil {
|
|
|
|
|
+ srv.srv = &http.Server{}
|
|
|
|
|
+ }
|
|
|
|
|
+ srv.streamableHttpServer = server.NewStreamableHTTPServer(srv.MCPServer, append(srv.streamableHTTPOpts, server.WithStreamableHTTPServer(srv.srv))...)
|
|
|
srv.mux.Handle("/mcp", srv.streamableHttpServer)
|
|
srv.mux.Handle("/mcp", srv.streamableHttpServer)
|
|
|
- srv.srv.Handler = srv.mux
|
|
|
|
|
|
|
+ srv.srv.Handler = srv.middleware(srv.mux)
|
|
|
|
|
|
|
|
return srv
|
|
return srv
|
|
|
}
|
|
}
|
|
@@ -108,6 +111,7 @@ func (s *Server) Endpoint() (*url.URL, error) {
|
|
|
|
|
|
|
|
// Start start the MCP server.
|
|
// Start start the MCP server.
|
|
|
func (s *Server) Start(_ context.Context) error {
|
|
func (s *Server) Start(_ context.Context) error {
|
|
|
|
|
+ log.Infof("[MCP] server listening on: %s", s.address)
|
|
|
if err := s.streamableHttpServer.Start(s.address); err != nil {
|
|
if err := s.streamableHttpServer.Start(s.address); err != nil {
|
|
|
if !errors.Is(err, http.ErrServerClosed) {
|
|
if !errors.Is(err, http.ErrServerClosed) {
|
|
|
return err
|
|
return err
|
|
@@ -118,5 +122,15 @@ func (s *Server) Start(_ context.Context) error {
|
|
|
|
|
|
|
|
// Stop stop the MCP server.
|
|
// Stop stop the MCP server.
|
|
|
func (s *Server) Stop(ctx context.Context) error {
|
|
func (s *Server) Stop(ctx context.Context) error {
|
|
|
- return s.streamableHttpServer.Shutdown(ctx)
|
|
|
|
|
|
|
+ defer func() {
|
|
|
|
|
+ log.Info("[MCP] server stopping")
|
|
|
|
|
+ }()
|
|
|
|
|
+ err := s.streamableHttpServer.Shutdown(ctx)
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ if ctx.Err() != nil {
|
|
|
|
|
+ log.Warn("[MCP] server couldn't stop gracefully in time, doing force stop")
|
|
|
|
|
+ err = s.srv.Close()
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ return err
|
|
|
}
|
|
}
|