|
|
@@ -0,0 +1,165 @@
|
|
|
+package mcp
|
|
|
+
|
|
|
+import (
|
|
|
+ "context"
|
|
|
+ "encoding/json"
|
|
|
+ "errors"
|
|
|
+ "strings"
|
|
|
+
|
|
|
+ "git.ikuban.com/server/base-protobuf/kuban/options"
|
|
|
+ openapi_v3 "github.com/google/gnostic/openapiv3"
|
|
|
+ mcp2 "github.com/mark3labs/mcp-go/mcp"
|
|
|
+ "github.com/mark3labs/mcp-go/server"
|
|
|
+ "google.golang.org/genproto/googleapis/api/annotations"
|
|
|
+ "google.golang.org/grpc"
|
|
|
+ "google.golang.org/protobuf/proto"
|
|
|
+ "google.golang.org/protobuf/reflect/protoreflect"
|
|
|
+ "google.golang.org/protobuf/reflect/protoregistry"
|
|
|
+)
|
|
|
+
|
|
|
+func ServerAddTools(s *server.MCPServer, srv any, svcDesc grpc.ServiceDesc) error {
|
|
|
+ serviceName := strings.ReplaceAll(svcDesc.ServiceName, ".", "_")
|
|
|
+
|
|
|
+ handlerMap := make(map[string]grpc.MethodDesc)
|
|
|
+
|
|
|
+ for _, _v := range svcDesc.Methods {
|
|
|
+ v := _v
|
|
|
+ mapK := serviceName + "_" + v.MethodName
|
|
|
+ handlerMap[mapK] = v
|
|
|
+ }
|
|
|
+ d, err := protoregistry.GlobalFiles.FindFileByPath(svcDesc.Metadata.(string))
|
|
|
+ if err != nil {
|
|
|
+
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ if d.Services().Len() == 0 {
|
|
|
+ return nil
|
|
|
+ }
|
|
|
+ ser := d.Services().Get(0)
|
|
|
+
|
|
|
+ for j := 0; j < ser.Methods().Len(); j++ {
|
|
|
+ method := ser.Methods().Get(j)
|
|
|
+ t, h := serverAddToolsByMethod(serviceName, srv, method, handlerMap)
|
|
|
+ s.AddTool(*t, h)
|
|
|
+ }
|
|
|
+ return nil
|
|
|
+}
|
|
|
+
|
|
|
+func serverAddToolsByMethod(serviceName string, srv any, method protoreflect.MethodDescriptor, handlerMap map[string]grpc.MethodDesc) (*mcp2.Tool, server.ToolHandlerFunc) {
|
|
|
+ methodMcpOpts, _ := proto.GetExtension(method.Options(), options.E_McpOptions).(*options.McpOptions)
|
|
|
+ if methodMcpOpts == nil || !methodMcpOpts.Enabled {
|
|
|
+ return nil, nil
|
|
|
+ }
|
|
|
+ methodOperation, _ := proto.GetExtension(method.Options(), openapi_v3.E_Operation).(*openapi_v3.Operation)
|
|
|
+ description := ""
|
|
|
+ if methodOperation != nil {
|
|
|
+ description = methodOperation.Description
|
|
|
+ if description == "" {
|
|
|
+ description = methodOperation.Summary
|
|
|
+ }
|
|
|
+ }
|
|
|
+ toolOptions := []mcp2.ToolOption{mcp2.WithDescription(description)}
|
|
|
+ for k := 0; k < method.Input().Fields().Len(); k++ {
|
|
|
+ input := method.Input().Fields().Get(k)
|
|
|
+ inputOperation, _ := proto.GetExtension(input.Options(), openapi_v3.E_Property).(*openapi_v3.Schema)
|
|
|
+ inputOperation2, _ := proto.GetExtension(input.Options(), annotations.E_FieldBehavior).([]annotations.FieldBehavior)
|
|
|
+ inputDescription := ""
|
|
|
+ if inputOperation != nil {
|
|
|
+ inputDescription = inputOperation.GetDescription()
|
|
|
+ }
|
|
|
+ propertyOption := []mcp2.PropertyOption{mcp2.Description(inputDescription)}
|
|
|
+ if inputOperation2 != nil && len(inputOperation2) > 0 && inputOperation2[0] == annotations.FieldBehavior_REQUIRED {
|
|
|
+ propertyOption = append(propertyOption, mcp2.Required())
|
|
|
+ }
|
|
|
+ switch input.Kind() {
|
|
|
+ case protoreflect.StringKind:
|
|
|
+ toolOptions = append(toolOptions, mcp2.WithString(string(input.Name()), propertyOption...))
|
|
|
+ case protoreflect.BoolKind:
|
|
|
+ toolOptions = append(toolOptions, mcp2.WithBoolean(string(input.Name()), propertyOption...))
|
|
|
+ case protoreflect.DoubleKind, protoreflect.FloatKind,
|
|
|
+ protoreflect.Sfixed64Kind, protoreflect.Sfixed32Kind,
|
|
|
+ protoreflect.Fixed64Kind, protoreflect.Fixed32Kind,
|
|
|
+ protoreflect.Sint64Kind, protoreflect.Sint32Kind,
|
|
|
+ protoreflect.Uint64Kind, protoreflect.Uint32Kind,
|
|
|
+ protoreflect.Int64Kind, protoreflect.Int32Kind:
|
|
|
+ toolOptions = append(toolOptions, mcp2.WithNumber(string(input.Name()), propertyOption...))
|
|
|
+ case protoreflect.MessageKind:
|
|
|
+ propertyOption = append(propertyOption, mcp2.Properties(getFiledMessageParamProperties(input.Message())))
|
|
|
+ toolOptions = append(toolOptions, mcp2.WithObject(string(input.Name()), propertyOption...))
|
|
|
+ }
|
|
|
+ }
|
|
|
+ toolName := serviceName + "_" + string(method.Name())
|
|
|
+ t := mcp2.NewTool(toolName, toolOptions...)
|
|
|
+
|
|
|
+ h := func(ctx context.Context, request mcp2.CallToolRequest) (*mcp2.CallToolResult, error) {
|
|
|
+
|
|
|
+ if _, ok := handlerMap[toolName]; !ok {
|
|
|
+ return nil, errors.New("没有实现")
|
|
|
+ }
|
|
|
+ arg := request.GetArguments()
|
|
|
+ argJson, _ := json.Marshal(arg)
|
|
|
+ dec := func(in any) error {
|
|
|
+ decErr := json.Unmarshal(argJson, &in)
|
|
|
+ if decErr != nil {
|
|
|
+ return decErr
|
|
|
+ }
|
|
|
+ return nil
|
|
|
+ }
|
|
|
+ interceptor := func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp any, err error) {
|
|
|
+ return handler(ctx, req)
|
|
|
+ }
|
|
|
+
|
|
|
+ handler := handlerMap[toolName]
|
|
|
+ out, outErr := handler.Handler(srv, ctx, dec, interceptor)
|
|
|
+ if outErr != nil {
|
|
|
+ return nil, outErr
|
|
|
+ }
|
|
|
+
|
|
|
+ outJson, _ := json.Marshal(out)
|
|
|
+
|
|
|
+ callToolResult := &mcp2.CallToolResult{
|
|
|
+ Content: []mcp2.Content{
|
|
|
+ mcp2.TextContent{
|
|
|
+ Type: "text",
|
|
|
+ Text: string(outJson),
|
|
|
+ },
|
|
|
+ },
|
|
|
+ }
|
|
|
+ return callToolResult, nil
|
|
|
+ }
|
|
|
+ return &t, h
|
|
|
+}
|
|
|
+
|
|
|
+func getFiledMessageParamProperties(message protoreflect.MessageDescriptor) map[string]any {
|
|
|
+
|
|
|
+ messageParamMap := make(map[string]any)
|
|
|
+
|
|
|
+ for i := 0; i < message.Fields().Len(); i++ {
|
|
|
+ input := message.Fields().Get(i)
|
|
|
+ paramMap := make(map[string]any)
|
|
|
+ inputOperation, _ := proto.GetExtension(input.Options(), openapi_v3.E_Property).(*openapi_v3.Schema)
|
|
|
+ inputDescription := ""
|
|
|
+ if inputOperation != nil {
|
|
|
+ inputDescription = inputOperation.GetDescription()
|
|
|
+ }
|
|
|
+ paramMap["description"] = inputDescription
|
|
|
+
|
|
|
+ switch input.Kind() {
|
|
|
+ case protoreflect.StringKind:
|
|
|
+ paramMap["type"] = "string"
|
|
|
+ case protoreflect.BoolKind:
|
|
|
+ paramMap["type"] = "boolean"
|
|
|
+ case protoreflect.DoubleKind, protoreflect.FloatKind,
|
|
|
+ protoreflect.Sfixed64Kind, protoreflect.Sfixed32Kind,
|
|
|
+ protoreflect.Fixed64Kind, protoreflect.Fixed32Kind,
|
|
|
+ protoreflect.Sint64Kind, protoreflect.Sint32Kind,
|
|
|
+ protoreflect.Uint64Kind, protoreflect.Uint32Kind,
|
|
|
+ protoreflect.Int64Kind, protoreflect.Int32Kind:
|
|
|
+ paramMap["type"] = "number"
|
|
|
+ default:
|
|
|
+ break
|
|
|
+ }
|
|
|
+ messageParamMap[string(input.Name())] = paramMap
|
|
|
+ }
|
|
|
+ return messageParamMap
|
|
|
+}
|