| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264 |
- package mcp
- import (
- "context"
- "encoding/json"
- "errors"
- "fmt"
- "strings"
- annotations2 "git.ikuban.com/server/kubanapis/kuban/api/annotations"
- openapi_v3 "github.com/google/gnostic/openapiv3"
- mcp2 "github.com/mark3labs/mcp-go/mcp"
- "github.com/mark3labs/mcp-go/server"
- "google.golang.org/genproto/googleapis/api/annotations"
- "google.golang.org/grpc"
- "google.golang.org/protobuf/proto"
- "google.golang.org/protobuf/reflect/protoreflect"
- "google.golang.org/protobuf/reflect/protoregistry"
- )
- func ServerAddTools(s *server.MCPServer, srv any, svcDesc grpc.ServiceDesc) error {
- serviceName := strings.ReplaceAll(svcDesc.ServiceName, ".", "_")
- handlerMap := make(map[string]grpc.MethodDesc)
- for _, _v := range svcDesc.Methods {
- v := _v
- mapK := serviceName + "_" + v.MethodName
- handlerMap[mapK] = v
- }
- d, err := protoregistry.GlobalFiles.FindFileByPath(svcDesc.Metadata.(string))
- if err != nil {
- return err
- }
- if d.Services().Len() == 0 {
- return nil
- }
- ser := d.Services().Get(0)
- for j := 0; j < ser.Methods().Len(); j++ {
- method := ser.Methods().Get(j)
- t, h := serverAddToolsByMethod(serviceName, srv, method, handlerMap)
- if t == nil || h == nil {
- continue
- }
- s.AddTool(*t, h)
- }
- return nil
- }
- func serverAddToolsByMethod(serviceName string, srv any, method protoreflect.MethodDescriptor, handlerMap map[string]grpc.MethodDesc) (*mcp2.Tool, server.ToolHandlerFunc) {
- methodMcpOpts, _ := proto.GetExtension(method.Options(), annotations2.E_Options).(*annotations2.Options)
- if methodMcpOpts == nil || methodMcpOpts.McpOptions == nil || !methodMcpOpts.GetMcpOptions().Enabled {
- return nil, nil
- }
- methodOperation, _ := proto.GetExtension(method.Options(), openapi_v3.E_Operation).(*openapi_v3.Operation)
- description := ""
- if methodOperation != nil {
- description = methodOperation.Description
- if description == "" {
- description = methodOperation.Summary
- }
- }
- toolOptions := []mcp2.ToolOption{mcp2.WithDescription(description)}
- for k := 0; k < method.Input().Fields().Len(); k++ {
- input := method.Input().Fields().Get(k)
- inputOperation, _ := proto.GetExtension(input.Options(), openapi_v3.E_Property).(*openapi_v3.Schema)
- inputOperation2, _ := proto.GetExtension(input.Options(), annotations.E_FieldBehavior).([]annotations.FieldBehavior)
- inputDescription := ""
- if inputOperation != nil {
- inputDescription = inputOperation.GetDescription()
- }
- propertyOption := make([]mcp2.PropertyOption, 0)
- if inputOperation2 != nil && len(inputOperation2) > 0 && inputOperation2[0] == annotations.FieldBehavior_REQUIRED {
- propertyOption = append(propertyOption, mcp2.Required())
- }
- if input.IsList() {
- propertyOption = append(propertyOption, mcp2.Items(getInputArrayItems(input)))
- propertyOption = append(propertyOption, mcp2.Description(inputDescription))
- toolOptions = append(toolOptions, mcp2.WithArray(string(input.Name()), propertyOption...))
- } else if input.IsMap() {
- additionalProperties, descriptionSuffix := getInputMapProperties(input)
- inputDescription = inputDescription + descriptionSuffix
- propertyOption = append(propertyOption, mcp2.Description(inputDescription))
- propertyOption = append(propertyOption, mcp2.AdditionalProperties(additionalProperties))
- toolOptions = append(toolOptions, mcp2.WithObject(string(input.Name()), propertyOption...))
- } else {
- propertyOption = append(propertyOption, mcp2.Description(inputDescription))
- switch input.Kind() {
- case protoreflect.StringKind:
- toolOptions = append(toolOptions, mcp2.WithString(string(input.Name()), propertyOption...))
- case protoreflect.BoolKind:
- toolOptions = append(toolOptions, mcp2.WithBoolean(string(input.Name()), propertyOption...))
- case protoreflect.DoubleKind, protoreflect.FloatKind,
- protoreflect.Sfixed64Kind, protoreflect.Sfixed32Kind,
- protoreflect.Fixed64Kind, protoreflect.Fixed32Kind,
- protoreflect.Sint64Kind, protoreflect.Sint32Kind,
- protoreflect.Uint64Kind, protoreflect.Uint32Kind,
- protoreflect.Int64Kind, protoreflect.Int32Kind:
- toolOptions = append(toolOptions, mcp2.WithNumber(string(input.Name()), propertyOption...))
- case protoreflect.MessageKind:
- propertyOption = append(propertyOption, mcp2.Properties(getFiledMessageParamProperties(input.Message(), false)))
- toolOptions = append(toolOptions, mcp2.WithObject(string(input.Name()), propertyOption...))
- }
- }
- }
- toolName := serviceName + "_" + string(method.Name())
- t := mcp2.NewTool(toolName, toolOptions...)
- h := func(ctx context.Context, request mcp2.CallToolRequest) (*mcp2.CallToolResult, error) {
- if _, ok := handlerMap[toolName]; !ok {
- return nil, errors.New("没有实现")
- }
- arg := request.GetArguments()
- argJson, _ := json.Marshal(arg)
- dec := func(in any) error {
- decErr := json.Unmarshal(argJson, &in)
- if decErr != nil {
- return decErr
- }
- return nil
- }
- interceptor := func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp any, err error) {
- return handler(ctx, req)
- }
- handler := handlerMap[toolName]
- out, outErr := handler.Handler(srv, ctx, dec, interceptor)
- if outErr != nil {
- return nil, outErr
- }
- outJson, _ := json.Marshal(out)
- callToolResult := &mcp2.CallToolResult{
- Content: []mcp2.Content{
- mcp2.TextContent{
- Type: "text",
- Text: string(outJson),
- },
- },
- }
- return callToolResult, nil
- }
- return &t, h
- }
- func getInputArrayItems(input protoreflect.FieldDescriptor) map[string]any {
- inputMap := make(map[string]any)
- switch input.Kind() {
- case protoreflect.StringKind:
- inputMap["type"] = "string"
- case protoreflect.BoolKind:
- inputMap["type"] = "boolean"
- case protoreflect.DoubleKind, protoreflect.FloatKind,
- protoreflect.Sfixed64Kind, protoreflect.Sfixed32Kind,
- protoreflect.Fixed64Kind, protoreflect.Fixed32Kind,
- protoreflect.Sint64Kind, protoreflect.Sint32Kind,
- protoreflect.Uint64Kind, protoreflect.Uint32Kind,
- protoreflect.Int64Kind, protoreflect.Int32Kind:
- inputMap["type"] = "number"
- case protoreflect.MessageKind:
- inputMap["type"] = "object"
- propertiesMap := getFiledMessageParamProperties(input.Message(), true)
- inputMap["properties"] = propertiesMap
- }
- return inputMap
- }
- func getInputMapProperties(input protoreflect.FieldDescriptor) (map[string]any, string) {
- messageParamMap := make(map[string]any)
- paramMap := make(map[string]any)
- switch input.MapValue().Kind() {
- case protoreflect.StringKind:
- paramMap["type"] = "string"
- case protoreflect.BoolKind:
- paramMap["type"] = "boolean"
- case protoreflect.DoubleKind, protoreflect.FloatKind,
- protoreflect.Sfixed64Kind, protoreflect.Sfixed32Kind,
- protoreflect.Fixed64Kind, protoreflect.Fixed32Kind,
- protoreflect.Sint64Kind, protoreflect.Sint32Kind,
- protoreflect.Uint64Kind, protoreflect.Uint32Kind,
- protoreflect.Int64Kind, protoreflect.Int32Kind:
- paramMap["type"] = "number"
- case protoreflect.MessageKind:
- paramMap["type"] = "object"
- paramMap["properties"] = getFiledMessageParamProperties(input.Message(), false)
- default:
- break
- }
- var keyType string
- switch input.MapKey().Kind() {
- case protoreflect.StringKind:
- keyType = "string"
- case protoreflect.BoolKind:
- keyType = "boolean"
- case protoreflect.DoubleKind, protoreflect.FloatKind,
- protoreflect.Sfixed64Kind, protoreflect.Sfixed32Kind,
- protoreflect.Fixed64Kind, protoreflect.Fixed32Kind,
- protoreflect.Sint64Kind, protoreflect.Sint32Kind,
- protoreflect.Uint64Kind, protoreflect.Uint32Kind,
- protoreflect.Int64Kind, protoreflect.Int32Kind:
- keyType = "number"
- }
- descriptionSuffix := fmt.Sprintf("(type of key: %s, type of value: %s)", keyType, paramMap["type"])
- messageParamMap[string(input.Name())] = paramMap
- return messageParamMap, descriptionSuffix
- }
- func getFiledMessageParamProperties(message protoreflect.MessageDescriptor, needRequired bool) map[string]any {
- messageParamMap := make(map[string]any)
- for i := 0; i < message.Fields().Len(); i++ {
- input := message.Fields().Get(i)
- paramMap := make(map[string]any)
- inputOperation, _ := proto.GetExtension(input.Options(), openapi_v3.E_Property).(*openapi_v3.Schema)
- inputDescription := ""
- if inputOperation != nil {
- inputDescription = inputOperation.GetDescription()
- }
- paramMap["description"] = inputDescription
- if needRequired {
- inputOperation2, _ := proto.GetExtension(input.Options(), annotations.E_FieldBehavior).([]annotations.FieldBehavior)
- if inputOperation2 != nil && len(inputOperation2) > 0 && inputOperation2[0] == annotations.FieldBehavior_REQUIRED {
- paramMap["required"] = true
- }
- }
- if input.IsList() {
- paramMap["type"] = "array"
- paramMap["items"] = getInputArrayItems(input)
- } else if input.IsMap() {
- paramMap["type"] = "object"
- var descriptionSuffix string
- paramMap["additionalProperties"], descriptionSuffix = getInputMapProperties(input)
- paramMap["description"] = inputDescription + descriptionSuffix
- } else {
- switch input.Kind() {
- case protoreflect.StringKind:
- paramMap["type"] = "string"
- case protoreflect.BoolKind:
- paramMap["type"] = "boolean"
- case protoreflect.DoubleKind, protoreflect.FloatKind,
- protoreflect.Sfixed64Kind, protoreflect.Sfixed32Kind,
- protoreflect.Fixed64Kind, protoreflect.Fixed32Kind,
- protoreflect.Sint64Kind, protoreflect.Sint32Kind,
- protoreflect.Uint64Kind, protoreflect.Uint32Kind,
- protoreflect.Int64Kind, protoreflect.Int32Kind:
- paramMap["type"] = "number"
- case protoreflect.MessageKind:
- paramMap["type"] = "object"
- paramMap["properties"] = getFiledMessageParamProperties(input.Message(), false)
- default:
- break
- }
- }
- messageParamMap[string(input.Name())] = paramMap
- }
- return messageParamMap
- }
|