generator.go 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894
  1. // Copyright 2020 Google LLC. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. //
  15. package generator
  16. import (
  17. "fmt"
  18. "google.golang.org/protobuf/types/descriptorpb"
  19. "log"
  20. "net/url"
  21. "regexp"
  22. "sort"
  23. "strings"
  24. "google.golang.org/genproto/googleapis/api/annotations"
  25. status_pb "google.golang.org/genproto/googleapis/rpc/status"
  26. "google.golang.org/protobuf/compiler/protogen"
  27. "google.golang.org/protobuf/proto"
  28. "google.golang.org/protobuf/reflect/protoreflect"
  29. any_pb "google.golang.org/protobuf/types/known/anypb"
  30. wk "git.ikuban.com/server/swagger-api/v2/generator/wellknown"
  31. v3 "github.com/google/gnostic/openapiv3"
  32. )
  33. type Configuration struct {
  34. Version *string
  35. Title *string
  36. Description *string
  37. Naming *string
  38. FQSchemaNaming *bool
  39. EnumType *string
  40. CircularDepth *int
  41. DefaultResponse *bool
  42. OutputMode *string
  43. }
  44. const (
  45. infoURL = "git.ikuban.com/server/swagger-api"
  46. )
  47. // In order to dynamically add google.rpc.Status responses we need
  48. // to know the message descriptors for google.rpc.Status as well
  49. // as google.protobuf.Any.
  50. var statusProtoDesc = (&status_pb.Status{}).ProtoReflect().Descriptor()
  51. var anyProtoDesc = (&any_pb.Any{}).ProtoReflect().Descriptor()
  52. // OpenAPIv3Generator holds internal state needed to generate an OpenAPIv3 document for a transcoded Protocol Buffer service.
  53. type OpenAPIv3Generator struct {
  54. conf Configuration
  55. plugin *protogen.Plugin
  56. inputFiles []*protogen.File
  57. reflect *OpenAPIv3Reflector
  58. generatedSchemas []string // Names of schemas that have already been generated.
  59. linterRulePattern *regexp.Regexp
  60. pathPattern *regexp.Regexp
  61. namedPathPattern *regexp.Regexp
  62. }
  63. // NewOpenAPIv3Generator creates a new generator for a protoc plugin invocation.
  64. func NewOpenAPIv3Generator(plugin *protogen.Plugin, conf Configuration, inputFiles []*protogen.File) *OpenAPIv3Generator {
  65. return &OpenAPIv3Generator{
  66. conf: conf,
  67. plugin: plugin,
  68. inputFiles: inputFiles,
  69. reflect: NewOpenAPIv3Reflector(conf),
  70. generatedSchemas: make([]string, 0),
  71. linterRulePattern: regexp.MustCompile(`\(-- .* --\)`),
  72. pathPattern: regexp.MustCompile("{([^=}]+)}"),
  73. namedPathPattern: regexp.MustCompile("{(.+)=(.+)}"),
  74. }
  75. }
  76. // Run runs the generator.
  77. func (g *OpenAPIv3Generator) Run(outputFile *protogen.GeneratedFile) error {
  78. d := g.buildDocumentV3()
  79. bytes, err := d.YAMLValue("Generated with protoc-gen-openapi\n" + infoURL)
  80. if err != nil {
  81. return fmt.Errorf("failed to marshal yaml: %s", err.Error())
  82. }
  83. if _, err = outputFile.Write(bytes); err != nil {
  84. return fmt.Errorf("failed to write yaml: %s", err.Error())
  85. }
  86. return nil
  87. }
  88. func (g *OpenAPIv3Generator) RunV2() ([]byte, error) {
  89. d := g.buildDocumentV3()
  90. bytes, err := d.YAMLValue("Generated with protoc-gen-openapi\n" + infoURL)
  91. if err != nil {
  92. return bytes, fmt.Errorf("failed to marshal yaml: %s", err.Error())
  93. }
  94. return bytes, nil
  95. }
  96. // buildDocumentV3 builds an OpenAPIv3 document for a plugin request.
  97. func (g *OpenAPIv3Generator) buildDocumentV3() *v3.Document {
  98. d := &v3.Document{}
  99. d.Openapi = "3.0.3"
  100. d.Info = &v3.Info{
  101. Version: *g.conf.Version,
  102. Title: *g.conf.Title,
  103. Description: *g.conf.Description,
  104. }
  105. d.Paths = &v3.Paths{}
  106. d.Components = &v3.Components{
  107. Schemas: &v3.SchemasOrReferences{
  108. AdditionalProperties: []*v3.NamedSchemaOrReference{},
  109. },
  110. }
  111. // Go through the files and add the services to the documents, keeping
  112. // track of which schemas are referenced in the response so we can
  113. // add them later.
  114. for _, file := range g.inputFiles {
  115. if file.Generate {
  116. // Merge any `Document` annotations with the current
  117. extDocument := proto.GetExtension(file.Desc.Options(), v3.E_Document)
  118. if extDocument != nil {
  119. proto.Merge(d, extDocument.(*v3.Document))
  120. }
  121. g.addPathsToDocumentV3(d, file.Services)
  122. }
  123. }
  124. // While we have required schemas left to generate, go through the files again
  125. // looking for the related message and adding them to the document if required.
  126. for len(g.reflect.requiredSchemas) > 0 {
  127. count := len(g.reflect.requiredSchemas)
  128. for _, file := range g.plugin.Files {
  129. g.addSchemasForMessagesToDocumentV3(d, file.Messages, file.Proto.GetEdition())
  130. }
  131. g.reflect.requiredSchemas = g.reflect.requiredSchemas[count:len(g.reflect.requiredSchemas)]
  132. }
  133. // If there is only 1 service, then use it's title for the
  134. // document, if the document is missing it.
  135. if len(d.Tags) == 1 {
  136. if d.Info.Title == "" && d.Tags[0].Name != "" {
  137. d.Info.Title = d.Tags[0].Name + " API"
  138. }
  139. if d.Info.Description == "" {
  140. d.Info.Description = d.Tags[0].Description
  141. }
  142. d.Tags[0].Description = ""
  143. }
  144. allServers := []string{}
  145. // If paths methods has servers, but they're all the same, then move servers to path level
  146. for _, path := range d.Paths.Path {
  147. servers := []string{}
  148. // Only 1 server will ever be set, per method, by the generator
  149. if path.Value.Get != nil && len(path.Value.Get.Servers) == 1 {
  150. servers = appendUnique(servers, path.Value.Get.Servers[0].Url)
  151. allServers = appendUnique(servers, path.Value.Get.Servers[0].Url)
  152. }
  153. if path.Value.Post != nil && len(path.Value.Post.Servers) == 1 {
  154. servers = appendUnique(servers, path.Value.Post.Servers[0].Url)
  155. allServers = appendUnique(servers, path.Value.Post.Servers[0].Url)
  156. }
  157. if path.Value.Put != nil && len(path.Value.Put.Servers) == 1 {
  158. servers = appendUnique(servers, path.Value.Put.Servers[0].Url)
  159. allServers = appendUnique(servers, path.Value.Put.Servers[0].Url)
  160. }
  161. if path.Value.Delete != nil && len(path.Value.Delete.Servers) == 1 {
  162. servers = appendUnique(servers, path.Value.Delete.Servers[0].Url)
  163. allServers = appendUnique(servers, path.Value.Delete.Servers[0].Url)
  164. }
  165. if path.Value.Patch != nil && len(path.Value.Patch.Servers) == 1 {
  166. servers = appendUnique(servers, path.Value.Patch.Servers[0].Url)
  167. allServers = appendUnique(servers, path.Value.Patch.Servers[0].Url)
  168. }
  169. if len(servers) == 1 {
  170. path.Value.Servers = []*v3.Server{{Url: servers[0]}}
  171. if path.Value.Get != nil {
  172. path.Value.Get.Servers = nil
  173. }
  174. if path.Value.Post != nil {
  175. path.Value.Post.Servers = nil
  176. }
  177. if path.Value.Put != nil {
  178. path.Value.Put.Servers = nil
  179. }
  180. if path.Value.Delete != nil {
  181. path.Value.Delete.Servers = nil
  182. }
  183. if path.Value.Patch != nil {
  184. path.Value.Patch.Servers = nil
  185. }
  186. }
  187. }
  188. // Set all servers on API level
  189. if len(allServers) > 0 {
  190. d.Servers = []*v3.Server{}
  191. for _, server := range allServers {
  192. d.Servers = append(d.Servers, &v3.Server{Url: server})
  193. }
  194. }
  195. // If there is only 1 server, we can safely remove all path level servers
  196. if len(allServers) == 1 {
  197. for _, path := range d.Paths.Path {
  198. path.Value.Servers = nil
  199. }
  200. }
  201. // Sort the tags.
  202. {
  203. pairs := d.Tags
  204. sort.Slice(pairs, func(i, j int) bool {
  205. return pairs[i].Name < pairs[j].Name
  206. })
  207. d.Tags = pairs
  208. }
  209. // Sort the paths.
  210. {
  211. pairs := d.Paths.Path
  212. sort.Slice(pairs, func(i, j int) bool {
  213. return pairs[i].Name < pairs[j].Name
  214. })
  215. d.Paths.Path = pairs
  216. }
  217. // Sort the schemas.
  218. {
  219. pairs := d.Components.Schemas.AdditionalProperties
  220. sort.Slice(pairs, func(i, j int) bool {
  221. return pairs[i].Name < pairs[j].Name
  222. })
  223. d.Components.Schemas.AdditionalProperties = pairs
  224. }
  225. return d
  226. }
  227. // filterCommentString removes linter rules from comments.
  228. func (g *OpenAPIv3Generator) filterCommentString(c protogen.Comments) string {
  229. comment := g.linterRulePattern.ReplaceAllString(string(c), "")
  230. return strings.TrimSpace(comment)
  231. }
  232. func (g *OpenAPIv3Generator) findField(name string, inMessage *protogen.Message) *protogen.Field {
  233. for _, field := range inMessage.Fields {
  234. if string(field.Desc.Name()) == name || string(field.Desc.JSONName()) == name {
  235. return field
  236. }
  237. }
  238. return nil
  239. }
  240. func (g *OpenAPIv3Generator) findAndFormatFieldName(name string, inMessage *protogen.Message) string {
  241. field := g.findField(name, inMessage)
  242. if field != nil {
  243. return g.reflect.formatFieldName(field.Desc)
  244. }
  245. return name
  246. }
  247. // Note that fields which are mapped to URL query parameters must have a primitive type
  248. // or a repeated primitive type or a non-repeated message type.
  249. // In the case of a repeated type, the parameter can be repeated in the URL as ...?param=A&param=B.
  250. // In the case of a message type, each field of the message is mapped to a separate parameter,
  251. // such as ...?foo.a=A&foo.b=B&foo.c=C.
  252. // There are exceptions:
  253. // - for wrapper types it will use the same representation as the wrapped primitive type in JSON
  254. // - for google.protobuf.timestamp type it will be serialized as a string
  255. //
  256. // maps, Struct and Empty can NOT be used
  257. // messages can have any number of sub messages - including circular (e.g. sub.subsub.sub.subsub.id)
  258. // buildQueryParamsV3 extracts any valid query params, including sub and recursive messages
  259. func (g *OpenAPIv3Generator) buildQueryParamsV3(field *protogen.Field) []*v3.ParameterOrReference {
  260. depths := map[string]int{}
  261. return g._buildQueryParamsV3(field, depths)
  262. }
  263. // depths are used to keep track of how many times a message's fields has been seen
  264. func (g *OpenAPIv3Generator) _buildQueryParamsV3(field *protogen.Field, depths map[string]int) []*v3.ParameterOrReference {
  265. parameters := []*v3.ParameterOrReference{}
  266. queryFieldName := g.reflect.formatFieldName(field.Desc)
  267. fieldDescription := g.filterCommentString(field.Comments.Leading)
  268. if field.Desc.IsMap() {
  269. // Map types are not allowed in query parameteres
  270. return parameters
  271. } else if field.Desc.Kind() == protoreflect.MessageKind {
  272. typeName := g.reflect.fullMessageTypeName(field.Desc.Message())
  273. switch typeName {
  274. case ".google.protobuf.Value":
  275. fieldSchema := g.reflect.schemaOrReferenceForField(field.Desc)
  276. parameters = append(parameters,
  277. &v3.ParameterOrReference{
  278. Oneof: &v3.ParameterOrReference_Parameter{
  279. Parameter: &v3.Parameter{
  280. Name: queryFieldName,
  281. In: "query",
  282. Description: fieldDescription,
  283. Required: false,
  284. Schema: fieldSchema,
  285. },
  286. },
  287. })
  288. return parameters
  289. case ".google.protobuf.BoolValue", ".google.protobuf.BytesValue", ".google.protobuf.Int32Value", ".google.protobuf.UInt32Value",
  290. ".google.protobuf.StringValue", ".google.protobuf.Int64Value", ".google.protobuf.UInt64Value", ".google.protobuf.FloatValue",
  291. ".google.protobuf.DoubleValue":
  292. valueField := getValueField(field.Message.Desc)
  293. fieldSchema := g.reflect.schemaOrReferenceForField(valueField)
  294. parameters = append(parameters,
  295. &v3.ParameterOrReference{
  296. Oneof: &v3.ParameterOrReference_Parameter{
  297. Parameter: &v3.Parameter{
  298. Name: queryFieldName,
  299. In: "query",
  300. Description: fieldDescription,
  301. Required: false,
  302. Schema: fieldSchema,
  303. },
  304. },
  305. })
  306. return parameters
  307. case ".google.protobuf.Timestamp":
  308. fieldSchema := g.reflect.schemaOrReferenceForMessage(field.Message.Desc)
  309. parameters = append(parameters,
  310. &v3.ParameterOrReference{
  311. Oneof: &v3.ParameterOrReference_Parameter{
  312. Parameter: &v3.Parameter{
  313. Name: queryFieldName,
  314. In: "query",
  315. Description: fieldDescription,
  316. Required: false,
  317. Schema: fieldSchema,
  318. },
  319. },
  320. })
  321. return parameters
  322. case ".google.protobuf.Duration":
  323. fieldSchema := g.reflect.schemaOrReferenceForMessage(field.Message.Desc)
  324. parameters = append(parameters,
  325. &v3.ParameterOrReference{
  326. Oneof: &v3.ParameterOrReference_Parameter{
  327. Parameter: &v3.Parameter{
  328. Name: queryFieldName,
  329. In: "query",
  330. Description: fieldDescription,
  331. Required: false,
  332. Schema: fieldSchema,
  333. },
  334. },
  335. })
  336. return parameters
  337. }
  338. if field.Desc.IsList() {
  339. // Only non-repeated message types are valid
  340. return parameters
  341. }
  342. // Represent field masks directly as strings (don't expand them).
  343. if typeName == ".google.protobuf.FieldMask" {
  344. fieldSchema := g.reflect.schemaOrReferenceForField(field.Desc)
  345. parameters = append(parameters,
  346. &v3.ParameterOrReference{
  347. Oneof: &v3.ParameterOrReference_Parameter{
  348. Parameter: &v3.Parameter{
  349. Name: queryFieldName,
  350. In: "query",
  351. Description: fieldDescription,
  352. Required: false,
  353. Schema: fieldSchema,
  354. },
  355. },
  356. })
  357. return parameters
  358. }
  359. // Sub messages are allowed, even circular, as long as the final type is a primitive.
  360. // Go through each of the sub message fields
  361. for _, subField := range field.Message.Fields {
  362. subFieldFullName := string(subField.Desc.FullName())
  363. seen, ok := depths[subFieldFullName]
  364. if !ok {
  365. depths[subFieldFullName] = 0
  366. }
  367. if seen < *g.conf.CircularDepth {
  368. depths[subFieldFullName]++
  369. subParams := g._buildQueryParamsV3(subField, depths)
  370. for _, subParam := range subParams {
  371. if param, ok := subParam.Oneof.(*v3.ParameterOrReference_Parameter); ok {
  372. param.Parameter.Name = queryFieldName + "." + param.Parameter.Name
  373. parameters = append(parameters, subParam)
  374. }
  375. }
  376. }
  377. }
  378. } else if field.Desc.Kind() != protoreflect.GroupKind {
  379. // schemaOrReferenceForField also handles array types
  380. fieldSchema := g.reflect.schemaOrReferenceForField(field.Desc)
  381. parameters = append(parameters,
  382. &v3.ParameterOrReference{
  383. Oneof: &v3.ParameterOrReference_Parameter{
  384. Parameter: &v3.Parameter{
  385. Name: queryFieldName,
  386. In: "query",
  387. Description: fieldDescription,
  388. Required: false,
  389. Schema: fieldSchema,
  390. },
  391. },
  392. })
  393. }
  394. return parameters
  395. }
  396. // buildOperationV3 constructs an operation for a set of values.
  397. func (g *OpenAPIv3Generator) buildOperationV3(
  398. d *v3.Document,
  399. operationID string,
  400. tagName string,
  401. description string,
  402. defaultHost string,
  403. path string,
  404. bodyField string,
  405. inputMessage *protogen.Message,
  406. outputMessage *protogen.Message,
  407. ) (*v3.Operation, string) {
  408. // coveredParameters tracks the parameters that have been used in the body or path.
  409. coveredParameters := make([]string, 0)
  410. if bodyField != "" {
  411. coveredParameters = append(coveredParameters, bodyField)
  412. }
  413. // Initialize the list of operation parameters.
  414. parameters := []*v3.ParameterOrReference{}
  415. // Find simple path parameters like {id}
  416. if allMatches := g.pathPattern.FindAllStringSubmatch(path, -1); allMatches != nil {
  417. for _, matches := range allMatches {
  418. // Add the value to the list of covered parameters.
  419. coveredParameters = append(coveredParameters, matches[1])
  420. pathParameter := g.findAndFormatFieldName(matches[1], inputMessage)
  421. path = strings.Replace(path, matches[1], pathParameter, 1)
  422. // Add the path parameters to the operation parameters.
  423. var fieldSchema *v3.SchemaOrReference
  424. var fieldDescription string
  425. field := g.findField(pathParameter, inputMessage)
  426. if field != nil {
  427. fieldSchema = g.reflect.schemaOrReferenceForField(field.Desc)
  428. fieldDescription = g.filterCommentString(field.Comments.Leading)
  429. } else {
  430. // If field does not exist, it is safe to set it to string, as it is ignored downstream
  431. fieldSchema = &v3.SchemaOrReference{
  432. Oneof: &v3.SchemaOrReference_Schema{
  433. Schema: &v3.Schema{
  434. Type: "string",
  435. },
  436. },
  437. }
  438. }
  439. parameters = append(parameters,
  440. &v3.ParameterOrReference{
  441. Oneof: &v3.ParameterOrReference_Parameter{
  442. Parameter: &v3.Parameter{
  443. Name: pathParameter,
  444. In: "path",
  445. Description: fieldDescription,
  446. Required: true,
  447. Schema: fieldSchema,
  448. },
  449. },
  450. })
  451. }
  452. }
  453. // Find named path parameters like {name=shelves/*}
  454. if matches := g.namedPathPattern.FindStringSubmatch(path); matches != nil {
  455. // Build a list of named path parameters.
  456. namedPathParameters := make([]string, 0)
  457. // Add the "name=" "name" value to the list of covered parameters.
  458. coveredParameters = append(coveredParameters, matches[1])
  459. // Convert the path from the starred form to use named path parameters.
  460. starredPath := matches[2]
  461. parts := strings.Split(starredPath, "/")
  462. // The starred path is assumed to be in the form "things/*/otherthings/*".
  463. // We want to convert it to "things/{thingsId}/otherthings/{otherthingsId}".
  464. for i := 0; i < len(parts)-1; i += 2 {
  465. section := parts[i]
  466. namedPathParameter := g.findAndFormatFieldName(section, inputMessage)
  467. namedPathParameter = singular(namedPathParameter)
  468. parts[i+1] = "{" + namedPathParameter + "}"
  469. namedPathParameters = append(namedPathParameters, namedPathParameter)
  470. }
  471. // Rewrite the path to use the path parameters.
  472. newPath := strings.Join(parts, "/")
  473. path = strings.Replace(path, matches[0], newPath, 1)
  474. // Add the named path parameters to the operation parameters.
  475. for _, namedPathParameter := range namedPathParameters {
  476. parameters = append(parameters,
  477. &v3.ParameterOrReference{
  478. Oneof: &v3.ParameterOrReference_Parameter{
  479. Parameter: &v3.Parameter{
  480. Name: namedPathParameter,
  481. In: "path",
  482. Required: true,
  483. Description: "The " + namedPathParameter + " id.",
  484. Schema: &v3.SchemaOrReference{
  485. Oneof: &v3.SchemaOrReference_Schema{
  486. Schema: &v3.Schema{
  487. Type: "string",
  488. },
  489. },
  490. },
  491. },
  492. },
  493. })
  494. }
  495. }
  496. // Add any unhandled fields in the request message as query parameters.
  497. if bodyField != "*" && string(inputMessage.Desc.FullName()) != "google.api.HttpBody" {
  498. for _, field := range inputMessage.Fields {
  499. fieldName := string(field.Desc.Name())
  500. if !contains(coveredParameters, fieldName) && fieldName != bodyField {
  501. fieldParams := g.buildQueryParamsV3(field)
  502. parameters = append(parameters, fieldParams...)
  503. }
  504. }
  505. }
  506. // Create the response.
  507. name, content := g.reflect.responseContentForMessage(outputMessage.Desc)
  508. responses := &v3.Responses{
  509. ResponseOrReference: []*v3.NamedResponseOrReference{
  510. {
  511. Name: name,
  512. Value: &v3.ResponseOrReference{
  513. Oneof: &v3.ResponseOrReference_Response{
  514. Response: &v3.Response{
  515. Description: "OK",
  516. Content: content,
  517. },
  518. },
  519. },
  520. },
  521. },
  522. }
  523. // Add the default reponse if needed
  524. if *g.conf.DefaultResponse {
  525. anySchemaName := g.reflect.formatMessageName(anyProtoDesc)
  526. anySchema := wk.NewGoogleProtobufAnySchema(anySchemaName)
  527. g.addSchemaToDocumentV3(d, anySchema)
  528. statusSchemaName := g.reflect.formatMessageName(statusProtoDesc)
  529. statusSchema := wk.NewGoogleRpcStatusSchema(statusSchemaName, anySchemaName)
  530. g.addSchemaToDocumentV3(d, statusSchema)
  531. defaultResponse := &v3.NamedResponseOrReference{
  532. Name: "default",
  533. Value: &v3.ResponseOrReference{
  534. Oneof: &v3.ResponseOrReference_Response{
  535. Response: &v3.Response{
  536. Description: "Default error response",
  537. Content: wk.NewApplicationJsonMediaType(&v3.SchemaOrReference{
  538. Oneof: &v3.SchemaOrReference_Reference{
  539. Reference: &v3.Reference{XRef: "#/components/schemas/" + statusSchemaName}}}),
  540. },
  541. },
  542. },
  543. }
  544. responses.ResponseOrReference = append(responses.ResponseOrReference, defaultResponse)
  545. }
  546. // Create the operation.
  547. op := &v3.Operation{
  548. Tags: []string{tagName},
  549. Description: description,
  550. OperationId: operationID,
  551. Parameters: parameters,
  552. Responses: responses,
  553. }
  554. if defaultHost != "" {
  555. hostURL, err := url.Parse(defaultHost)
  556. if err == nil {
  557. hostURL.Scheme = "https"
  558. op.Servers = append(op.Servers, &v3.Server{Url: hostURL.String()})
  559. }
  560. }
  561. // If a body field is specified, we need to pass a message as the request body.
  562. if bodyField != "" {
  563. var requestSchema *v3.SchemaOrReference
  564. if bodyField == "*" {
  565. // Pass the entire request message as the request body.
  566. requestSchema = g.reflect.schemaOrReferenceForMessage(inputMessage.Desc)
  567. } else {
  568. // If body refers to a message field, use that type.
  569. for _, field := range inputMessage.Fields {
  570. if string(field.Desc.Name()) == bodyField {
  571. switch field.Desc.Kind() {
  572. case protoreflect.StringKind:
  573. requestSchema = &v3.SchemaOrReference{
  574. Oneof: &v3.SchemaOrReference_Schema{
  575. Schema: &v3.Schema{
  576. Type: "string",
  577. },
  578. },
  579. }
  580. case protoreflect.MessageKind:
  581. requestSchema = g.reflect.schemaOrReferenceForMessage(field.Message.Desc)
  582. default:
  583. log.Printf("unsupported field type %+v", field.Desc)
  584. }
  585. break
  586. }
  587. }
  588. }
  589. op.RequestBody = &v3.RequestBodyOrReference{
  590. Oneof: &v3.RequestBodyOrReference_RequestBody{
  591. RequestBody: &v3.RequestBody{
  592. Required: true,
  593. Content: &v3.MediaTypes{
  594. AdditionalProperties: []*v3.NamedMediaType{
  595. {
  596. Name: "application/json",
  597. Value: &v3.MediaType{
  598. Schema: requestSchema,
  599. },
  600. },
  601. },
  602. },
  603. },
  604. },
  605. }
  606. }
  607. return op, path
  608. }
  609. // addOperationToDocumentV3 adds an operation to the specified path/method.
  610. func (g *OpenAPIv3Generator) addOperationToDocumentV3(d *v3.Document, op *v3.Operation, path string, methodName string) {
  611. var selectedPathItem *v3.NamedPathItem
  612. for _, namedPathItem := range d.Paths.Path {
  613. if namedPathItem.Name == path {
  614. selectedPathItem = namedPathItem
  615. break
  616. }
  617. }
  618. // If we get here, we need to create a path item.
  619. if selectedPathItem == nil {
  620. selectedPathItem = &v3.NamedPathItem{Name: path, Value: &v3.PathItem{}}
  621. d.Paths.Path = append(d.Paths.Path, selectedPathItem)
  622. }
  623. // Set the operation on the specified method.
  624. switch methodName {
  625. case "GET":
  626. selectedPathItem.Value.Get = op
  627. case "POST":
  628. selectedPathItem.Value.Post = op
  629. case "PUT":
  630. selectedPathItem.Value.Put = op
  631. case "DELETE":
  632. selectedPathItem.Value.Delete = op
  633. case "PATCH":
  634. selectedPathItem.Value.Patch = op
  635. }
  636. }
  637. // addPathsToDocumentV3 adds paths from a specified file descriptor.
  638. func (g *OpenAPIv3Generator) addPathsToDocumentV3(d *v3.Document, services []*protogen.Service) {
  639. for _, service := range services {
  640. annotationsCount := 0
  641. for _, method := range service.Methods {
  642. comment := g.filterCommentString(method.Comments.Leading)
  643. inputMessage := method.Input
  644. outputMessage := method.Output
  645. operationID := service.GoName + "_" + method.GoName
  646. extOperation := proto.GetExtension(method.Desc.Options(), v3.E_Operation)
  647. if extOperation == nil || extOperation == v3.E_Operation.InterfaceOf(v3.E_Operation.Zero()) {
  648. continue
  649. }
  650. annotationsCount++
  651. path := fmt.Sprintf("/api/%s/%s", service.Desc.FullName(), method.GoName)
  652. defaultHost := proto.GetExtension(service.Desc.Options(), annotations.E_DefaultHost).(string)
  653. op, path2 := g.buildOperationV3(
  654. d, operationID, service.GoName, comment, defaultHost, path, "*", inputMessage, outputMessage)
  655. // Merge any `Operation` annotations with the current
  656. proto.Merge(op, extOperation.(*v3.Operation))
  657. g.addOperationToDocumentV3(d, op, path2, "POST")
  658. }
  659. if annotationsCount > 0 {
  660. comment := g.filterCommentString(service.Comments.Leading)
  661. d.Tags = append(d.Tags, &v3.Tag{Name: service.GoName, Description: comment})
  662. }
  663. }
  664. }
  665. // addSchemaForMessageToDocumentV3 adds the schema to the document if required
  666. func (g *OpenAPIv3Generator) addSchemaToDocumentV3(d *v3.Document, schema *v3.NamedSchemaOrReference) {
  667. if contains(g.generatedSchemas, schema.Name) {
  668. return
  669. }
  670. g.generatedSchemas = append(g.generatedSchemas, schema.Name)
  671. d.Components.Schemas.AdditionalProperties = append(d.Components.Schemas.AdditionalProperties, schema)
  672. }
  673. // addSchemasForMessagesToDocumentV3 adds info from one file descriptor.
  674. func (g *OpenAPIv3Generator) addSchemasForMessagesToDocumentV3(d *v3.Document, messages []*protogen.Message, edition descriptorpb.Edition) {
  675. // For each message, generate a definition.
  676. for _, message := range messages {
  677. if message.Messages != nil {
  678. g.addSchemasForMessagesToDocumentV3(d, message.Messages, edition)
  679. }
  680. schemaName := g.reflect.formatMessageName(message.Desc)
  681. // Only generate this if we need it and haven't already generated it.
  682. if !contains(g.reflect.requiredSchemas, schemaName) ||
  683. contains(g.generatedSchemas, schemaName) {
  684. continue
  685. }
  686. typeName := g.reflect.fullMessageTypeName(message.Desc)
  687. messageDescription := g.filterCommentString(message.Comments.Leading)
  688. // `google.protobuf.Value` and `google.protobuf.Any` have special JSON transcoding
  689. // so we can't just reflect on the message descriptor.
  690. if typeName == ".google.protobuf.Value" {
  691. g.addSchemaToDocumentV3(d, wk.NewGoogleProtobufValueSchema(schemaName))
  692. continue
  693. } else if typeName == ".google.protobuf.Any" {
  694. g.addSchemaToDocumentV3(d, wk.NewGoogleProtobufAnySchema(schemaName))
  695. continue
  696. } else if typeName == ".google.rpc.Status" {
  697. anySchemaName := g.reflect.formatMessageName(anyProtoDesc)
  698. g.addSchemaToDocumentV3(d, wk.NewGoogleProtobufAnySchema(anySchemaName))
  699. g.addSchemaToDocumentV3(d, wk.NewGoogleRpcStatusSchema(schemaName, anySchemaName))
  700. continue
  701. }
  702. // Build an array holding the fields of the message.
  703. definitionProperties := &v3.Properties{
  704. AdditionalProperties: make([]*v3.NamedSchemaOrReference, 0),
  705. }
  706. var required []string
  707. for _, field := range message.Fields {
  708. // Get the field description from the comments.
  709. description := g.filterCommentString(field.Comments.Leading)
  710. // Check the field annotations to see if this is a readonly or writeonly field.
  711. inputOnly := false
  712. outputOnly := false
  713. isRequired := true
  714. extension := proto.GetExtension(field.Desc.Options(), annotations.E_FieldBehavior)
  715. if extension != nil {
  716. switch v := extension.(type) {
  717. case []annotations.FieldBehavior:
  718. for _, vv := range v {
  719. switch vv {
  720. case annotations.FieldBehavior_OUTPUT_ONLY:
  721. outputOnly = true
  722. case annotations.FieldBehavior_INPUT_ONLY:
  723. inputOnly = true
  724. case annotations.FieldBehavior_OPTIONAL:
  725. isRequired = false
  726. }
  727. }
  728. default:
  729. log.Printf("unsupported extension type %T", extension)
  730. }
  731. }
  732. if edition == descriptorpb.Edition_EDITION_2023 {
  733. if fieldOptions, ok := field.Desc.Options().(*descriptorpb.FieldOptions); ok {
  734. if fieldOptions.GetFeatures().GetFieldPresence() == descriptorpb.FeatureSet_EXPLICIT {
  735. isRequired = false
  736. }
  737. }
  738. }
  739. if isRequired {
  740. required = append(required, g.reflect.formatFieldName(field.Desc))
  741. }
  742. // The field is either described by a reference or a schema.
  743. fieldSchema := g.reflect.schemaOrReferenceForField(field.Desc)
  744. if fieldSchema == nil {
  745. continue
  746. }
  747. // If this field has siblings and is a $ref now, create a new schema use `allOf` to wrap it
  748. wrapperNeeded := inputOnly || outputOnly || description != ""
  749. if wrapperNeeded {
  750. if _, ok := fieldSchema.Oneof.(*v3.SchemaOrReference_Reference); ok {
  751. fieldSchema = &v3.SchemaOrReference{Oneof: &v3.SchemaOrReference_Schema{Schema: &v3.Schema{
  752. AllOf: []*v3.SchemaOrReference{fieldSchema},
  753. }}}
  754. }
  755. }
  756. if schema, ok := fieldSchema.Oneof.(*v3.SchemaOrReference_Schema); ok {
  757. schema.Schema.Description = description
  758. schema.Schema.ReadOnly = outputOnly
  759. schema.Schema.WriteOnly = inputOnly
  760. // Merge any `Property` annotations with the current
  761. extProperty := proto.GetExtension(field.Desc.Options(), v3.E_Property)
  762. if extProperty != nil {
  763. proto.Merge(schema.Schema, extProperty.(*v3.Schema))
  764. }
  765. }
  766. definitionProperties.AdditionalProperties = append(
  767. definitionProperties.AdditionalProperties,
  768. &v3.NamedSchemaOrReference{
  769. Name: g.reflect.formatFieldName(field.Desc),
  770. Value: fieldSchema,
  771. },
  772. )
  773. }
  774. schema := &v3.Schema{
  775. Type: "object",
  776. Description: messageDescription,
  777. Properties: definitionProperties,
  778. Required: required,
  779. }
  780. // Merge any `Schema` annotations with the current
  781. extSchema := proto.GetExtension(message.Desc.Options(), v3.E_Schema)
  782. if extSchema != nil {
  783. proto.Merge(schema, extSchema.(*v3.Schema))
  784. }
  785. // Add the schema to the components.schema list.
  786. g.addSchemaToDocumentV3(d, &v3.NamedSchemaOrReference{
  787. Name: schemaName,
  788. Value: &v3.SchemaOrReference{
  789. Oneof: &v3.SchemaOrReference_Schema{
  790. Schema: schema,
  791. },
  792. },
  793. })
  794. }
  795. }