||
- // Copyright 2020 Google LLC. All Rights Reserved.
- //
- // Licensed under the Apache License, Version 2.0 (the "License");
- // you may not use this file except in compliance with the License.
- // You may obtain a copy of the License at
- //
- // http://www.apache.org/licenses/LICENSE-2.0
- //
- // Unless required by applicable law or agreed to in writing, software
- // distributed under the License is distributed on an "AS IS" BASIS,
- // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- // See the License for the specific language governing permissions and
- // limitations under the License.
- //
- package generator
- import (
- "fmt"
- "log"
- "net/url"
- "regexp"
- "sort"
- "strings"
- http2 "net/http"
- "google.golang.org/protobuf/types/descriptorpb"
- "google.golang.org/genproto/googleapis/api/annotations"
- status_pb "google.golang.org/genproto/googleapis/rpc/status"
- "google.golang.org/protobuf/compiler/protogen"
- "google.golang.org/protobuf/proto"
- "google.golang.org/protobuf/reflect/protoreflect"
- any_pb "google.golang.org/protobuf/types/known/anypb"
- "git.ikuban.com/server/kratos-utils/v2/transport/http/handler"
- wk "git.ikuban.com/server/swagger-api/v2/generator/wellknown"
- v3 "github.com/google/gnostic/openapiv3"
- )
- type Configuration struct {
- Version *string
- Title *string
- Description *string
- Naming *string
- FQSchemaNaming *bool
- EnumType *string
- CircularDepth *int
- DefaultResponse *bool
- OutputMode *string
- PreserveInfo *bool
- }
- const (
- infoURL = "git.ikuban.com/server/swagger-api"
- )
- // In order to dynamically add google.rpc.Status responses we need
- // to know the message descriptors for google.rpc.Status as well
- // as google.protobuf.Any.
- var statusProtoDesc = (&status_pb.Status{}).ProtoReflect().Descriptor()
- var anyProtoDesc = (&any_pb.Any{}).ProtoReflect().Descriptor()
- // OpenAPIv3Generator holds internal state needed to generate an OpenAPIv3 document for a transcoded Protocol Buffer service.
- type OpenAPIv3Generator struct {
- conf Configuration
- plugin *protogen.Plugin
- inputFiles []*protogen.File
- reflect *OpenAPIv3Reflector
- generatedSchemas []string // Names of schemas that have already been generated.
- linterRulePattern *regexp.Regexp
- pathPattern *regexp.Regexp
- namedPathPattern *regexp.Regexp
- }
- // NewOpenAPIv3Generator creates a new generator for a protoc plugin invocation.
- func NewOpenAPIv3Generator(plugin *protogen.Plugin, conf Configuration, inputFiles []*protogen.File) *OpenAPIv3Generator {
- return &OpenAPIv3Generator{
- conf: conf,
- plugin: plugin,
- inputFiles: inputFiles,
- reflect: NewOpenAPIv3Reflector(conf),
- generatedSchemas: make([]string, 0),
- linterRulePattern: regexp.MustCompile(`\(-- .* --\)`),
- pathPattern: regexp.MustCompile("{([^=}]+)}"),
- namedPathPattern: regexp.MustCompile("{(.+)=(.+)}"),
- }
- }
- // Run runs the generator.
- func (g *OpenAPIv3Generator) Run(outputFile *protogen.GeneratedFile) error {
- d := g.buildDocumentV3()
- bytes, err := d.YAMLValue("Generated with protoc-gen-openapi\n" + infoURL)
- if err != nil {
- return fmt.Errorf("failed to marshal yaml: %s", err.Error())
- }
- if _, err = outputFile.Write(bytes); err != nil {
- return fmt.Errorf("failed to write yaml: %s", err.Error())
- }
- return nil
- }
- func (g *OpenAPIv3Generator) RunV2() ([]byte, error) {
- d := g.buildDocumentV3()
- bytes, err := d.YAMLValue("Generated with protoc-gen-openapi\n" + infoURL)
- if err != nil {
- return bytes, fmt.Errorf("failed to marshal yaml: %s", err.Error())
- }
- return bytes, nil
- }
- // buildDocumentV3 builds an OpenAPIv3 document for a plugin request.
- func (g *OpenAPIv3Generator) buildDocumentV3() *v3.Document {
- d := &v3.Document{}
- d.Openapi = "3.0.3"
- d.Info = &v3.Info{
- Version: *g.conf.Version,
- Title: *g.conf.Title,
- Description: *g.conf.Description,
- }
- d.Paths = &v3.Paths{}
- d.Components = &v3.Components{
- Schemas: &v3.SchemasOrReferences{
- AdditionalProperties: []*v3.NamedSchemaOrReference{},
- },
- }
- // Go through the files and add the services to the documents, keeping
- // track of which schemas are referenced in the response so we can
- // add them later.
- for _, file := range g.inputFiles {
- if file.Generate {
- // Merge any `Document` annotations with the current
- extDocument := proto.GetExtension(file.Desc.Options(), v3.E_Document)
- if extDocument != nil {
- proto.Merge(d, extDocument.(*v3.Document))
- }
- g.addPathsToDocumentV3(d, file.Services)
- }
- }
- // While we have required schemas left to generate, go through the files again
- // looking for the related message and adding them to the document if required.
- for len(g.reflect.requiredSchemas) > 0 {
- count := len(g.reflect.requiredSchemas)
- for _, file := range g.plugin.Files {
- g.addSchemasForMessagesToDocumentV3(d, file.Messages, file.Proto.GetEdition())
- }
- g.reflect.requiredSchemas = g.reflect.requiredSchemas[count:len(g.reflect.requiredSchemas)]
- }
- // If there is only 1 service, then use it's title for the
- // document, if the document is missing it.
- if len(d.Tags) == 1 {
- if d.Info.Title == "" && d.Tags[0].Name != "" {
- d.Info.Title = d.Tags[0].Name + " API"
- }
- if d.Info.Description == "" {
- d.Info.Description = d.Tags[0].Description
- }
- d.Tags[0].Description = ""
- }
- if g.conf.PreserveInfo != nil && *g.conf.PreserveInfo {
- d.Info = &v3.Info{
- Version: *g.conf.Version,
- Title: *g.conf.Title,
- Description: *g.conf.Description,
- }
- }
- allServers := []string{}
- // If paths methods has servers, but they're all the same, then move servers to path level
- for _, path := range d.Paths.Path {
- servers := []string{}
- // Only 1 server will ever be set, per method, by the generator
- if path.Value.Get != nil && len(path.Value.Get.Servers) == 1 {
- servers = appendUnique(servers, path.Value.Get.Servers[0].Url)
- allServers = appendUnique(allServers, path.Value.Get.Servers[0].Url)
- }
- if path.Value.Post != nil && len(path.Value.Post.Servers) == 1 {
- servers = appendUnique(servers, path.Value.Post.Servers[0].Url)
- allServers = appendUnique(allServers, path.Value.Post.Servers[0].Url)
- }
- if path.Value.Put != nil && len(path.Value.Put.Servers) == 1 {
- servers = appendUnique(servers, path.Value.Put.Servers[0].Url)
- allServers = appendUnique(allServers, path.Value.Put.Servers[0].Url)
- }
- if path.Value.Delete != nil && len(path.Value.Delete.Servers) == 1 {
- servers = appendUnique(servers, path.Value.Delete.Servers[0].Url)
- allServers = appendUnique(allServers, path.Value.Delete.Servers[0].Url)
- }
- if path.Value.Patch != nil && len(path.Value.Patch.Servers) == 1 {
- servers = appendUnique(servers, path.Value.Patch.Servers[0].Url)
- allServers = appendUnique(allServers, path.Value.Patch.Servers[0].Url)
- }
- if path.Value.Head != nil && len(path.Value.Head.Servers) == 1 {
- servers = appendUnique(servers, path.Value.Head.Servers[0].Url)
- allServers = appendUnique(allServers, path.Value.Head.Servers[0].Url)
- }
- if path.Value.Options != nil && len(path.Value.Options.Servers) == 1 {
- servers = appendUnique(servers, path.Value.Options.Servers[0].Url)
- allServers = appendUnique(allServers, path.Value.Options.Servers[0].Url)
- }
- if path.Value.Trace != nil && len(path.Value.Trace.Servers) == 1 {
- servers = appendUnique(servers, path.Value.Trace.Servers[0].Url)
- allServers = appendUnique(allServers, path.Value.Trace.Servers[0].Url)
- }
- if len(servers) == 1 {
- path.Value.Servers = []*v3.Server{{Url: servers[0]}}
- if path.Value.Get != nil {
- path.Value.Get.Servers = nil
- }
- if path.Value.Post != nil {
- path.Value.Post.Servers = nil
- }
- if path.Value.Put != nil {
- path.Value.Put.Servers = nil
- }
- if path.Value.Delete != nil {
- path.Value.Delete.Servers = nil
- }
- if path.Value.Patch != nil {
- path.Value.Patch.Servers = nil
- }
- if path.Value.Head != nil {
- path.Value.Head.Servers = nil
- }
- if path.Value.Options != nil {
- path.Value.Options.Servers = nil
- }
- if path.Value.Trace != nil {
- path.Value.Trace.Servers = nil
- }
- }
- }
- // Set all servers on API level
- if len(allServers) > 0 {
- d.Servers = []*v3.Server{}
- for _, server := range allServers {
- d.Servers = append(d.Servers, &v3.Server{Url: server})
- }
- }
- // If there is only 1 server, we can safely remove all path level servers
- if len(allServers) == 1 {
- for _, path := range d.Paths.Path {
- path.Value.Servers = nil
- }
- }
- // Sort the tags.
- {
- pairs := d.Tags
- sort.Slice(pairs, func(i, j int) bool {
- return pairs[i].Name < pairs[j].Name
- })
- d.Tags = pairs
- }
- // Sort the paths.
- {
- pairs := d.Paths.Path
- sort.Slice(pairs, func(i, j int) bool {
- return pairs[i].Name < pairs[j].Name
- })
- d.Paths.Path = pairs
- }
- // Sort the schemas.
- {
- pairs := d.Components.Schemas.AdditionalProperties
- sort.Slice(pairs, func(i, j int) bool {
- return pairs[i].Name < pairs[j].Name
- })
- d.Components.Schemas.AdditionalProperties = pairs
- }
- // Deduplicate and sort security schemes.
- if d.Components.SecuritySchemes != nil && d.Components.SecuritySchemes.AdditionalProperties != nil {
- seen := make(map[string]bool)
- unique := make([]*v3.NamedSecuritySchemeOrReference, 0)
- for _, scheme := range d.Components.SecuritySchemes.AdditionalProperties {
- if !seen[scheme.Name] {
- seen[scheme.Name] = true
- unique = append(unique, scheme)
- }
- }
- sort.Slice(unique, func(i, j int) bool {
- return unique[i].Name < unique[j].Name
- })
- d.Components.SecuritySchemes.AdditionalProperties = unique
- }
- // Deduplicate and sort responses.
- if d.Components.Responses != nil && d.Components.Responses.AdditionalProperties != nil {
- seen := make(map[string]bool)
- unique := make([]*v3.NamedResponseOrReference, 0)
- for _, resp := range d.Components.Responses.AdditionalProperties {
- if !seen[resp.Name] {
- seen[resp.Name] = true
- unique = append(unique, resp)
- }
- }
- sort.Slice(unique, func(i, j int) bool {
- return unique[i].Name < unique[j].Name
- })
- d.Components.Responses.AdditionalProperties = unique
- }
- // Deduplicate and sort parameters.
- if d.Components.Parameters != nil && d.Components.Parameters.AdditionalProperties != nil {
- seen := make(map[string]bool)
- unique := make([]*v3.NamedParameterOrReference, 0)
- for _, param := range d.Components.Parameters.AdditionalProperties {
- if !seen[param.Name] {
- seen[param.Name] = true
- unique = append(unique, param)
- }
- }
- sort.Slice(unique, func(i, j int) bool {
- return unique[i].Name < unique[j].Name
- })
- d.Components.Parameters.AdditionalProperties = unique
- }
- // Deduplicate and sort request bodies.
- if d.Components.RequestBodies != nil && d.Components.RequestBodies.AdditionalProperties != nil {
- seen := make(map[string]bool)
- unique := make([]*v3.NamedRequestBodyOrReference, 0)
- for _, body := range d.Components.RequestBodies.AdditionalProperties {
- if !seen[body.Name] {
- seen[body.Name] = true
- unique = append(unique, body)
- }
- }
- sort.Slice(unique, func(i, j int) bool {
- return unique[i].Name < unique[j].Name
- })
- d.Components.RequestBodies.AdditionalProperties = unique
- }
- // Deduplicate and sort headers.
- if d.Components.Headers != nil && d.Components.Headers.AdditionalProperties != nil {
- seen := make(map[string]bool)
- unique := make([]*v3.NamedHeaderOrReference, 0)
- for _, header := range d.Components.Headers.AdditionalProperties {
- if !seen[header.Name] {
- seen[header.Name] = true
- unique = append(unique, header)
- }
- }
- sort.Slice(unique, func(i, j int) bool {
- return unique[i].Name < unique[j].Name
- })
- d.Components.Headers.AdditionalProperties = unique
- }
- // Deduplicate servers by URL.
- if d.Servers != nil && len(d.Servers) > 0 {
- seen := make(map[string]bool)
- unique := make([]*v3.Server, 0)
- for _, server := range d.Servers {
- if !seen[server.Url] {
- seen[server.Url] = true
- unique = append(unique, server)
- }
- }
- d.Servers = unique
- }
- return d
- }
- // filterCommentString removes linter rules from comments.
- func (g *OpenAPIv3Generator) filterCommentString(c protogen.Comments) string {
- comment := g.linterRulePattern.ReplaceAllString(string(c), "")
- return strings.TrimSpace(comment)
- }
- func (g *OpenAPIv3Generator) findField(name string, inMessage *protogen.Message) *protogen.Field {
- for _, field := range inMessage.Fields {
- if string(field.Desc.Name()) == name || string(field.Desc.JSONName()) == name {
- return field
- }
- }
- return nil
- }
- func (g *OpenAPIv3Generator) findAndFormatFieldName(name string, inMessage *protogen.Message) string {
- field := g.findField(name, inMessage)
- if field != nil {
- return g.reflect.formatFieldName(field.Desc)
- }
- return name
- }
- // Note that fields which are mapped to URL query parameters must have a primitive type
- // or a repeated primitive type or a non-repeated message type.
- // In the case of a repeated type, the parameter can be repeated in the URL as ...?param=A¶m=B.
- // In the case of a message type, each field of the message is mapped to a separate parameter,
- // such as ...?foo.a=A&foo.b=B&foo.c=C.
- // There are exceptions:
- // - for wrapper types it will use the same representation as the wrapped primitive type in JSON
- // - for google.protobuf.timestamp type it will be serialized as a string
- //
- // maps, Struct and Empty can NOT be used
- // messages can have any number of sub messages - including circular (e.g. sub.subsub.sub.subsub.id)
- // buildQueryParamsV3 extracts any valid query params, including sub and recursive messages
- func (g *OpenAPIv3Generator) buildQueryParamsV3(field *protogen.Field) []*v3.ParameterOrReference {
- depths := map[string]int{}
- return g._buildQueryParamsV3(field, depths)
- }
- // depths are used to keep track of how many times a message's fields has been seen
- func (g *OpenAPIv3Generator) _buildQueryParamsV3(field *protogen.Field, depths map[string]int) []*v3.ParameterOrReference {
- parameters := []*v3.ParameterOrReference{}
- queryFieldName := g.reflect.formatFieldName(field.Desc)
- fieldDescription := g.filterCommentString(field.Comments.Leading)
- if field.Desc.IsMap() {
- // Map types are not allowed in query parameteres
- return parameters
- } else if field.Desc.Kind() == protoreflect.MessageKind {
- typeName := g.reflect.fullMessageTypeName(field.Desc.Message())
- switch typeName {
- case ".google.protobuf.Value":
- fieldSchema := g.reflect.schemaOrReferenceForField(field.Desc)
- parameters = append(parameters,
- &v3.ParameterOrReference{
- Oneof: &v3.ParameterOrReference_Parameter{
- Parameter: &v3.Parameter{
- Name: queryFieldName,
- In: "query",
- Description: fieldDescription,
- Required: false,
- Schema: fieldSchema,
- },
- },
- })
- return parameters
- case ".google.protobuf.BoolValue", ".google.protobuf.BytesValue", ".google.protobuf.Int32Value", ".google.protobuf.UInt32Value",
- ".google.protobuf.StringValue", ".google.protobuf.Int64Value", ".google.protobuf.UInt64Value", ".google.protobuf.FloatValue",
- ".google.protobuf.DoubleValue":
- valueField := getValueField(field.Message.Desc)
- fieldSchema := g.reflect.schemaOrReferenceForField(valueField)
- parameters = append(parameters,
- &v3.ParameterOrReference{
- Oneof: &v3.ParameterOrReference_Parameter{
- Parameter: &v3.Parameter{
- Name: queryFieldName,
- In: "query",
- Description: fieldDescription,
- Required: false,
- Schema: fieldSchema,
- },
- },
- })
- return parameters
- case ".google.protobuf.Timestamp":
- fieldSchema := g.reflect.schemaOrReferenceForMessage(field.Message.Desc)
- parameters = append(parameters,
- &v3.ParameterOrReference{
- Oneof: &v3.ParameterOrReference_Parameter{
- Parameter: &v3.Parameter{
- Name: queryFieldName,
- In: "query",
- Description: fieldDescription,
- Required: false,
- Schema: fieldSchema,
- },
- },
- })
- return parameters
- case ".google.protobuf.Duration":
- fieldSchema := g.reflect.schemaOrReferenceForMessage(field.Message.Desc)
- parameters = append(parameters,
- &v3.ParameterOrReference{
- Oneof: &v3.ParameterOrReference_Parameter{
- Parameter: &v3.Parameter{
- Name: queryFieldName,
- In: "query",
- Description: fieldDescription,
- Required: false,
- Schema: fieldSchema,
- },
- },
- })
- return parameters
- }
- if field.Desc.IsList() {
- // Only non-repeated message types are valid
- return parameters
- }
- // Represent field masks directly as strings (don't expand them).
- if typeName == ".google.protobuf.FieldMask" {
- fieldSchema := g.reflect.schemaOrReferenceForField(field.Desc)
- parameters = append(parameters,
- &v3.ParameterOrReference{
- Oneof: &v3.ParameterOrReference_Parameter{
- Parameter: &v3.Parameter{
- Name: queryFieldName,
- In: "query",
- Description: fieldDescription,
- Required: false,
- Schema: fieldSchema,
- },
- },
- })
- return parameters
- }
- // Sub messages are allowed, even circular, as long as the final type is a primitive.
- // Go through each of the sub message fields
- for _, subField := range field.Message.Fields {
- subFieldFullName := string(subField.Desc.FullName())
- seen, ok := depths[subFieldFullName]
- if !ok {
- depths[subFieldFullName] = 0
- }
- if seen < *g.conf.CircularDepth {
- depths[subFieldFullName]++
- subParams := g._buildQueryParamsV3(subField, depths)
- for _, subParam := range subParams {
- if param, ok := subParam.Oneof.(*v3.ParameterOrReference_Parameter); ok {
- param.Parameter.Name = queryFieldName + "." + param.Parameter.Name
- parameters = append(parameters, subParam)
- }
- }
- }
- }
- } else if field.Desc.Kind() != protoreflect.GroupKind {
- // schemaOrReferenceForField also handles array types
- fieldSchema := g.reflect.schemaOrReferenceForField(field.Desc)
- parameters = append(parameters,
- &v3.ParameterOrReference{
- Oneof: &v3.ParameterOrReference_Parameter{
- Parameter: &v3.Parameter{
- Name: queryFieldName,
- In: "query",
- Description: fieldDescription,
- Required: false,
- Schema: fieldSchema,
- },
- },
- })
- }
- return parameters
- }
- // buildOperationV3 constructs an operation for a set of values.
- func (g *OpenAPIv3Generator) buildOperationV3(
- d *v3.Document,
- operationID string,
- tagName string,
- description string,
- defaultHost string,
- path string,
- bodyField string,
- inputMessage *protogen.Message,
- outputMessage *protogen.Message,
- ) (*v3.Operation, string) {
- // coveredParameters tracks the parameters that have been used in the body or path.
- coveredParameters := make([]string, 0)
- if bodyField != "" {
- coveredParameters = append(coveredParameters, bodyField)
- }
- // Initialize the list of operation parameters.
- parameters := []*v3.ParameterOrReference{}
- // Find simple path parameters like {id}
- if allMatches := g.pathPattern.FindAllStringSubmatch(path, -1); allMatches != nil {
- for _, matches := range allMatches {
- // Add the value to the list of covered parameters.
- coveredParameters = append(coveredParameters, matches[1])
- pathParameter := g.findAndFormatFieldName(matches[1], inputMessage)
- path = strings.Replace(path, matches[1], pathParameter, 1)
- // Add the path parameters to the operation parameters.
- var fieldSchema *v3.SchemaOrReference
- var fieldDescription string
- field := g.findField(pathParameter, inputMessage)
- if field != nil {
- fieldSchema = g.reflect.schemaOrReferenceForField(field.Desc)
- fieldDescription = g.filterCommentString(field.Comments.Leading)
- } else {
- // If field does not exist, it is safe to set it to string, as it is ignored downstream
- fieldSchema = &v3.SchemaOrReference{
- Oneof: &v3.SchemaOrReference_Schema{
- Schema: &v3.Schema{
- Type: "string",
- },
- },
- }
- }
- parameters = append(parameters,
- &v3.ParameterOrReference{
- Oneof: &v3.ParameterOrReference_Parameter{
- Parameter: &v3.Parameter{
- Name: pathParameter,
- In: "path",
- Description: fieldDescription,
- Required: true,
- Schema: fieldSchema,
- },
- },
- })
- }
- }
- // Find named path parameters like {name=shelves/*}
- if matches := g.namedPathPattern.FindStringSubmatch(path); matches != nil {
- // Build a list of named path parameters.
- namedPathParameters := make([]string, 0)
- // Add the "name=" "name" value to the list of covered parameters.
- coveredParameters = append(coveredParameters, matches[1])
- // Convert the path from the starred form to use named path parameters.
- starredPath := matches[2]
- parts := strings.Split(starredPath, "/")
- // The starred path is assumed to be in the form "things/*/otherthings/*".
- // We want to convert it to "things/{thingsId}/otherthings/{otherthingsId}".
- for i := 0; i < len(parts)-1; i += 2 {
- section := parts[i]
- namedPathParameter := g.findAndFormatFieldName(section, inputMessage)
- namedPathParameter = singular(namedPathParameter)
- parts[i+1] = "{" + namedPathParameter + "}"
- namedPathParameters = append(namedPathParameters, namedPathParameter)
- }
- // Rewrite the path to use the path parameters.
- newPath := strings.Join(parts, "/")
- path = strings.Replace(path, matches[0], newPath, 1)
- // Add the named path parameters to the operation parameters.
- for _, namedPathParameter := range namedPathParameters {
- parameters = append(parameters,
- &v3.ParameterOrReference{
- Oneof: &v3.ParameterOrReference_Parameter{
- Parameter: &v3.Parameter{
- Name: namedPathParameter,
- In: "path",
- Required: true,
- Description: "The " + namedPathParameter + " id.",
- Schema: &v3.SchemaOrReference{
- Oneof: &v3.SchemaOrReference_Schema{
- Schema: &v3.Schema{
- Type: "string",
- },
- },
- },
- },
- },
- })
- }
- }
- // Add any unhandled fields in the request message as query parameters.
- if bodyField != "*" && string(inputMessage.Desc.FullName()) != "google.api.HttpBody" {
- for _, field := range inputMessage.Fields {
- fieldName := string(field.Desc.Name())
- if !contains(coveredParameters, fieldName) && fieldName != bodyField {
- fieldParams := g.buildQueryParamsV3(field)
- parameters = append(parameters, fieldParams...)
- }
- }
- }
- // Create the response.
- name, content := g.reflect.responseContentForMessage(outputMessage.Desc)
- responses := &v3.Responses{
- ResponseOrReference: []*v3.NamedResponseOrReference{
- {
- Name: name,
- Value: &v3.ResponseOrReference{
- Oneof: &v3.ResponseOrReference_Response{
- Response: &v3.Response{
- Description: "OK",
- Content: content,
- },
- },
- },
- },
- },
- }
- // Add the default reponse if needed
- if *g.conf.DefaultResponse {
- anySchemaName := g.reflect.formatMessageName(anyProtoDesc)
- anySchema := wk.NewGoogleProtobufAnySchema(anySchemaName)
- g.addSchemaToDocumentV3(d, anySchema)
- statusSchemaName := g.reflect.formatMessageName(statusProtoDesc)
- statusSchema := wk.NewGoogleRpcStatusSchema(statusSchemaName, anySchemaName)
- g.addSchemaToDocumentV3(d, statusSchema)
- defaultResponse := &v3.NamedResponseOrReference{
- Name: "default",
- Value: &v3.ResponseOrReference{
- Oneof: &v3.ResponseOrReference_Response{
- Response: &v3.Response{
- Description: "Default error response",
- Content: wk.NewApplicationJsonMediaType(&v3.SchemaOrReference{
- Oneof: &v3.SchemaOrReference_Reference{
- Reference: &v3.Reference{XRef: "#/components/schemas/" + statusSchemaName}}}),
- },
- },
- },
- }
- responses.ResponseOrReference = append(responses.ResponseOrReference, defaultResponse)
- }
- // Create the operation.
- op := &v3.Operation{
- Tags: []string{tagName},
- Description: description,
- OperationId: operationID,
- Parameters: parameters,
- Responses: responses,
- }
- if defaultHost != "" {
- hostURL, err := url.Parse(defaultHost)
- if err == nil {
- hostURL.Scheme = "https"
- op.Servers = append(op.Servers, &v3.Server{Url: hostURL.String()})
- }
- }
- // If a body field is specified, we need to pass a message as the request body.
- if bodyField != "" {
- var requestSchema *v3.SchemaOrReference
- if bodyField == "*" {
- // Pass the entire request message as the request body.
- requestSchema = g.reflect.schemaOrReferenceForMessage(inputMessage.Desc)
- } else {
- // If body refers to a message field, use that type.
- for _, field := range inputMessage.Fields {
- if string(field.Desc.Name()) == bodyField {
- switch field.Desc.Kind() {
- case protoreflect.StringKind:
- requestSchema = &v3.SchemaOrReference{
- Oneof: &v3.SchemaOrReference_Schema{
- Schema: &v3.Schema{
- Type: "string",
- },
- },
- }
- case protoreflect.MessageKind:
- requestSchema = g.reflect.schemaOrReferenceForMessage(field.Message.Desc)
- default:
- log.Printf("unsupported field type %+v", field.Desc)
- }
- break
- }
- }
- }
- op.RequestBody = &v3.RequestBodyOrReference{
- Oneof: &v3.RequestBodyOrReference_RequestBody{
- RequestBody: &v3.RequestBody{
- Required: true,
- Content: &v3.MediaTypes{
- AdditionalProperties: []*v3.NamedMediaType{
- {
- Name: "application/json",
- Value: &v3.MediaType{
- Schema: requestSchema,
- },
- },
- },
- },
- },
- },
- }
- }
- return op, path
- }
- // addOperationToDocumentV3 adds an operation to the specified path/method.
- func (g *OpenAPIv3Generator) addOperationToDocumentV3(d *v3.Document, op *v3.Operation, path string, methodName string) {
- var selectedPathItem *v3.NamedPathItem
- for _, namedPathItem := range d.Paths.Path {
- if namedPathItem.Name == path {
- selectedPathItem = namedPathItem
- break
- }
- }
- // If we get here, we need to create a path item.
- if selectedPathItem == nil {
- selectedPathItem = &v3.NamedPathItem{Name: path, Value: &v3.PathItem{}}
- d.Paths.Path = append(d.Paths.Path, selectedPathItem)
- }
- // Set the operation on the specified method.
- switch methodName {
- case "GET":
- selectedPathItem.Value.Get = op
- case "POST":
- selectedPathItem.Value.Post = op
- case "PUT":
- selectedPathItem.Value.Put = op
- case "DELETE":
- selectedPathItem.Value.Delete = op
- case "PATCH":
- selectedPathItem.Value.Patch = op
- case http2.MethodHead:
- selectedPathItem.Value.Head = op
- case http2.MethodOptions:
- selectedPathItem.Value.Options = op
- case http2.MethodTrace:
- selectedPathItem.Value.Trace = op
- }
- }
- // addPathsToDocumentV3 adds paths from a specified file descriptor.
- func (g *OpenAPIv3Generator) addPathsToDocumentV3(d *v3.Document, services []*protogen.Service) {
- for _, service := range services {
- annotationsCount := 0
- for _, method := range service.Methods {
- comment := g.filterCommentString(method.Comments.Leading)
- inputMessage := method.Input
- outputMessage := method.Output
- operationID := service.GoName + "_" + method.GoName
- extOperation := proto.GetExtension(method.Desc.Options(), v3.E_Operation)
- if extOperation == nil || extOperation == v3.E_Operation.InterfaceOf(v3.E_Operation.Zero()) {
- continue
- }
- var path string
- var httpMethod string
- var bodyField string
- httpOperation := proto.GetExtension(method.Desc.Options(), annotations.E_Http)
- if httpOperation != nil && httpOperation != annotations.E_Http.InterfaceOf(annotations.E_Http.Zero()) {
- _httpOperation := httpOperation.(*annotations.HttpRule)
- switch httpRule := _httpOperation.GetPattern().(type) {
- case *annotations.HttpRule_Post:
- path = httpRule.Post
- httpMethod = http2.MethodPost
- bodyField = _httpOperation.GetBody()
- case *annotations.HttpRule_Get:
- path = httpRule.Get
- httpMethod = http2.MethodGet
- bodyField = ""
- case *annotations.HttpRule_Delete:
- path = httpRule.Delete
- httpMethod = http2.MethodDelete
- bodyField = ""
- case *annotations.HttpRule_Put:
- path = httpRule.Put
- httpMethod = http2.MethodPut
- bodyField = _httpOperation.GetBody()
- case *annotations.HttpRule_Patch:
- path = httpRule.Patch
- httpMethod = http2.MethodPatch
- bodyField = _httpOperation.GetBody()
- case *annotations.HttpRule_Custom:
- path = httpRule.Custom.Path
- httpMethod = httpRule.Custom.Kind
- bodyField = _httpOperation.GetBody()
- }
- }
- annotationsCount++
- if path == "" {
- path = handler.PathGenerator(string(service.Desc.FullName()), method.GoName)
- }
- if httpMethod == "" {
- httpMethod = http2.MethodPost
- }
- if bodyField == "" && (httpMethod == http2.MethodPost || httpMethod == http2.MethodPut || httpMethod == http2.MethodPatch) {
- bodyField = "*"
- }
- defaultHost := proto.GetExtension(service.Desc.Options(), annotations.E_DefaultHost).(string)
- op, path2 := g.buildOperationV3(
- d, operationID, d.Info.Title, comment, defaultHost, path, bodyField, inputMessage, outputMessage)
- // Merge any `Operation` annotations with the current
- proto.Merge(op, extOperation.(*v3.Operation))
- g.addOperationToDocumentV3(d, op, path2, httpMethod)
- }
- if annotationsCount > 0 {
- d.Tags = append(d.Tags, &v3.Tag{Name: d.Info.Title, Description: d.Info.Description})
- }
- }
- }
- // addSchemaForMessageToDocumentV3 adds the schema to the document if required
- func (g *OpenAPIv3Generator) addSchemaToDocumentV3(d *v3.Document, schema *v3.NamedSchemaOrReference) {
- if contains(g.generatedSchemas, schema.Name) {
- return
- }
- g.generatedSchemas = append(g.generatedSchemas, schema.Name)
- d.Components.Schemas.AdditionalProperties = append(d.Components.Schemas.AdditionalProperties, schema)
- }
- // addSchemasForMessagesToDocumentV3 adds info from one file descriptor.
- func (g *OpenAPIv3Generator) addSchemasForMessagesToDocumentV3(d *v3.Document, messages []*protogen.Message, edition descriptorpb.Edition) {
- // For each message, generate a definition.
- for _, message := range messages {
- if message.Messages != nil {
- g.addSchemasForMessagesToDocumentV3(d, message.Messages, edition)
- }
- schemaName := g.reflect.formatMessageName(message.Desc)
- // Only generate this if we need it and haven't already generated it.
- if !contains(g.reflect.requiredSchemas, schemaName) ||
- contains(g.generatedSchemas, schemaName) {
- continue
- }
- typeName := g.reflect.fullMessageTypeName(message.Desc)
- messageDescription := g.filterCommentString(message.Comments.Leading)
- // `google.protobuf.Value` and `google.protobuf.Any` have special JSON transcoding
- // so we can't just reflect on the message descriptor.
- if typeName == ".google.protobuf.Value" {
- g.addSchemaToDocumentV3(d, wk.NewGoogleProtobufValueSchema(schemaName))
- continue
- } else if typeName == ".google.protobuf.Any" {
- g.addSchemaToDocumentV3(d, wk.NewGoogleProtobufAnySchema(schemaName))
- continue
- } else if typeName == ".google.rpc.Status" {
- anySchemaName := g.reflect.formatMessageName(anyProtoDesc)
- g.addSchemaToDocumentV3(d, wk.NewGoogleProtobufAnySchema(anySchemaName))
- g.addSchemaToDocumentV3(d, wk.NewGoogleRpcStatusSchema(schemaName, anySchemaName))
- continue
- }
- // Build an array holding the fields of the message.
- definitionProperties := &v3.Properties{
- AdditionalProperties: make([]*v3.NamedSchemaOrReference, 0),
- }
- var required []string
- for _, field := range message.Fields {
- // Get the field description from the comments.
- description := g.filterCommentString(field.Comments.Leading)
- // Check the field annotations to see if this is a readonly or writeonly field.
- inputOnly := false
- outputOnly := false
- isRequired := true
- extension := proto.GetExtension(field.Desc.Options(), annotations.E_FieldBehavior)
- if extension != nil {
- switch v := extension.(type) {
- case []annotations.FieldBehavior:
- for _, vv := range v {
- switch vv {
- case annotations.FieldBehavior_OUTPUT_ONLY:
- outputOnly = true
- case annotations.FieldBehavior_INPUT_ONLY:
- inputOnly = true
- case annotations.FieldBehavior_OPTIONAL:
- isRequired = false
- }
- }
- default:
- log.Printf("unsupported extension type %T", extension)
- }
- }
- if edition == descriptorpb.Edition_EDITION_2023 {
- if fieldOptions, ok := field.Desc.Options().(*descriptorpb.FieldOptions); ok {
- if fieldOptions.GetFeatures().GetFieldPresence() == descriptorpb.FeatureSet_EXPLICIT {
- isRequired = false
- }
- }
- }
- if isRequired {
- required = append(required, g.reflect.formatFieldName(field.Desc))
- }
- // The field is either described by a reference or a schema.
- fieldSchema := g.reflect.schemaOrReferenceForField(field.Desc)
- if fieldSchema == nil {
- continue
- }
- // If this field has siblings and is a $ref now, create a new schema use `allOf` to wrap it
- wrapperNeeded := inputOnly || outputOnly || description != ""
- if wrapperNeeded {
- if _, ok := fieldSchema.Oneof.(*v3.SchemaOrReference_Reference); ok {
- fieldSchema = &v3.SchemaOrReference{Oneof: &v3.SchemaOrReference_Schema{Schema: &v3.Schema{
- AllOf: []*v3.SchemaOrReference{fieldSchema},
- }}}
- }
- }
- if schema, ok := fieldSchema.Oneof.(*v3.SchemaOrReference_Schema); ok {
- schema.Schema.Description = description
- schema.Schema.ReadOnly = outputOnly
- schema.Schema.WriteOnly = inputOnly
- // Merge any `Property` annotations with the current
- extProperty := proto.GetExtension(field.Desc.Options(), v3.E_Property)
- if extProperty != nil {
- proto.Merge(schema.Schema, extProperty.(*v3.Schema))
- }
- }
- definitionProperties.AdditionalProperties = append(
- definitionProperties.AdditionalProperties,
- &v3.NamedSchemaOrReference{
- Name: g.reflect.formatFieldName(field.Desc),
- Value: fieldSchema,
- },
- )
- }
- schema := &v3.Schema{
- Type: "object",
- Description: messageDescription,
- Properties: definitionProperties,
- Required: required,
- }
- // Merge any `Schema` annotations with the current
- extSchema := proto.GetExtension(message.Desc.Options(), v3.E_Schema)
- if extSchema != nil {
- proto.Merge(schema, extSchema.(*v3.Schema))
- }
- // Add the schema to the components.schema list.
- g.addSchemaToDocumentV3(d, &v3.NamedSchemaOrReference{
- Name: schemaName,
- Value: &v3.SchemaOrReference{
- Oneof: &v3.SchemaOrReference_Schema{
- Schema: schema,
- },
- },
- })
- }
- }
|