Prechádzať zdrojové kódy

feat(mcp): 优化参数处理逻辑

- 增加对数组和字典类型参数的处理
- 优化参数描述和必填项标记的逻辑
- 更新 Kratos 依赖到 v2.8.4 版本
- 添加 golang.org/x/sync依赖
dcsunny 4 mesiacov pred
rodič
commit
c58ab57292
3 zmenil súbory, kde vykonal 130 pridanie a 31 odobranie
  1. 2 1
      go.mod
  2. 2 0
      go.sum
  3. 126 30
      mcp/tools.go

+ 2 - 1
go.mod

@@ -11,7 +11,7 @@ require (
 	github.com/dcsunny/gocrypt v0.0.0-20200828060317-4dec5212cc15
 	github.com/dcsunny/mwt v0.0.0-20210128034911-2f50006077f5
 	github.com/dgrijalva/jwt-go v3.2.0+incompatible
-	github.com/go-kratos/kratos/v2 v2.8.3
+	github.com/go-kratos/kratos/v2 v2.8.4
 	github.com/go-resty/resty/v2 v2.7.0
 	github.com/google/gnostic v0.7.0
 	github.com/google/uuid v1.6.0
@@ -56,6 +56,7 @@ require (
 	go.uber.org/zap v1.21.0 // indirect
 	golang.org/x/crypto v0.32.0 // indirect
 	golang.org/x/net v0.34.0 // indirect
+	golang.org/x/sync v0.10.0 // indirect
 	golang.org/x/sys v0.29.0 // indirect
 	golang.org/x/text v0.21.0 // indirect
 	google.golang.org/genproto/googleapis/rpc v0.0.0-20240102182953-50ed04b92917 // indirect

+ 2 - 0
go.sum

@@ -699,6 +699,8 @@ github.com/go-kratos/aegis v0.2.0 h1:dObzCDWn3XVjUkgxyBp6ZeWtx/do0DPZ7LY3yNSJLUQ
 github.com/go-kratos/aegis v0.2.0/go.mod h1:v0R2m73WgEEYB3XYu6aE2WcMwsZkJ/Rzuf5eVccm7bI=
 github.com/go-kratos/kratos/v2 v2.8.3 h1:kkNBq0gvdX+b8cbaN+p6Sdh95DgMhx7GimefXb4o7Ss=
 github.com/go-kratos/kratos/v2 v2.8.3/go.mod h1:+Vfe3FzF0d+BfMdajA11jT0rAyJWublRE/seZQNZVxE=
+github.com/go-kratos/kratos/v2 v2.8.4 h1:eIJLE9Qq9WSoKx+Buy2uPyrahtF/lPh+Xf4MTpxhmjs=
+github.com/go-kratos/kratos/v2 v2.8.4/go.mod h1:mq62W2101a5uYyRxe+7IdWubu7gZCGYqSNKwGFiiRcw=
 github.com/go-latex/latex v0.0.0-20210118124228-b3d85cf34e07/go.mod h1:CO1AlKB2CSIqUrmQPqA0gdRIlnLEY0gK5JGjh37zN5U=
 github.com/go-latex/latex v0.0.0-20210823091927-c0d11ff05a81/go.mod h1:SX0U8uGpxhq9o2S/CELCSUxEWWAuoCUcVCQWv7G2OCk=
 github.com/go-pdf/fpdf v0.5.0/go.mod h1:HzcnA+A23uwogo0tp9yU+l3V+KXhiESpt1PMayhOh5M=

+ 126 - 30
mcp/tools.go

@@ -4,6 +4,7 @@ import (
 	"context"
 	"encoding/json"
 	"errors"
+	"fmt"
 	"strings"
 
 	annotations2 "git.ikuban.com/server/kubanapis/kuban/api/annotations"
@@ -70,25 +71,39 @@ func serverAddToolsByMethod(serviceName string, srv any, method protoreflect.Met
 		if inputOperation != nil {
 			inputDescription = inputOperation.GetDescription()
 		}
-		propertyOption := []mcp2.PropertyOption{mcp2.Description(inputDescription)}
+		propertyOption := make([]mcp2.PropertyOption, 0)
 		if inputOperation2 != nil && len(inputOperation2) > 0 && inputOperation2[0] == annotations.FieldBehavior_REQUIRED {
 			propertyOption = append(propertyOption, mcp2.Required())
 		}
-		switch input.Kind() {
-		case protoreflect.StringKind:
-			toolOptions = append(toolOptions, mcp2.WithString(string(input.Name()), propertyOption...))
-		case protoreflect.BoolKind:
-			toolOptions = append(toolOptions, mcp2.WithBoolean(string(input.Name()), propertyOption...))
-		case protoreflect.DoubleKind, protoreflect.FloatKind,
-			protoreflect.Sfixed64Kind, protoreflect.Sfixed32Kind,
-			protoreflect.Fixed64Kind, protoreflect.Fixed32Kind,
-			protoreflect.Sint64Kind, protoreflect.Sint32Kind,
-			protoreflect.Uint64Kind, protoreflect.Uint32Kind,
-			protoreflect.Int64Kind, protoreflect.Int32Kind:
-			toolOptions = append(toolOptions, mcp2.WithNumber(string(input.Name()), propertyOption...))
-		case protoreflect.MessageKind:
-			propertyOption = append(propertyOption, mcp2.Properties(getFiledMessageParamProperties(input.Message())))
+
+		if input.IsList() {
+			propertyOption = append(propertyOption, mcp2.Items(getInputArrayItems(input)))
+			propertyOption = append(propertyOption, mcp2.Description(inputDescription))
+			toolOptions = append(toolOptions, mcp2.WithArray(string(input.Name()), propertyOption...))
+		} else if input.IsMap() {
+			additionalProperties, descriptionSuffix := getInputMapProperties(input)
+			inputDescription = inputDescription + descriptionSuffix
+			propertyOption = append(propertyOption, mcp2.Description(inputDescription))
+			propertyOption = append(propertyOption, mcp2.AdditionalProperties(additionalProperties))
 			toolOptions = append(toolOptions, mcp2.WithObject(string(input.Name()), propertyOption...))
+		} else {
+			propertyOption = append(propertyOption, mcp2.Description(inputDescription))
+			switch input.Kind() {
+			case protoreflect.StringKind:
+				toolOptions = append(toolOptions, mcp2.WithString(string(input.Name()), propertyOption...))
+			case protoreflect.BoolKind:
+				toolOptions = append(toolOptions, mcp2.WithBoolean(string(input.Name()), propertyOption...))
+			case protoreflect.DoubleKind, protoreflect.FloatKind,
+				protoreflect.Sfixed64Kind, protoreflect.Sfixed32Kind,
+				protoreflect.Fixed64Kind, protoreflect.Fixed32Kind,
+				protoreflect.Sint64Kind, protoreflect.Sint32Kind,
+				protoreflect.Uint64Kind, protoreflect.Uint32Kind,
+				protoreflect.Int64Kind, protoreflect.Int32Kind:
+				toolOptions = append(toolOptions, mcp2.WithNumber(string(input.Name()), propertyOption...))
+			case protoreflect.MessageKind:
+				propertyOption = append(propertyOption, mcp2.Properties(getFiledMessageParamProperties(input.Message(), false)))
+				toolOptions = append(toolOptions, mcp2.WithObject(string(input.Name()), propertyOption...))
+			}
 		}
 	}
 	toolName := serviceName + "_" + string(method.Name())
@@ -133,7 +148,69 @@ func serverAddToolsByMethod(serviceName string, srv any, method protoreflect.Met
 	return &t, h
 }
 
-func getFiledMessageParamProperties(message protoreflect.MessageDescriptor) map[string]any {
+func getInputArrayItems(input protoreflect.FieldDescriptor) map[string]any {
+	inputMap := make(map[string]any)
+	switch input.Kind() {
+	case protoreflect.StringKind:
+		inputMap["type"] = "string"
+	case protoreflect.BoolKind:
+		inputMap["type"] = "boolean"
+	case protoreflect.DoubleKind, protoreflect.FloatKind,
+		protoreflect.Sfixed64Kind, protoreflect.Sfixed32Kind,
+		protoreflect.Fixed64Kind, protoreflect.Fixed32Kind,
+		protoreflect.Sint64Kind, protoreflect.Sint32Kind,
+		protoreflect.Uint64Kind, protoreflect.Uint32Kind,
+		protoreflect.Int64Kind, protoreflect.Int32Kind:
+		inputMap["type"] = "number"
+	case protoreflect.MessageKind:
+		inputMap["type"] = "object"
+		propertiesMap := getFiledMessageParamProperties(input.Message(), true)
+		inputMap["properties"] = propertiesMap
+	}
+	return inputMap
+}
+
+func getInputMapProperties(input protoreflect.FieldDescriptor) (map[string]any, string) {
+	messageParamMap := make(map[string]any)
+	paramMap := make(map[string]any)
+	switch input.MapValue().Kind() {
+	case protoreflect.StringKind:
+		paramMap["type"] = "string"
+	case protoreflect.BoolKind:
+		paramMap["type"] = "boolean"
+	case protoreflect.DoubleKind, protoreflect.FloatKind,
+		protoreflect.Sfixed64Kind, protoreflect.Sfixed32Kind,
+		protoreflect.Fixed64Kind, protoreflect.Fixed32Kind,
+		protoreflect.Sint64Kind, protoreflect.Sint32Kind,
+		protoreflect.Uint64Kind, protoreflect.Uint32Kind,
+		protoreflect.Int64Kind, protoreflect.Int32Kind:
+		paramMap["type"] = "number"
+	case protoreflect.MessageKind:
+		paramMap["type"] = "object"
+		paramMap["properties"] = getFiledMessageParamProperties(input.Message(), false)
+	default:
+		break
+	}
+	var keyType string
+	switch input.MapKey().Kind() {
+	case protoreflect.StringKind:
+		keyType = "string"
+	case protoreflect.BoolKind:
+		keyType = "boolean"
+	case protoreflect.DoubleKind, protoreflect.FloatKind,
+		protoreflect.Sfixed64Kind, protoreflect.Sfixed32Kind,
+		protoreflect.Fixed64Kind, protoreflect.Fixed32Kind,
+		protoreflect.Sint64Kind, protoreflect.Sint32Kind,
+		protoreflect.Uint64Kind, protoreflect.Uint32Kind,
+		protoreflect.Int64Kind, protoreflect.Int32Kind:
+		keyType = "number"
+	}
+	descriptionSuffix := fmt.Sprintf("(type of key: %s, type of value: %s)", keyType, paramMap["type"])
+	messageParamMap[string(input.Name())] = paramMap
+	return messageParamMap, descriptionSuffix
+}
+
+func getFiledMessageParamProperties(message protoreflect.MessageDescriptor, needRequired bool) map[string]any {
 
 	messageParamMap := make(map[string]any)
 
@@ -146,21 +223,40 @@ func getFiledMessageParamProperties(message protoreflect.MessageDescriptor) map[
 			inputDescription = inputOperation.GetDescription()
 		}
 		paramMap["description"] = inputDescription
+		if needRequired {
+			inputOperation2, _ := proto.GetExtension(input.Options(), annotations.E_FieldBehavior).([]annotations.FieldBehavior)
+			if inputOperation2 != nil && len(inputOperation2) > 0 && inputOperation2[0] == annotations.FieldBehavior_REQUIRED {
+				paramMap["required"] = true
+			}
+		}
 
-		switch input.Kind() {
-		case protoreflect.StringKind:
-			paramMap["type"] = "string"
-		case protoreflect.BoolKind:
-			paramMap["type"] = "boolean"
-		case protoreflect.DoubleKind, protoreflect.FloatKind,
-			protoreflect.Sfixed64Kind, protoreflect.Sfixed32Kind,
-			protoreflect.Fixed64Kind, protoreflect.Fixed32Kind,
-			protoreflect.Sint64Kind, protoreflect.Sint32Kind,
-			protoreflect.Uint64Kind, protoreflect.Uint32Kind,
-			protoreflect.Int64Kind, protoreflect.Int32Kind:
-			paramMap["type"] = "number"
-		default:
-			break
+		if input.IsList() {
+			paramMap["type"] = "array"
+			paramMap["items"] = getInputArrayItems(input)
+		} else if input.IsMap() {
+			paramMap["type"] = "object"
+			var descriptionSuffix string
+			paramMap["additionalProperties"], descriptionSuffix = getInputMapProperties(input)
+			paramMap["description"] = inputDescription + descriptionSuffix
+		} else {
+			switch input.Kind() {
+			case protoreflect.StringKind:
+				paramMap["type"] = "string"
+			case protoreflect.BoolKind:
+				paramMap["type"] = "boolean"
+			case protoreflect.DoubleKind, protoreflect.FloatKind,
+				protoreflect.Sfixed64Kind, protoreflect.Sfixed32Kind,
+				protoreflect.Fixed64Kind, protoreflect.Fixed32Kind,
+				protoreflect.Sint64Kind, protoreflect.Sint32Kind,
+				protoreflect.Uint64Kind, protoreflect.Uint32Kind,
+				protoreflect.Int64Kind, protoreflect.Int32Kind:
+				paramMap["type"] = "number"
+			case protoreflect.MessageKind:
+				paramMap["type"] = "object"
+				paramMap["properties"] = getFiledMessageParamProperties(input.Message(), false)
+			default:
+				break
+			}
 		}
 		messageParamMap[string(input.Name())] = paramMap
 	}