generator.go 34 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067
  1. // Copyright 2020 Google LLC. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. //
  15. package generator
  16. import (
  17. "fmt"
  18. "log"
  19. "net/url"
  20. "regexp"
  21. "sort"
  22. "strings"
  23. http2 "net/http"
  24. "google.golang.org/protobuf/types/descriptorpb"
  25. "google.golang.org/genproto/googleapis/api/annotations"
  26. status_pb "google.golang.org/genproto/googleapis/rpc/status"
  27. "google.golang.org/protobuf/compiler/protogen"
  28. "google.golang.org/protobuf/proto"
  29. "google.golang.org/protobuf/reflect/protoreflect"
  30. any_pb "google.golang.org/protobuf/types/known/anypb"
  31. "git.ikuban.com/server/kratos-utils/v2/transport/http/handler"
  32. wk "git.ikuban.com/server/swagger-api/v2/generator/wellknown"
  33. v3 "github.com/google/gnostic/openapiv3"
  34. )
  35. type Configuration struct {
  36. Version *string
  37. Title *string
  38. Description *string
  39. Naming *string
  40. FQSchemaNaming *bool
  41. EnumType *string
  42. CircularDepth *int
  43. DefaultResponse *bool
  44. OutputMode *string
  45. PreserveInfo *bool
  46. }
  47. const (
  48. infoURL = "git.ikuban.com/server/swagger-api"
  49. )
  50. // In order to dynamically add google.rpc.Status responses we need
  51. // to know the message descriptors for google.rpc.Status as well
  52. // as google.protobuf.Any.
  53. var statusProtoDesc = (&status_pb.Status{}).ProtoReflect().Descriptor()
  54. var anyProtoDesc = (&any_pb.Any{}).ProtoReflect().Descriptor()
  55. // OpenAPIv3Generator holds internal state needed to generate an OpenAPIv3 document for a transcoded Protocol Buffer service.
  56. type OpenAPIv3Generator struct {
  57. conf Configuration
  58. plugin *protogen.Plugin
  59. inputFiles []*protogen.File
  60. reflect *OpenAPIv3Reflector
  61. generatedSchemas []string // Names of schemas that have already been generated.
  62. linterRulePattern *regexp.Regexp
  63. pathPattern *regexp.Regexp
  64. namedPathPattern *regexp.Regexp
  65. }
  66. // NewOpenAPIv3Generator creates a new generator for a protoc plugin invocation.
  67. func NewOpenAPIv3Generator(plugin *protogen.Plugin, conf Configuration, inputFiles []*protogen.File) *OpenAPIv3Generator {
  68. return &OpenAPIv3Generator{
  69. conf: conf,
  70. plugin: plugin,
  71. inputFiles: inputFiles,
  72. reflect: NewOpenAPIv3Reflector(conf),
  73. generatedSchemas: make([]string, 0),
  74. linterRulePattern: regexp.MustCompile(`\(-- .* --\)`),
  75. pathPattern: regexp.MustCompile("{([^=}]+)}"),
  76. namedPathPattern: regexp.MustCompile("{(.+)=(.+)}"),
  77. }
  78. }
  79. // Run runs the generator.
  80. func (g *OpenAPIv3Generator) Run(outputFile *protogen.GeneratedFile) error {
  81. d := g.buildDocumentV3()
  82. bytes, err := d.YAMLValue("Generated with protoc-gen-openapi\n" + infoURL)
  83. if err != nil {
  84. return fmt.Errorf("failed to marshal yaml: %s", err.Error())
  85. }
  86. if _, err = outputFile.Write(bytes); err != nil {
  87. return fmt.Errorf("failed to write yaml: %s", err.Error())
  88. }
  89. return nil
  90. }
  91. func (g *OpenAPIv3Generator) RunV2() ([]byte, error) {
  92. d := g.buildDocumentV3()
  93. bytes, err := d.YAMLValue("Generated with protoc-gen-openapi\n" + infoURL)
  94. if err != nil {
  95. return bytes, fmt.Errorf("failed to marshal yaml: %s", err.Error())
  96. }
  97. return bytes, nil
  98. }
  99. // buildDocumentV3 builds an OpenAPIv3 document for a plugin request.
  100. func (g *OpenAPIv3Generator) buildDocumentV3() *v3.Document {
  101. d := &v3.Document{}
  102. d.Openapi = "3.0.3"
  103. d.Info = &v3.Info{
  104. Version: *g.conf.Version,
  105. Title: *g.conf.Title,
  106. Description: *g.conf.Description,
  107. }
  108. d.Paths = &v3.Paths{}
  109. d.Components = &v3.Components{
  110. Schemas: &v3.SchemasOrReferences{
  111. AdditionalProperties: []*v3.NamedSchemaOrReference{},
  112. },
  113. }
  114. // Go through the files and add the services to the documents, keeping
  115. // track of which schemas are referenced in the response so we can
  116. // add them later.
  117. for _, file := range g.inputFiles {
  118. if file.Generate {
  119. // Merge any `Document` annotations with the current
  120. extDocument := proto.GetExtension(file.Desc.Options(), v3.E_Document)
  121. if extDocument != nil {
  122. proto.Merge(d, extDocument.(*v3.Document))
  123. }
  124. g.addPathsToDocumentV3(d, file.Services)
  125. }
  126. }
  127. // While we have required schemas left to generate, go through the files again
  128. // looking for the related message and adding them to the document if required.
  129. for len(g.reflect.requiredSchemas) > 0 {
  130. count := len(g.reflect.requiredSchemas)
  131. for _, file := range g.plugin.Files {
  132. g.addSchemasForMessagesToDocumentV3(d, file.Messages, file.Proto.GetEdition())
  133. }
  134. g.reflect.requiredSchemas = g.reflect.requiredSchemas[count:len(g.reflect.requiredSchemas)]
  135. }
  136. // If there is only 1 service, then use it's title for the
  137. // document, if the document is missing it.
  138. if len(d.Tags) == 1 {
  139. if d.Info.Title == "" && d.Tags[0].Name != "" {
  140. d.Info.Title = d.Tags[0].Name + " API"
  141. }
  142. if d.Info.Description == "" {
  143. d.Info.Description = d.Tags[0].Description
  144. }
  145. d.Tags[0].Description = ""
  146. }
  147. if g.conf.PreserveInfo != nil && *g.conf.PreserveInfo {
  148. d.Info = &v3.Info{
  149. Version: *g.conf.Version,
  150. Title: *g.conf.Title,
  151. Description: *g.conf.Description,
  152. }
  153. }
  154. allServers := []string{}
  155. // If paths methods has servers, but they're all the same, then move servers to path level
  156. for _, path := range d.Paths.Path {
  157. servers := []string{}
  158. // Only 1 server will ever be set, per method, by the generator
  159. if path.Value.Get != nil && len(path.Value.Get.Servers) == 1 {
  160. servers = appendUnique(servers, path.Value.Get.Servers[0].Url)
  161. allServers = appendUnique(allServers, path.Value.Get.Servers[0].Url)
  162. }
  163. if path.Value.Post != nil && len(path.Value.Post.Servers) == 1 {
  164. servers = appendUnique(servers, path.Value.Post.Servers[0].Url)
  165. allServers = appendUnique(allServers, path.Value.Post.Servers[0].Url)
  166. }
  167. if path.Value.Put != nil && len(path.Value.Put.Servers) == 1 {
  168. servers = appendUnique(servers, path.Value.Put.Servers[0].Url)
  169. allServers = appendUnique(allServers, path.Value.Put.Servers[0].Url)
  170. }
  171. if path.Value.Delete != nil && len(path.Value.Delete.Servers) == 1 {
  172. servers = appendUnique(servers, path.Value.Delete.Servers[0].Url)
  173. allServers = appendUnique(allServers, path.Value.Delete.Servers[0].Url)
  174. }
  175. if path.Value.Patch != nil && len(path.Value.Patch.Servers) == 1 {
  176. servers = appendUnique(servers, path.Value.Patch.Servers[0].Url)
  177. allServers = appendUnique(allServers, path.Value.Patch.Servers[0].Url)
  178. }
  179. if path.Value.Head != nil && len(path.Value.Head.Servers) == 1 {
  180. servers = appendUnique(servers, path.Value.Head.Servers[0].Url)
  181. allServers = appendUnique(allServers, path.Value.Head.Servers[0].Url)
  182. }
  183. if path.Value.Options != nil && len(path.Value.Options.Servers) == 1 {
  184. servers = appendUnique(servers, path.Value.Options.Servers[0].Url)
  185. allServers = appendUnique(allServers, path.Value.Options.Servers[0].Url)
  186. }
  187. if path.Value.Trace != nil && len(path.Value.Trace.Servers) == 1 {
  188. servers = appendUnique(servers, path.Value.Trace.Servers[0].Url)
  189. allServers = appendUnique(allServers, path.Value.Trace.Servers[0].Url)
  190. }
  191. if len(servers) == 1 {
  192. path.Value.Servers = []*v3.Server{{Url: servers[0]}}
  193. if path.Value.Get != nil {
  194. path.Value.Get.Servers = nil
  195. }
  196. if path.Value.Post != nil {
  197. path.Value.Post.Servers = nil
  198. }
  199. if path.Value.Put != nil {
  200. path.Value.Put.Servers = nil
  201. }
  202. if path.Value.Delete != nil {
  203. path.Value.Delete.Servers = nil
  204. }
  205. if path.Value.Patch != nil {
  206. path.Value.Patch.Servers = nil
  207. }
  208. if path.Value.Head != nil {
  209. path.Value.Head.Servers = nil
  210. }
  211. if path.Value.Options != nil {
  212. path.Value.Options.Servers = nil
  213. }
  214. if path.Value.Trace != nil {
  215. path.Value.Trace.Servers = nil
  216. }
  217. }
  218. }
  219. // Set all servers on API level
  220. if len(allServers) > 0 {
  221. d.Servers = []*v3.Server{}
  222. for _, server := range allServers {
  223. d.Servers = append(d.Servers, &v3.Server{Url: server})
  224. }
  225. }
  226. // If there is only 1 server, we can safely remove all path level servers
  227. if len(allServers) == 1 {
  228. for _, path := range d.Paths.Path {
  229. path.Value.Servers = nil
  230. }
  231. }
  232. // Sort the tags.
  233. {
  234. pairs := d.Tags
  235. sort.Slice(pairs, func(i, j int) bool {
  236. return pairs[i].Name < pairs[j].Name
  237. })
  238. d.Tags = pairs
  239. }
  240. // Sort the paths.
  241. {
  242. pairs := d.Paths.Path
  243. sort.Slice(pairs, func(i, j int) bool {
  244. return pairs[i].Name < pairs[j].Name
  245. })
  246. d.Paths.Path = pairs
  247. }
  248. // Sort the schemas.
  249. {
  250. pairs := d.Components.Schemas.AdditionalProperties
  251. sort.Slice(pairs, func(i, j int) bool {
  252. return pairs[i].Name < pairs[j].Name
  253. })
  254. d.Components.Schemas.AdditionalProperties = pairs
  255. }
  256. // Deduplicate and sort security schemes.
  257. if d.Components.SecuritySchemes != nil && d.Components.SecuritySchemes.AdditionalProperties != nil {
  258. seen := make(map[string]bool)
  259. unique := make([]*v3.NamedSecuritySchemeOrReference, 0)
  260. for _, scheme := range d.Components.SecuritySchemes.AdditionalProperties {
  261. if !seen[scheme.Name] {
  262. seen[scheme.Name] = true
  263. unique = append(unique, scheme)
  264. }
  265. }
  266. sort.Slice(unique, func(i, j int) bool {
  267. return unique[i].Name < unique[j].Name
  268. })
  269. d.Components.SecuritySchemes.AdditionalProperties = unique
  270. }
  271. // Deduplicate and sort responses.
  272. if d.Components.Responses != nil && d.Components.Responses.AdditionalProperties != nil {
  273. seen := make(map[string]bool)
  274. unique := make([]*v3.NamedResponseOrReference, 0)
  275. for _, resp := range d.Components.Responses.AdditionalProperties {
  276. if !seen[resp.Name] {
  277. seen[resp.Name] = true
  278. unique = append(unique, resp)
  279. }
  280. }
  281. sort.Slice(unique, func(i, j int) bool {
  282. return unique[i].Name < unique[j].Name
  283. })
  284. d.Components.Responses.AdditionalProperties = unique
  285. }
  286. // Deduplicate and sort parameters.
  287. if d.Components.Parameters != nil && d.Components.Parameters.AdditionalProperties != nil {
  288. seen := make(map[string]bool)
  289. unique := make([]*v3.NamedParameterOrReference, 0)
  290. for _, param := range d.Components.Parameters.AdditionalProperties {
  291. if !seen[param.Name] {
  292. seen[param.Name] = true
  293. unique = append(unique, param)
  294. }
  295. }
  296. sort.Slice(unique, func(i, j int) bool {
  297. return unique[i].Name < unique[j].Name
  298. })
  299. d.Components.Parameters.AdditionalProperties = unique
  300. }
  301. // Deduplicate and sort request bodies.
  302. if d.Components.RequestBodies != nil && d.Components.RequestBodies.AdditionalProperties != nil {
  303. seen := make(map[string]bool)
  304. unique := make([]*v3.NamedRequestBodyOrReference, 0)
  305. for _, body := range d.Components.RequestBodies.AdditionalProperties {
  306. if !seen[body.Name] {
  307. seen[body.Name] = true
  308. unique = append(unique, body)
  309. }
  310. }
  311. sort.Slice(unique, func(i, j int) bool {
  312. return unique[i].Name < unique[j].Name
  313. })
  314. d.Components.RequestBodies.AdditionalProperties = unique
  315. }
  316. // Deduplicate and sort headers.
  317. if d.Components.Headers != nil && d.Components.Headers.AdditionalProperties != nil {
  318. seen := make(map[string]bool)
  319. unique := make([]*v3.NamedHeaderOrReference, 0)
  320. for _, header := range d.Components.Headers.AdditionalProperties {
  321. if !seen[header.Name] {
  322. seen[header.Name] = true
  323. unique = append(unique, header)
  324. }
  325. }
  326. sort.Slice(unique, func(i, j int) bool {
  327. return unique[i].Name < unique[j].Name
  328. })
  329. d.Components.Headers.AdditionalProperties = unique
  330. }
  331. // Deduplicate servers by URL.
  332. if d.Servers != nil && len(d.Servers) > 0 {
  333. seen := make(map[string]bool)
  334. unique := make([]*v3.Server, 0)
  335. for _, server := range d.Servers {
  336. if !seen[server.Url] {
  337. seen[server.Url] = true
  338. unique = append(unique, server)
  339. }
  340. }
  341. d.Servers = unique
  342. }
  343. return d
  344. }
  345. // filterCommentString removes linter rules from comments.
  346. func (g *OpenAPIv3Generator) filterCommentString(c protogen.Comments) string {
  347. comment := g.linterRulePattern.ReplaceAllString(string(c), "")
  348. return strings.TrimSpace(comment)
  349. }
  350. func (g *OpenAPIv3Generator) findField(name string, inMessage *protogen.Message) *protogen.Field {
  351. for _, field := range inMessage.Fields {
  352. if string(field.Desc.Name()) == name || string(field.Desc.JSONName()) == name {
  353. return field
  354. }
  355. }
  356. return nil
  357. }
  358. func (g *OpenAPIv3Generator) findAndFormatFieldName(name string, inMessage *protogen.Message) string {
  359. field := g.findField(name, inMessage)
  360. if field != nil {
  361. return g.reflect.formatFieldName(field.Desc)
  362. }
  363. return name
  364. }
  365. // Note that fields which are mapped to URL query parameters must have a primitive type
  366. // or a repeated primitive type or a non-repeated message type.
  367. // In the case of a repeated type, the parameter can be repeated in the URL as ...?param=A&param=B.
  368. // In the case of a message type, each field of the message is mapped to a separate parameter,
  369. // such as ...?foo.a=A&foo.b=B&foo.c=C.
  370. // There are exceptions:
  371. // - for wrapper types it will use the same representation as the wrapped primitive type in JSON
  372. // - for google.protobuf.timestamp type it will be serialized as a string
  373. //
  374. // maps, Struct and Empty can NOT be used
  375. // messages can have any number of sub messages - including circular (e.g. sub.subsub.sub.subsub.id)
  376. // buildQueryParamsV3 extracts any valid query params, including sub and recursive messages
  377. func (g *OpenAPIv3Generator) buildQueryParamsV3(field *protogen.Field) []*v3.ParameterOrReference {
  378. depths := map[string]int{}
  379. return g._buildQueryParamsV3(field, depths)
  380. }
  381. // depths are used to keep track of how many times a message's fields has been seen
  382. func (g *OpenAPIv3Generator) _buildQueryParamsV3(field *protogen.Field, depths map[string]int) []*v3.ParameterOrReference {
  383. parameters := []*v3.ParameterOrReference{}
  384. queryFieldName := g.reflect.formatFieldName(field.Desc)
  385. fieldDescription := g.filterCommentString(field.Comments.Leading)
  386. if field.Desc.IsMap() {
  387. // Map types are not allowed in query parameteres
  388. return parameters
  389. } else if field.Desc.Kind() == protoreflect.MessageKind {
  390. typeName := g.reflect.fullMessageTypeName(field.Desc.Message())
  391. switch typeName {
  392. case ".google.protobuf.Value":
  393. fieldSchema := g.reflect.schemaOrReferenceForField(field.Desc)
  394. parameters = append(parameters,
  395. &v3.ParameterOrReference{
  396. Oneof: &v3.ParameterOrReference_Parameter{
  397. Parameter: &v3.Parameter{
  398. Name: queryFieldName,
  399. In: "query",
  400. Description: fieldDescription,
  401. Required: false,
  402. Schema: fieldSchema,
  403. },
  404. },
  405. })
  406. return parameters
  407. case ".google.protobuf.BoolValue", ".google.protobuf.BytesValue", ".google.protobuf.Int32Value", ".google.protobuf.UInt32Value",
  408. ".google.protobuf.StringValue", ".google.protobuf.Int64Value", ".google.protobuf.UInt64Value", ".google.protobuf.FloatValue",
  409. ".google.protobuf.DoubleValue":
  410. valueField := getValueField(field.Message.Desc)
  411. fieldSchema := g.reflect.schemaOrReferenceForField(valueField)
  412. parameters = append(parameters,
  413. &v3.ParameterOrReference{
  414. Oneof: &v3.ParameterOrReference_Parameter{
  415. Parameter: &v3.Parameter{
  416. Name: queryFieldName,
  417. In: "query",
  418. Description: fieldDescription,
  419. Required: false,
  420. Schema: fieldSchema,
  421. },
  422. },
  423. })
  424. return parameters
  425. case ".google.protobuf.Timestamp":
  426. fieldSchema := g.reflect.schemaOrReferenceForMessage(field.Message.Desc)
  427. parameters = append(parameters,
  428. &v3.ParameterOrReference{
  429. Oneof: &v3.ParameterOrReference_Parameter{
  430. Parameter: &v3.Parameter{
  431. Name: queryFieldName,
  432. In: "query",
  433. Description: fieldDescription,
  434. Required: false,
  435. Schema: fieldSchema,
  436. },
  437. },
  438. })
  439. return parameters
  440. case ".google.protobuf.Duration":
  441. fieldSchema := g.reflect.schemaOrReferenceForMessage(field.Message.Desc)
  442. parameters = append(parameters,
  443. &v3.ParameterOrReference{
  444. Oneof: &v3.ParameterOrReference_Parameter{
  445. Parameter: &v3.Parameter{
  446. Name: queryFieldName,
  447. In: "query",
  448. Description: fieldDescription,
  449. Required: false,
  450. Schema: fieldSchema,
  451. },
  452. },
  453. })
  454. return parameters
  455. }
  456. if field.Desc.IsList() {
  457. // Only non-repeated message types are valid
  458. return parameters
  459. }
  460. // Represent field masks directly as strings (don't expand them).
  461. if typeName == ".google.protobuf.FieldMask" {
  462. fieldSchema := g.reflect.schemaOrReferenceForField(field.Desc)
  463. parameters = append(parameters,
  464. &v3.ParameterOrReference{
  465. Oneof: &v3.ParameterOrReference_Parameter{
  466. Parameter: &v3.Parameter{
  467. Name: queryFieldName,
  468. In: "query",
  469. Description: fieldDescription,
  470. Required: false,
  471. Schema: fieldSchema,
  472. },
  473. },
  474. })
  475. return parameters
  476. }
  477. // Sub messages are allowed, even circular, as long as the final type is a primitive.
  478. // Go through each of the sub message fields
  479. for _, subField := range field.Message.Fields {
  480. subFieldFullName := string(subField.Desc.FullName())
  481. seen, ok := depths[subFieldFullName]
  482. if !ok {
  483. depths[subFieldFullName] = 0
  484. }
  485. if seen < *g.conf.CircularDepth {
  486. depths[subFieldFullName]++
  487. subParams := g._buildQueryParamsV3(subField, depths)
  488. for _, subParam := range subParams {
  489. if param, ok := subParam.Oneof.(*v3.ParameterOrReference_Parameter); ok {
  490. param.Parameter.Name = queryFieldName + "." + param.Parameter.Name
  491. parameters = append(parameters, subParam)
  492. }
  493. }
  494. }
  495. }
  496. } else if field.Desc.Kind() != protoreflect.GroupKind {
  497. // schemaOrReferenceForField also handles array types
  498. fieldSchema := g.reflect.schemaOrReferenceForField(field.Desc)
  499. parameters = append(parameters,
  500. &v3.ParameterOrReference{
  501. Oneof: &v3.ParameterOrReference_Parameter{
  502. Parameter: &v3.Parameter{
  503. Name: queryFieldName,
  504. In: "query",
  505. Description: fieldDescription,
  506. Required: false,
  507. Schema: fieldSchema,
  508. },
  509. },
  510. })
  511. }
  512. return parameters
  513. }
  514. // buildOperationV3 constructs an operation for a set of values.
  515. func (g *OpenAPIv3Generator) buildOperationV3(
  516. d *v3.Document,
  517. operationID string,
  518. tagName string,
  519. description string,
  520. defaultHost string,
  521. path string,
  522. bodyField string,
  523. inputMessage *protogen.Message,
  524. outputMessage *protogen.Message,
  525. ) (*v3.Operation, string) {
  526. // coveredParameters tracks the parameters that have been used in the body or path.
  527. coveredParameters := make([]string, 0)
  528. if bodyField != "" {
  529. coveredParameters = append(coveredParameters, bodyField)
  530. }
  531. // Initialize the list of operation parameters.
  532. parameters := []*v3.ParameterOrReference{}
  533. // Find simple path parameters like {id}
  534. if allMatches := g.pathPattern.FindAllStringSubmatch(path, -1); allMatches != nil {
  535. for _, matches := range allMatches {
  536. // Add the value to the list of covered parameters.
  537. coveredParameters = append(coveredParameters, matches[1])
  538. pathParameter := g.findAndFormatFieldName(matches[1], inputMessage)
  539. path = strings.Replace(path, matches[1], pathParameter, 1)
  540. // Add the path parameters to the operation parameters.
  541. var fieldSchema *v3.SchemaOrReference
  542. var fieldDescription string
  543. field := g.findField(pathParameter, inputMessage)
  544. if field != nil {
  545. fieldSchema = g.reflect.schemaOrReferenceForField(field.Desc)
  546. fieldDescription = g.filterCommentString(field.Comments.Leading)
  547. } else {
  548. // If field does not exist, it is safe to set it to string, as it is ignored downstream
  549. fieldSchema = &v3.SchemaOrReference{
  550. Oneof: &v3.SchemaOrReference_Schema{
  551. Schema: &v3.Schema{
  552. Type: "string",
  553. },
  554. },
  555. }
  556. }
  557. parameters = append(parameters,
  558. &v3.ParameterOrReference{
  559. Oneof: &v3.ParameterOrReference_Parameter{
  560. Parameter: &v3.Parameter{
  561. Name: pathParameter,
  562. In: "path",
  563. Description: fieldDescription,
  564. Required: true,
  565. Schema: fieldSchema,
  566. },
  567. },
  568. })
  569. }
  570. }
  571. // Find named path parameters like {name=shelves/*}
  572. if matches := g.namedPathPattern.FindStringSubmatch(path); matches != nil {
  573. // Build a list of named path parameters.
  574. namedPathParameters := make([]string, 0)
  575. // Add the "name=" "name" value to the list of covered parameters.
  576. coveredParameters = append(coveredParameters, matches[1])
  577. // Convert the path from the starred form to use named path parameters.
  578. starredPath := matches[2]
  579. parts := strings.Split(starredPath, "/")
  580. // The starred path is assumed to be in the form "things/*/otherthings/*".
  581. // We want to convert it to "things/{thingsId}/otherthings/{otherthingsId}".
  582. for i := 0; i < len(parts)-1; i += 2 {
  583. section := parts[i]
  584. namedPathParameter := g.findAndFormatFieldName(section, inputMessage)
  585. namedPathParameter = singular(namedPathParameter)
  586. parts[i+1] = "{" + namedPathParameter + "}"
  587. namedPathParameters = append(namedPathParameters, namedPathParameter)
  588. }
  589. // Rewrite the path to use the path parameters.
  590. newPath := strings.Join(parts, "/")
  591. path = strings.Replace(path, matches[0], newPath, 1)
  592. // Add the named path parameters to the operation parameters.
  593. for _, namedPathParameter := range namedPathParameters {
  594. parameters = append(parameters,
  595. &v3.ParameterOrReference{
  596. Oneof: &v3.ParameterOrReference_Parameter{
  597. Parameter: &v3.Parameter{
  598. Name: namedPathParameter,
  599. In: "path",
  600. Required: true,
  601. Description: "The " + namedPathParameter + " id.",
  602. Schema: &v3.SchemaOrReference{
  603. Oneof: &v3.SchemaOrReference_Schema{
  604. Schema: &v3.Schema{
  605. Type: "string",
  606. },
  607. },
  608. },
  609. },
  610. },
  611. })
  612. }
  613. }
  614. // Add any unhandled fields in the request message as query parameters.
  615. if bodyField != "*" && string(inputMessage.Desc.FullName()) != "google.api.HttpBody" {
  616. for _, field := range inputMessage.Fields {
  617. fieldName := string(field.Desc.Name())
  618. if !contains(coveredParameters, fieldName) && fieldName != bodyField {
  619. fieldParams := g.buildQueryParamsV3(field)
  620. parameters = append(parameters, fieldParams...)
  621. }
  622. }
  623. }
  624. // Create the response.
  625. name, content := g.reflect.responseContentForMessage(outputMessage.Desc)
  626. responses := &v3.Responses{
  627. ResponseOrReference: []*v3.NamedResponseOrReference{
  628. {
  629. Name: name,
  630. Value: &v3.ResponseOrReference{
  631. Oneof: &v3.ResponseOrReference_Response{
  632. Response: &v3.Response{
  633. Description: "OK",
  634. Content: content,
  635. },
  636. },
  637. },
  638. },
  639. },
  640. }
  641. // Add the default reponse if needed
  642. if *g.conf.DefaultResponse {
  643. anySchemaName := g.reflect.formatMessageName(anyProtoDesc)
  644. anySchema := wk.NewGoogleProtobufAnySchema(anySchemaName)
  645. g.addSchemaToDocumentV3(d, anySchema)
  646. statusSchemaName := g.reflect.formatMessageName(statusProtoDesc)
  647. statusSchema := wk.NewGoogleRpcStatusSchema(statusSchemaName, anySchemaName)
  648. g.addSchemaToDocumentV3(d, statusSchema)
  649. defaultResponse := &v3.NamedResponseOrReference{
  650. Name: "default",
  651. Value: &v3.ResponseOrReference{
  652. Oneof: &v3.ResponseOrReference_Response{
  653. Response: &v3.Response{
  654. Description: "Default error response",
  655. Content: wk.NewApplicationJsonMediaType(&v3.SchemaOrReference{
  656. Oneof: &v3.SchemaOrReference_Reference{
  657. Reference: &v3.Reference{XRef: "#/components/schemas/" + statusSchemaName}}}),
  658. },
  659. },
  660. },
  661. }
  662. responses.ResponseOrReference = append(responses.ResponseOrReference, defaultResponse)
  663. }
  664. // Create the operation.
  665. op := &v3.Operation{
  666. Tags: []string{tagName},
  667. Description: description,
  668. OperationId: operationID,
  669. Parameters: parameters,
  670. Responses: responses,
  671. }
  672. if defaultHost != "" {
  673. hostURL, err := url.Parse(defaultHost)
  674. if err == nil {
  675. hostURL.Scheme = "https"
  676. op.Servers = append(op.Servers, &v3.Server{Url: hostURL.String()})
  677. }
  678. }
  679. // If a body field is specified, we need to pass a message as the request body.
  680. if bodyField != "" {
  681. var requestSchema *v3.SchemaOrReference
  682. if bodyField == "*" {
  683. // Pass the entire request message as the request body.
  684. requestSchema = g.reflect.schemaOrReferenceForMessage(inputMessage.Desc)
  685. } else {
  686. // If body refers to a message field, use that type.
  687. for _, field := range inputMessage.Fields {
  688. if string(field.Desc.Name()) == bodyField {
  689. switch field.Desc.Kind() {
  690. case protoreflect.StringKind:
  691. requestSchema = &v3.SchemaOrReference{
  692. Oneof: &v3.SchemaOrReference_Schema{
  693. Schema: &v3.Schema{
  694. Type: "string",
  695. },
  696. },
  697. }
  698. case protoreflect.MessageKind:
  699. requestSchema = g.reflect.schemaOrReferenceForMessage(field.Message.Desc)
  700. default:
  701. log.Printf("unsupported field type %+v", field.Desc)
  702. }
  703. break
  704. }
  705. }
  706. }
  707. op.RequestBody = &v3.RequestBodyOrReference{
  708. Oneof: &v3.RequestBodyOrReference_RequestBody{
  709. RequestBody: &v3.RequestBody{
  710. Required: true,
  711. Content: &v3.MediaTypes{
  712. AdditionalProperties: []*v3.NamedMediaType{
  713. {
  714. Name: "application/json",
  715. Value: &v3.MediaType{
  716. Schema: requestSchema,
  717. },
  718. },
  719. },
  720. },
  721. },
  722. },
  723. }
  724. }
  725. return op, path
  726. }
  727. // addOperationToDocumentV3 adds an operation to the specified path/method.
  728. func (g *OpenAPIv3Generator) addOperationToDocumentV3(d *v3.Document, op *v3.Operation, path string, methodName string) {
  729. var selectedPathItem *v3.NamedPathItem
  730. for _, namedPathItem := range d.Paths.Path {
  731. if namedPathItem.Name == path {
  732. selectedPathItem = namedPathItem
  733. break
  734. }
  735. }
  736. // If we get here, we need to create a path item.
  737. if selectedPathItem == nil {
  738. selectedPathItem = &v3.NamedPathItem{Name: path, Value: &v3.PathItem{}}
  739. d.Paths.Path = append(d.Paths.Path, selectedPathItem)
  740. }
  741. // Set the operation on the specified method.
  742. switch methodName {
  743. case "GET":
  744. selectedPathItem.Value.Get = op
  745. case "POST":
  746. selectedPathItem.Value.Post = op
  747. case "PUT":
  748. selectedPathItem.Value.Put = op
  749. case "DELETE":
  750. selectedPathItem.Value.Delete = op
  751. case "PATCH":
  752. selectedPathItem.Value.Patch = op
  753. case http2.MethodHead:
  754. selectedPathItem.Value.Head = op
  755. case http2.MethodOptions:
  756. selectedPathItem.Value.Options = op
  757. case http2.MethodTrace:
  758. selectedPathItem.Value.Trace = op
  759. }
  760. }
  761. // addPathsToDocumentV3 adds paths from a specified file descriptor.
  762. func (g *OpenAPIv3Generator) addPathsToDocumentV3(d *v3.Document, services []*protogen.Service) {
  763. for _, service := range services {
  764. annotationsCount := 0
  765. for _, method := range service.Methods {
  766. comment := g.filterCommentString(method.Comments.Leading)
  767. inputMessage := method.Input
  768. outputMessage := method.Output
  769. operationID := service.GoName + "_" + method.GoName
  770. extOperation := proto.GetExtension(method.Desc.Options(), v3.E_Operation)
  771. if extOperation == nil || extOperation == v3.E_Operation.InterfaceOf(v3.E_Operation.Zero()) {
  772. continue
  773. }
  774. var path string
  775. var httpMethod string
  776. var bodyField string
  777. httpOperation := proto.GetExtension(method.Desc.Options(), annotations.E_Http)
  778. if httpOperation != nil && httpOperation != annotations.E_Http.InterfaceOf(annotations.E_Http.Zero()) {
  779. _httpOperation := httpOperation.(*annotations.HttpRule)
  780. switch httpRule := _httpOperation.GetPattern().(type) {
  781. case *annotations.HttpRule_Post:
  782. path = httpRule.Post
  783. httpMethod = http2.MethodPost
  784. bodyField = _httpOperation.GetBody()
  785. case *annotations.HttpRule_Get:
  786. path = httpRule.Get
  787. httpMethod = http2.MethodGet
  788. bodyField = ""
  789. case *annotations.HttpRule_Delete:
  790. path = httpRule.Delete
  791. httpMethod = http2.MethodDelete
  792. bodyField = ""
  793. case *annotations.HttpRule_Put:
  794. path = httpRule.Put
  795. httpMethod = http2.MethodPut
  796. bodyField = _httpOperation.GetBody()
  797. case *annotations.HttpRule_Patch:
  798. path = httpRule.Patch
  799. httpMethod = http2.MethodPatch
  800. bodyField = _httpOperation.GetBody()
  801. case *annotations.HttpRule_Custom:
  802. path = httpRule.Custom.Path
  803. httpMethod = httpRule.Custom.Kind
  804. bodyField = _httpOperation.GetBody()
  805. }
  806. }
  807. annotationsCount++
  808. if path == "" {
  809. path = handler.PathGenerator(string(service.Desc.FullName()), method.GoName)
  810. }
  811. if httpMethod == "" {
  812. httpMethod = http2.MethodPost
  813. }
  814. if bodyField == "" && (httpMethod == http2.MethodPost || httpMethod == http2.MethodPut || httpMethod == http2.MethodPatch) {
  815. bodyField = "*"
  816. }
  817. defaultHost := proto.GetExtension(service.Desc.Options(), annotations.E_DefaultHost).(string)
  818. op, path2 := g.buildOperationV3(
  819. d, operationID, d.Info.Title, comment, defaultHost, path, bodyField, inputMessage, outputMessage)
  820. // Merge any `Operation` annotations with the current
  821. proto.Merge(op, extOperation.(*v3.Operation))
  822. g.addOperationToDocumentV3(d, op, path2, httpMethod)
  823. }
  824. if annotationsCount > 0 {
  825. d.Tags = append(d.Tags, &v3.Tag{Name: d.Info.Title, Description: d.Info.Description})
  826. }
  827. }
  828. }
  829. // addSchemaForMessageToDocumentV3 adds the schema to the document if required
  830. func (g *OpenAPIv3Generator) addSchemaToDocumentV3(d *v3.Document, schema *v3.NamedSchemaOrReference) {
  831. if contains(g.generatedSchemas, schema.Name) {
  832. return
  833. }
  834. g.generatedSchemas = append(g.generatedSchemas, schema.Name)
  835. d.Components.Schemas.AdditionalProperties = append(d.Components.Schemas.AdditionalProperties, schema)
  836. }
  837. // addSchemasForMessagesToDocumentV3 adds info from one file descriptor.
  838. func (g *OpenAPIv3Generator) addSchemasForMessagesToDocumentV3(d *v3.Document, messages []*protogen.Message, edition descriptorpb.Edition) {
  839. // For each message, generate a definition.
  840. for _, message := range messages {
  841. if message.Messages != nil {
  842. g.addSchemasForMessagesToDocumentV3(d, message.Messages, edition)
  843. }
  844. schemaName := g.reflect.formatMessageName(message.Desc)
  845. // Only generate this if we need it and haven't already generated it.
  846. if !contains(g.reflect.requiredSchemas, schemaName) ||
  847. contains(g.generatedSchemas, schemaName) {
  848. continue
  849. }
  850. typeName := g.reflect.fullMessageTypeName(message.Desc)
  851. messageDescription := g.filterCommentString(message.Comments.Leading)
  852. // `google.protobuf.Value` and `google.protobuf.Any` have special JSON transcoding
  853. // so we can't just reflect on the message descriptor.
  854. if typeName == ".google.protobuf.Value" {
  855. g.addSchemaToDocumentV3(d, wk.NewGoogleProtobufValueSchema(schemaName))
  856. continue
  857. } else if typeName == ".google.protobuf.Any" {
  858. g.addSchemaToDocumentV3(d, wk.NewGoogleProtobufAnySchema(schemaName))
  859. continue
  860. } else if typeName == ".google.rpc.Status" {
  861. anySchemaName := g.reflect.formatMessageName(anyProtoDesc)
  862. g.addSchemaToDocumentV3(d, wk.NewGoogleProtobufAnySchema(anySchemaName))
  863. g.addSchemaToDocumentV3(d, wk.NewGoogleRpcStatusSchema(schemaName, anySchemaName))
  864. continue
  865. }
  866. // Build an array holding the fields of the message.
  867. definitionProperties := &v3.Properties{
  868. AdditionalProperties: make([]*v3.NamedSchemaOrReference, 0),
  869. }
  870. var required []string
  871. for _, field := range message.Fields {
  872. // Get the field description from the comments.
  873. description := g.filterCommentString(field.Comments.Leading)
  874. // Check the field annotations to see if this is a readonly or writeonly field.
  875. inputOnly := false
  876. outputOnly := false
  877. isRequired := true
  878. extension := proto.GetExtension(field.Desc.Options(), annotations.E_FieldBehavior)
  879. if extension != nil {
  880. switch v := extension.(type) {
  881. case []annotations.FieldBehavior:
  882. for _, vv := range v {
  883. switch vv {
  884. case annotations.FieldBehavior_OUTPUT_ONLY:
  885. outputOnly = true
  886. case annotations.FieldBehavior_INPUT_ONLY:
  887. inputOnly = true
  888. case annotations.FieldBehavior_OPTIONAL:
  889. isRequired = false
  890. }
  891. }
  892. default:
  893. log.Printf("unsupported extension type %T", extension)
  894. }
  895. }
  896. if edition == descriptorpb.Edition_EDITION_2023 {
  897. if fieldOptions, ok := field.Desc.Options().(*descriptorpb.FieldOptions); ok {
  898. if fieldOptions.GetFeatures().GetFieldPresence() == descriptorpb.FeatureSet_EXPLICIT {
  899. isRequired = false
  900. }
  901. }
  902. }
  903. if isRequired {
  904. required = append(required, g.reflect.formatFieldName(field.Desc))
  905. }
  906. // The field is either described by a reference or a schema.
  907. fieldSchema := g.reflect.schemaOrReferenceForField(field.Desc)
  908. if fieldSchema == nil {
  909. continue
  910. }
  911. // If this field has siblings and is a $ref now, create a new schema use `allOf` to wrap it
  912. wrapperNeeded := inputOnly || outputOnly || description != ""
  913. if wrapperNeeded {
  914. if _, ok := fieldSchema.Oneof.(*v3.SchemaOrReference_Reference); ok {
  915. fieldSchema = &v3.SchemaOrReference{Oneof: &v3.SchemaOrReference_Schema{Schema: &v3.Schema{
  916. AllOf: []*v3.SchemaOrReference{fieldSchema},
  917. }}}
  918. }
  919. }
  920. if schema, ok := fieldSchema.Oneof.(*v3.SchemaOrReference_Schema); ok {
  921. schema.Schema.Description = description
  922. schema.Schema.ReadOnly = outputOnly
  923. schema.Schema.WriteOnly = inputOnly
  924. // Merge any `Property` annotations with the current
  925. extProperty := proto.GetExtension(field.Desc.Options(), v3.E_Property)
  926. if extProperty != nil {
  927. proto.Merge(schema.Schema, extProperty.(*v3.Schema))
  928. }
  929. }
  930. definitionProperties.AdditionalProperties = append(
  931. definitionProperties.AdditionalProperties,
  932. &v3.NamedSchemaOrReference{
  933. Name: g.reflect.formatFieldName(field.Desc),
  934. Value: fieldSchema,
  935. },
  936. )
  937. }
  938. schema := &v3.Schema{
  939. Type: "object",
  940. Description: messageDescription,
  941. Properties: definitionProperties,
  942. Required: required,
  943. }
  944. // Merge any `Schema` annotations with the current
  945. extSchema := proto.GetExtension(message.Desc.Options(), v3.E_Schema)
  946. if extSchema != nil {
  947. proto.Merge(schema, extSchema.(*v3.Schema))
  948. }
  949. // Add the schema to the components.schema list.
  950. g.addSchemaToDocumentV3(d, &v3.NamedSchemaOrReference{
  951. Name: schemaName,
  952. Value: &v3.SchemaOrReference{
  953. Oneof: &v3.SchemaOrReference_Schema{
  954. Schema: schema,
  955. },
  956. },
  957. })
  958. }
  959. }