| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264 | package mcpimport (	"context"	"encoding/json"	"errors"	"fmt"	"strings"	annotations2 "git.ikuban.com/server/kubanapis/kuban/api/annotations"	openapi_v3 "github.com/google/gnostic/openapiv3"	mcp2 "github.com/mark3labs/mcp-go/mcp"	"github.com/mark3labs/mcp-go/server"	"google.golang.org/genproto/googleapis/api/annotations"	"google.golang.org/grpc"	"google.golang.org/protobuf/proto"	"google.golang.org/protobuf/reflect/protoreflect"	"google.golang.org/protobuf/reflect/protoregistry")func ServerAddTools(s *server.MCPServer, srv any, svcDesc grpc.ServiceDesc) error {	serviceName := strings.ReplaceAll(svcDesc.ServiceName, ".", "_")	handlerMap := make(map[string]grpc.MethodDesc)	for _, _v := range svcDesc.Methods {		v := _v		mapK := serviceName + "_" + v.MethodName		handlerMap[mapK] = v	}	d, err := protoregistry.GlobalFiles.FindFileByPath(svcDesc.Metadata.(string))	if err != nil {		return err	}	if d.Services().Len() == 0 {		return nil	}	ser := d.Services().Get(0)	for j := 0; j < ser.Methods().Len(); j++ {		method := ser.Methods().Get(j)		t, h := serverAddToolsByMethod(serviceName, srv, method, handlerMap)		if t == nil || h == nil {			continue		}		s.AddTool(*t, h)	}	return nil}func serverAddToolsByMethod(serviceName string, srv any, method protoreflect.MethodDescriptor, handlerMap map[string]grpc.MethodDesc) (*mcp2.Tool, server.ToolHandlerFunc) {	methodMcpOpts, _ := proto.GetExtension(method.Options(), annotations2.E_Options).(*annotations2.Options)	if methodMcpOpts == nil || methodMcpOpts.McpOptions == nil || !methodMcpOpts.GetMcpOptions().Enabled {		return nil, nil	}	methodOperation, _ := proto.GetExtension(method.Options(), openapi_v3.E_Operation).(*openapi_v3.Operation)	description := ""	if methodOperation != nil {		description = methodOperation.Description		if description == "" {			description = methodOperation.Summary		}	}	toolOptions := []mcp2.ToolOption{mcp2.WithDescription(description)}	for k := 0; k < method.Input().Fields().Len(); k++ {		input := method.Input().Fields().Get(k)		inputOperation, _ := proto.GetExtension(input.Options(), openapi_v3.E_Property).(*openapi_v3.Schema)		inputOperation2, _ := proto.GetExtension(input.Options(), annotations.E_FieldBehavior).([]annotations.FieldBehavior)		inputDescription := ""		if inputOperation != nil {			inputDescription = inputOperation.GetDescription()		}		propertyOption := make([]mcp2.PropertyOption, 0)		if inputOperation2 != nil && len(inputOperation2) > 0 && inputOperation2[0] == annotations.FieldBehavior_REQUIRED {			propertyOption = append(propertyOption, mcp2.Required())		}		if input.IsList() {			propertyOption = append(propertyOption, mcp2.Items(getInputArrayItems(input)))			propertyOption = append(propertyOption, mcp2.Description(inputDescription))			toolOptions = append(toolOptions, mcp2.WithArray(string(input.Name()), propertyOption...))		} else if input.IsMap() {			additionalProperties, descriptionSuffix := getInputMapProperties(input)			inputDescription = inputDescription + descriptionSuffix			propertyOption = append(propertyOption, mcp2.Description(inputDescription))			propertyOption = append(propertyOption, mcp2.AdditionalProperties(additionalProperties))			toolOptions = append(toolOptions, mcp2.WithObject(string(input.Name()), propertyOption...))		} else {			propertyOption = append(propertyOption, mcp2.Description(inputDescription))			switch input.Kind() {			case protoreflect.StringKind:				toolOptions = append(toolOptions, mcp2.WithString(string(input.Name()), propertyOption...))			case protoreflect.BoolKind:				toolOptions = append(toolOptions, mcp2.WithBoolean(string(input.Name()), propertyOption...))			case protoreflect.DoubleKind, protoreflect.FloatKind,				protoreflect.Sfixed64Kind, protoreflect.Sfixed32Kind,				protoreflect.Fixed64Kind, protoreflect.Fixed32Kind,				protoreflect.Sint64Kind, protoreflect.Sint32Kind,				protoreflect.Uint64Kind, protoreflect.Uint32Kind,				protoreflect.Int64Kind, protoreflect.Int32Kind:				toolOptions = append(toolOptions, mcp2.WithNumber(string(input.Name()), propertyOption...))			case protoreflect.MessageKind:				propertyOption = append(propertyOption, mcp2.Properties(getFiledMessageParamProperties(input.Message(), false)))				toolOptions = append(toolOptions, mcp2.WithObject(string(input.Name()), propertyOption...))			}		}	}	toolName := serviceName + "_" + string(method.Name())	t := mcp2.NewTool(toolName, toolOptions...)	h := func(ctx context.Context, request mcp2.CallToolRequest) (*mcp2.CallToolResult, error) {		if _, ok := handlerMap[toolName]; !ok {			return nil, errors.New("没有实现")		}		arg := request.GetArguments()		argJson, _ := json.Marshal(arg)		dec := func(in any) error {			decErr := json.Unmarshal(argJson, &in)			if decErr != nil {				return decErr			}			return nil		}		interceptor := func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp any, err error) {			return handler(ctx, req)		}		handler := handlerMap[toolName]		out, outErr := handler.Handler(srv, ctx, dec, interceptor)		if outErr != nil {			return nil, outErr		}		outJson, _ := json.Marshal(out)		callToolResult := &mcp2.CallToolResult{			Content: []mcp2.Content{				mcp2.TextContent{					Type: "text",					Text: string(outJson),				},			},		}		return callToolResult, nil	}	return &t, h}func getInputArrayItems(input protoreflect.FieldDescriptor) map[string]any {	inputMap := make(map[string]any)	switch input.Kind() {	case protoreflect.StringKind:		inputMap["type"] = "string"	case protoreflect.BoolKind:		inputMap["type"] = "boolean"	case protoreflect.DoubleKind, protoreflect.FloatKind,		protoreflect.Sfixed64Kind, protoreflect.Sfixed32Kind,		protoreflect.Fixed64Kind, protoreflect.Fixed32Kind,		protoreflect.Sint64Kind, protoreflect.Sint32Kind,		protoreflect.Uint64Kind, protoreflect.Uint32Kind,		protoreflect.Int64Kind, protoreflect.Int32Kind:		inputMap["type"] = "number"	case protoreflect.MessageKind:		inputMap["type"] = "object"		propertiesMap := getFiledMessageParamProperties(input.Message(), true)		inputMap["properties"] = propertiesMap	}	return inputMap}func getInputMapProperties(input protoreflect.FieldDescriptor) (map[string]any, string) {	messageParamMap := make(map[string]any)	paramMap := make(map[string]any)	switch input.MapValue().Kind() {	case protoreflect.StringKind:		paramMap["type"] = "string"	case protoreflect.BoolKind:		paramMap["type"] = "boolean"	case protoreflect.DoubleKind, protoreflect.FloatKind,		protoreflect.Sfixed64Kind, protoreflect.Sfixed32Kind,		protoreflect.Fixed64Kind, protoreflect.Fixed32Kind,		protoreflect.Sint64Kind, protoreflect.Sint32Kind,		protoreflect.Uint64Kind, protoreflect.Uint32Kind,		protoreflect.Int64Kind, protoreflect.Int32Kind:		paramMap["type"] = "number"	case protoreflect.MessageKind:		paramMap["type"] = "object"		paramMap["properties"] = getFiledMessageParamProperties(input.Message(), false)	default:		break	}	var keyType string	switch input.MapKey().Kind() {	case protoreflect.StringKind:		keyType = "string"	case protoreflect.BoolKind:		keyType = "boolean"	case protoreflect.DoubleKind, protoreflect.FloatKind,		protoreflect.Sfixed64Kind, protoreflect.Sfixed32Kind,		protoreflect.Fixed64Kind, protoreflect.Fixed32Kind,		protoreflect.Sint64Kind, protoreflect.Sint32Kind,		protoreflect.Uint64Kind, protoreflect.Uint32Kind,		protoreflect.Int64Kind, protoreflect.Int32Kind:		keyType = "number"	}	descriptionSuffix := fmt.Sprintf("(type of key: %s, type of value: %s)", keyType, paramMap["type"])	messageParamMap[string(input.Name())] = paramMap	return messageParamMap, descriptionSuffix}func getFiledMessageParamProperties(message protoreflect.MessageDescriptor, needRequired bool) map[string]any {	messageParamMap := make(map[string]any)	for i := 0; i < message.Fields().Len(); i++ {		input := message.Fields().Get(i)		paramMap := make(map[string]any)		inputOperation, _ := proto.GetExtension(input.Options(), openapi_v3.E_Property).(*openapi_v3.Schema)		inputDescription := ""		if inputOperation != nil {			inputDescription = inputOperation.GetDescription()		}		paramMap["description"] = inputDescription		if needRequired {			inputOperation2, _ := proto.GetExtension(input.Options(), annotations.E_FieldBehavior).([]annotations.FieldBehavior)			if inputOperation2 != nil && len(inputOperation2) > 0 && inputOperation2[0] == annotations.FieldBehavior_REQUIRED {				paramMap["required"] = true			}		}		if input.IsList() {			paramMap["type"] = "array"			paramMap["items"] = getInputArrayItems(input)		} else if input.IsMap() {			paramMap["type"] = "object"			var descriptionSuffix string			paramMap["additionalProperties"], descriptionSuffix = getInputMapProperties(input)			paramMap["description"] = inputDescription + descriptionSuffix		} else {			switch input.Kind() {			case protoreflect.StringKind:				paramMap["type"] = "string"			case protoreflect.BoolKind:				paramMap["type"] = "boolean"			case protoreflect.DoubleKind, protoreflect.FloatKind,				protoreflect.Sfixed64Kind, protoreflect.Sfixed32Kind,				protoreflect.Fixed64Kind, protoreflect.Fixed32Kind,				protoreflect.Sint64Kind, protoreflect.Sint32Kind,				protoreflect.Uint64Kind, protoreflect.Uint32Kind,				protoreflect.Int64Kind, protoreflect.Int32Kind:				paramMap["type"] = "number"			case protoreflect.MessageKind:				paramMap["type"] = "object"				paramMap["properties"] = getFiledMessageParamProperties(input.Message(), false)			default:				break			}		}		messageParamMap[string(input.Name())] = paramMap	}	return messageParamMap}
 |