package mcp import ( "context" "encoding/json" "errors" "fmt" "strings" annotations2 "git.ikuban.com/server/kubanapis/kuban/api/annotations" openapi_v3 "github.com/google/gnostic/openapiv3" mcp2 "github.com/mark3labs/mcp-go/mcp" "github.com/mark3labs/mcp-go/server" "google.golang.org/genproto/googleapis/api/annotations" "google.golang.org/grpc" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/reflect/protoreflect" "google.golang.org/protobuf/reflect/protoregistry" ) func ServerAddTools(s *server.MCPServer, srv any, svcDesc grpc.ServiceDesc) error { serviceName := strings.ReplaceAll(svcDesc.ServiceName, ".", "_") handlerMap := make(map[string]grpc.MethodDesc) for _, _v := range svcDesc.Methods { v := _v mapK := serviceName + "_" + v.MethodName handlerMap[mapK] = v } d, err := protoregistry.GlobalFiles.FindFileByPath(svcDesc.Metadata.(string)) if err != nil { return err } if d.Services().Len() == 0 { return nil } ser := d.Services().Get(0) for j := 0; j < ser.Methods().Len(); j++ { method := ser.Methods().Get(j) t, h := serverAddToolsByMethod(serviceName, srv, method, handlerMap) if t == nil || h == nil { continue } s.AddTool(*t, h) } return nil } func serverAddToolsByMethod(serviceName string, srv any, method protoreflect.MethodDescriptor, handlerMap map[string]grpc.MethodDesc) (*mcp2.Tool, server.ToolHandlerFunc) { methodMcpOpts, _ := proto.GetExtension(method.Options(), annotations2.E_Options).(*annotations2.Options) if methodMcpOpts == nil || methodMcpOpts.McpOptions == nil || !methodMcpOpts.GetMcpOptions().Enabled { return nil, nil } methodOperation, _ := proto.GetExtension(method.Options(), openapi_v3.E_Operation).(*openapi_v3.Operation) description := "" if methodOperation != nil { description = methodOperation.Description if description == "" { description = methodOperation.Summary } } toolOptions := []mcp2.ToolOption{mcp2.WithDescription(description)} for k := 0; k < method.Input().Fields().Len(); k++ { input := method.Input().Fields().Get(k) inputOperation, _ := proto.GetExtension(input.Options(), openapi_v3.E_Property).(*openapi_v3.Schema) inputOperation2, _ := proto.GetExtension(input.Options(), annotations.E_FieldBehavior).([]annotations.FieldBehavior) inputDescription := "" if inputOperation != nil { inputDescription = inputOperation.GetDescription() } propertyOption := make([]mcp2.PropertyOption, 0) if inputOperation2 != nil && len(inputOperation2) > 0 && inputOperation2[0] == annotations.FieldBehavior_REQUIRED { propertyOption = append(propertyOption, mcp2.Required()) } if input.IsList() { propertyOption = append(propertyOption, mcp2.Items(getInputArrayItems(input))) propertyOption = append(propertyOption, mcp2.Description(inputDescription)) toolOptions = append(toolOptions, mcp2.WithArray(string(input.Name()), propertyOption...)) } else if input.IsMap() { additionalProperties, descriptionSuffix := getInputMapProperties(input) inputDescription = inputDescription + descriptionSuffix propertyOption = append(propertyOption, mcp2.Description(inputDescription)) propertyOption = append(propertyOption, mcp2.AdditionalProperties(additionalProperties)) toolOptions = append(toolOptions, mcp2.WithObject(string(input.Name()), propertyOption...)) } else { propertyOption = append(propertyOption, mcp2.Description(inputDescription)) switch input.Kind() { case protoreflect.StringKind: toolOptions = append(toolOptions, mcp2.WithString(string(input.Name()), propertyOption...)) case protoreflect.BoolKind: toolOptions = append(toolOptions, mcp2.WithBoolean(string(input.Name()), propertyOption...)) case protoreflect.DoubleKind, protoreflect.FloatKind, protoreflect.Sfixed64Kind, protoreflect.Sfixed32Kind, protoreflect.Fixed64Kind, protoreflect.Fixed32Kind, protoreflect.Sint64Kind, protoreflect.Sint32Kind, protoreflect.Uint64Kind, protoreflect.Uint32Kind, protoreflect.Int64Kind, protoreflect.Int32Kind: toolOptions = append(toolOptions, mcp2.WithNumber(string(input.Name()), propertyOption...)) case protoreflect.MessageKind: propertyOption = append(propertyOption, mcp2.Properties(getFiledMessageParamProperties(input.Message(), false))) toolOptions = append(toolOptions, mcp2.WithObject(string(input.Name()), propertyOption...)) } } } toolName := serviceName + "_" + string(method.Name()) t := mcp2.NewTool(toolName, toolOptions...) h := func(ctx context.Context, request mcp2.CallToolRequest) (*mcp2.CallToolResult, error) { if _, ok := handlerMap[toolName]; !ok { return nil, errors.New("没有实现") } arg := request.GetArguments() argJson, _ := json.Marshal(arg) dec := func(in any) error { decErr := json.Unmarshal(argJson, &in) if decErr != nil { return decErr } return nil } interceptor := func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp any, err error) { return handler(ctx, req) } handler := handlerMap[toolName] out, outErr := handler.Handler(srv, ctx, dec, interceptor) if outErr != nil { return nil, outErr } outJson, _ := json.Marshal(out) callToolResult := &mcp2.CallToolResult{ Content: []mcp2.Content{ mcp2.TextContent{ Type: "text", Text: string(outJson), }, }, } return callToolResult, nil } return &t, h } func getInputArrayItems(input protoreflect.FieldDescriptor) map[string]any { inputMap := make(map[string]any) switch input.Kind() { case protoreflect.StringKind: inputMap["type"] = "string" case protoreflect.BoolKind: inputMap["type"] = "boolean" case protoreflect.DoubleKind, protoreflect.FloatKind, protoreflect.Sfixed64Kind, protoreflect.Sfixed32Kind, protoreflect.Fixed64Kind, protoreflect.Fixed32Kind, protoreflect.Sint64Kind, protoreflect.Sint32Kind, protoreflect.Uint64Kind, protoreflect.Uint32Kind, protoreflect.Int64Kind, protoreflect.Int32Kind: inputMap["type"] = "number" case protoreflect.MessageKind: inputMap["type"] = "object" propertiesMap := getFiledMessageParamProperties(input.Message(), true) inputMap["properties"] = propertiesMap } return inputMap } func getInputMapProperties(input protoreflect.FieldDescriptor) (map[string]any, string) { messageParamMap := make(map[string]any) paramMap := make(map[string]any) switch input.MapValue().Kind() { case protoreflect.StringKind: paramMap["type"] = "string" case protoreflect.BoolKind: paramMap["type"] = "boolean" case protoreflect.DoubleKind, protoreflect.FloatKind, protoreflect.Sfixed64Kind, protoreflect.Sfixed32Kind, protoreflect.Fixed64Kind, protoreflect.Fixed32Kind, protoreflect.Sint64Kind, protoreflect.Sint32Kind, protoreflect.Uint64Kind, protoreflect.Uint32Kind, protoreflect.Int64Kind, protoreflect.Int32Kind: paramMap["type"] = "number" case protoreflect.MessageKind: paramMap["type"] = "object" paramMap["properties"] = getFiledMessageParamProperties(input.Message(), false) default: break } var keyType string switch input.MapKey().Kind() { case protoreflect.StringKind: keyType = "string" case protoreflect.BoolKind: keyType = "boolean" case protoreflect.DoubleKind, protoreflect.FloatKind, protoreflect.Sfixed64Kind, protoreflect.Sfixed32Kind, protoreflect.Fixed64Kind, protoreflect.Fixed32Kind, protoreflect.Sint64Kind, protoreflect.Sint32Kind, protoreflect.Uint64Kind, protoreflect.Uint32Kind, protoreflect.Int64Kind, protoreflect.Int32Kind: keyType = "number" } descriptionSuffix := fmt.Sprintf("(type of key: %s, type of value: %s)", keyType, paramMap["type"]) messageParamMap[string(input.Name())] = paramMap return messageParamMap, descriptionSuffix } func getFiledMessageParamProperties(message protoreflect.MessageDescriptor, needRequired bool) map[string]any { messageParamMap := make(map[string]any) for i := 0; i < message.Fields().Len(); i++ { input := message.Fields().Get(i) paramMap := make(map[string]any) inputOperation, _ := proto.GetExtension(input.Options(), openapi_v3.E_Property).(*openapi_v3.Schema) inputDescription := "" if inputOperation != nil { inputDescription = inputOperation.GetDescription() } paramMap["description"] = inputDescription if needRequired { inputOperation2, _ := proto.GetExtension(input.Options(), annotations.E_FieldBehavior).([]annotations.FieldBehavior) if inputOperation2 != nil && len(inputOperation2) > 0 && inputOperation2[0] == annotations.FieldBehavior_REQUIRED { paramMap["required"] = true } } if input.IsList() { paramMap["type"] = "array" paramMap["items"] = getInputArrayItems(input) } else if input.IsMap() { paramMap["type"] = "object" var descriptionSuffix string paramMap["additionalProperties"], descriptionSuffix = getInputMapProperties(input) paramMap["description"] = inputDescription + descriptionSuffix } else { switch input.Kind() { case protoreflect.StringKind: paramMap["type"] = "string" case protoreflect.BoolKind: paramMap["type"] = "boolean" case protoreflect.DoubleKind, protoreflect.FloatKind, protoreflect.Sfixed64Kind, protoreflect.Sfixed32Kind, protoreflect.Fixed64Kind, protoreflect.Fixed32Kind, protoreflect.Sint64Kind, protoreflect.Sint32Kind, protoreflect.Uint64Kind, protoreflect.Uint32Kind, protoreflect.Int64Kind, protoreflect.Int32Kind: paramMap["type"] = "number" case protoreflect.MessageKind: paramMap["type"] = "object" paramMap["properties"] = getFiledMessageParamProperties(input.Message(), false) default: break } } messageParamMap[string(input.Name())] = paramMap } return messageParamMap }