Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 38 additions & 12 deletions codegen/internal/generator/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@ type sdkModel struct {

// clientModel encapsulates the data necessary to render a Java API client.
type clientModel struct {
TagName string
TagName string
TagDescriptionLines []string
ClassName string
AccessorName string
FieldName string
Package string
Methods []operationModel
ClassName string
AccessorName string
FieldName string
Package string
Methods []operationModel
}

// operationModel stores the derived metadata for one OpenAPI operation.
Expand Down Expand Up @@ -242,9 +242,9 @@ func convertOperation(method, path string, item *v3.PathItem, op *v3.Operation,
}

params := collectParameters(item, op)
model.PathParams = filterParams(params, parameterInPath, resolver)
model.PathParams = filterParams(params, parameterInPath, resolver, sanitizedID)

queryParams := filterParams(params, parameterInQuery, resolver)
queryParams := filterParams(params, parameterInQuery, resolver, sanitizedID)
model.RequiredQueryParams, model.OptionalQueryParams = splitParams(queryParams)
if len(model.OptionalQueryParams) > 0 {
model.QueryStruct = &parameterGroupModel{
Expand All @@ -260,7 +260,7 @@ func convertOperation(method, path string, item *v3.PathItem, op *v3.Operation,
}
model.HasQueryParams = len(model.RequiredQueryParams) > 0 || len(model.OptionalQueryParams) > 0

headerParams := filterParams(params, parameterInHeader, resolver)
headerParams := filterParams(params, parameterInHeader, resolver, sanitizedID)
model.RequiredHeaderParams, model.OptionalHeaderParams = splitParams(headerParams)
if len(model.OptionalHeaderParams) > 0 {
model.HeaderStruct = &parameterGroupModel{
Expand Down Expand Up @@ -356,7 +356,7 @@ func collectParameters(item *v3.PathItem, op *v3.Operation) []*v3.Parameter {

// filterParams keeps parameters that match the requested location and converts
// them into parameterModel values.
func filterParams(params []*v3.Parameter, location string, resolver *typeResolver) []parameterModel {
func filterParams(params []*v3.Parameter, location string, resolver *typeResolver, operationContext string) []parameterModel {
var filtered []parameterModel
seen := map[string]struct{}{}
for _, param := range params {
Expand All @@ -375,7 +375,7 @@ func filterParams(params []*v3.Parameter, location string, resolver *typeResolve
}
seen[name] = struct{}{}
schemaRef := parameterSchema(param)
javaType := resolver.javaType(schemaRef, name)
javaType := resolver.parameterJavaType(schemaRef, operationContext, name)
required := param.Required != nil && *param.Required
filtered = append(filtered, parameterModel{
Name: name,
Expand Down Expand Up @@ -568,6 +568,7 @@ func buildSchemas(doc *v3.Document, params Params, resolver *typeResolver) []sch
imports := sortedImports(map[string]struct{}{
"com.fasterxml.jackson.annotation.JsonCreator": {},
"com.fasterxml.jackson.annotation.JsonValue": {},
"java.util.Objects": {},
})
result = append(result, schemaModel{
Name: name,
Expand Down Expand Up @@ -630,7 +631,7 @@ func buildSchemaFields(name string, ref *base.SchemaProxy, resolver *typeResolve
if schema == nil {
return []schemaField{{Name: "value", Type: "Object"}}, nil, nil, false
}
if schemaHasType(schema, "object") {
if schemaDefinesModelFields(schema) {
props := collectProperties(schema)
if len(props) == 0 {
return []schemaField{{Name: "value", Type: "java.util.Map<String, Object>"}}, nil, []string{"java.util.Map"}, false
Expand Down Expand Up @@ -695,6 +696,31 @@ func buildSchemaFields(name string, ref *base.SchemaProxy, resolver *typeResolve
return fields, additionalProps, sortedImports(imports), hasRequired
}

func schemaDefinesModelFields(schema *base.Schema) bool {
if schema == nil {
return false
}
return schemaHasType(schema, "object") || schemaHasFlattenableProperties(schema)
}

func schemaHasFlattenableProperties(schema *base.Schema) bool {
if schema == nil {
return false
}
if schema.Properties != nil && schema.Properties.Len() > 0 {
return true
}
for _, item := range schema.AllOf {
if item == nil {
continue
}
if schemaHasFlattenableProperties(item.Schema()) {
return true
}
}
return false
}

// resolveAdditionalProperties returns metadata for schemas that allow
// additional fields alongside declared properties.
func resolveAdditionalProperties(schema *base.Schema, resolver *typeResolver, context ...string) *additionalPropertiesModel {
Expand Down
177 changes: 177 additions & 0 deletions codegen/internal/generator/model_naming_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
package generator

import (
"context"
"os"
"path/filepath"
"testing"
)

func TestGenerateModelFlattensAllOfIntoComponent(t *testing.T) {
t.Parallel()

tmp := t.TempDir()
specPath := filepath.Join(tmp, "openapi.json")
outputDir := filepath.Join(tmp, "src", "main", "java")
resourceDir := filepath.Join(tmp, "src", "main", "resources")

spec := `{
"openapi": "3.0.3",
"info": {
"title": "test",
"version": "1.0.0"
},
"paths": {},
"components": {
"schemas": {
"Merchant": {
"allOf": [
{
"type": "object",
"required": ["merchant_code"],
"properties": {
"merchant_code": { "type": "string" }
}
},
{
"$ref": "#/components/schemas/Timestamps"
}
],
"title": "Merchant"
},
"Timestamps": {
"type": "object",
"properties": {
"created_at": { "type": "string", "format": "date-time", "readOnly": true }
}
}
}
}
}`
if err := os.WriteFile(specPath, []byte(spec), 0o644); err != nil {
t.Fatalf("write spec: %v", err)
}

params := Params{
SpecPath: specPath,
OutputDir: outputDir,
ResourceDir: resourceDir,
BasePackage: "com.test.sdk",
}
if err := Run(context.Background(), params); err != nil {
t.Fatalf("run generator: %v", err)
}

merchantPath := filepath.Join(outputDir, "com", "test", "sdk", "models", "Merchant.java")
content, err := os.ReadFile(merchantPath)
if err != nil {
t.Fatalf("read generated Merchant model: %v", err)
}
generated := string(content)

assertContains(t, generated, "public record Merchant(")
assertContains(t, generated, "String merchantCode")
assertContains(t, generated, "java.time.OffsetDateTime createdAt")
assertNotContains(t, generated, "Merchant2")
assertFileDoesNotExist(t, filepath.Join(outputDir, "com", "test", "sdk", "models", "Merchant2.java"))
}

func TestGenerateClientUsesOpenEnumsForInlineParameterEnums(t *testing.T) {
t.Parallel()

tmp := t.TempDir()
specPath := filepath.Join(tmp, "openapi.json")
outputDir := filepath.Join(tmp, "src", "main", "java")
resourceDir := filepath.Join(tmp, "src", "main", "resources")

spec := `{
"openapi": "3.0.3",
"info": {
"title": "test",
"version": "1.0.0"
},
"paths": {
"/v1/transactions": {
"get": {
"operationId": "ListTransactions",
"parameters": [
{
"name": "order",
"in": "query",
"schema": {
"type": "string",
"enum": ["ascending", "descending"]
}
},
{
"name": "types",
"in": "query",
"schema": {
"type": "array",
"items": {
"type": "string",
"enum": ["PAYMENT", "REFUND"]
}
}
}
],
"responses": {
"204": {
"description": "No content"
}
},
"tags": ["Transactions"]
}
}
}
}`
if err := os.WriteFile(specPath, []byte(spec), 0o644); err != nil {
t.Fatalf("write spec: %v", err)
}

params := Params{
SpecPath: specPath,
OutputDir: outputDir,
ResourceDir: resourceDir,
BasePackage: "com.test.sdk",
}
if err := Run(context.Background(), params); err != nil {
t.Fatalf("run generator: %v", err)
}

clientPath := filepath.Join(outputDir, "com", "test", "sdk", "clients", "TransactionsClient.java")
content, err := os.ReadFile(clientPath)
if err != nil {
t.Fatalf("read generated Transactions client: %v", err)
}
generated := string(content)

assertContains(t, generated, "public ListTransactionsQueryParams order(com.test.sdk.models.ListTransactionsOrder value)")
assertContains(t, generated, "public ListTransactionsQueryParams types(java.util.List<com.test.sdk.models.ListTransactionsTypesItem> value)")
assertContains(t, generated, "com.test.sdk.models.ListTransactionsOrder")
assertContains(t, generated, "com.test.sdk.models.ListTransactionsTypesItem")
assertFileDoesNotExist(t, filepath.Join(outputDir, "com", "test", "sdk", "models", "Order.java"))
assertFileDoesNotExist(t, filepath.Join(outputDir, "com", "test", "sdk", "models", "TypesItem.java"))

orderPath := filepath.Join(outputDir, "com", "test", "sdk", "models", "ListTransactionsOrder.java")
orderContent, err := os.ReadFile(orderPath)
if err != nil {
t.Fatalf("read generated ListTransactionsOrder model: %v", err)
}
orderGenerated := string(orderContent)

assertContains(t, orderGenerated, "public final class ListTransactionsOrder")
assertContains(t, orderGenerated, `public static final ListTransactionsOrder ASCENDING = new ListTransactionsOrder("ascending");`)
assertContains(t, orderGenerated, "public static ListTransactionsOrder of(String value)")
assertContains(t, orderGenerated, "return value == null ? null : new ListTransactionsOrder(value);")
assertNotContains(t, orderGenerated, "public enum ListTransactionsOrder")
}

func assertFileDoesNotExist(t *testing.T, path string) {
t.Helper()
if _, err := os.Stat(path); err == nil {
t.Fatalf("expected %s to not exist", path)
} else if !os.IsNotExist(err) {
t.Fatalf("stat %s: %v", path, err)
}
}
41 changes: 27 additions & 14 deletions codegen/internal/generator/templates/model.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,25 @@ package {{ .Package }};
*/
{{- end }}
{{- if .IsEnum }}
public enum {{ .ClassName }} {
{{- range $index, $value := .EnumValues }}{{ if $index }},
{{ end }} {{ $value.Name }}({{ quote $value.WireValue }}){{ end }};
public final class {{ .ClassName }} {
{{- range .EnumValues }}
public static final {{ $.ClassName }} {{ .Name }} = new {{ $.ClassName }}({{ quote .WireValue }});
{{- end }}

private final String value;

{{ .ClassName }}(String value) {
this.value = value;
private {{ .ClassName }}(String value) {
this.value = Objects.requireNonNull(value, "value");
}

/**
* Creates a {{ .ClassName }} for a value not yet known to this SDK version.
*
* @param value Wire value sent to or received from the API.
* @return Open enum value wrapping {@code value}.
*/
public static {{ .ClassName }} of(String value) {
return new {{ .ClassName }}(value);
}

@JsonValue
Expand All @@ -35,15 +46,17 @@ public enum {{ .ClassName }} {

@JsonCreator
public static {{ .ClassName }} fromValue(String value) {
if (value == null) {
return null;
}
for ({{ .ClassName }} entry : values()) {
if (entry.value.equals(value)) {
return entry;
}
}
throw new IllegalArgumentException("Unknown {{ .ClassName }} value: " + value);
return value == null ? null : new {{ .ClassName }}(value);
}

@Override
public boolean equals(Object other) {
return this == other || (other instanceof {{ .ClassName }} that && this.value.equals(that.value));
}

@Override
public int hashCode() {
return value.hashCode();
}
}
{{- else }}
Expand Down
14 changes: 7 additions & 7 deletions codegen/internal/generator/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,13 +146,8 @@ func (r *typeResolver) javaType(ref *base.SchemaProxy, context ...string) javaTy
if schema.Properties != nil && schema.Properties.Len() > 0 {
return r.inlineObjectType(schema, context)
}
if len(schema.AllOf) > 0 {
for _, item := range schema.AllOf {
if item == nil {
continue
}
return r.javaType(item, context...)
}
if schemaHasFlattenableProperties(schema) {
return r.inlineObjectType(schema, context)
}
if len(schema.OneOf) > 0 {
for _, item := range schema.OneOf {
Expand All @@ -173,6 +168,10 @@ func (r *typeResolver) javaType(ref *base.SchemaProxy, context ...string) javaTy
return r.genericMap()
}

func (r *typeResolver) parameterJavaType(ref *base.SchemaProxy, context ...string) javaType {
return r.javaType(ref, context...)
}

// objectType handles schemas that look like objects by either emitting inline
// models or falling back to generic map types.
func (r *typeResolver) objectType(schema *base.Schema, context []string) javaType {
Expand Down Expand Up @@ -331,6 +330,7 @@ func (r *typeResolver) inlineSchemaModels(params Params) []schemaModel {
imports := sortedImports(map[string]struct{}{
"com.fasterxml.jackson.annotation.JsonCreator": {},
"com.fasterxml.jackson.annotation.JsonValue": {},
"java.util.Objects": {},
})
models = append(models, schemaModel{
Name: info.className,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ public ListMembershipsQueryParams resourceParentId(String value) {
* Pass explicit null to filter for resources without a parent.
* @return This ListMembershipsQueryParams instance.
*/
public ListMembershipsQueryParams resourceParentType(com.sumup.sdk.models.ResourceType value) {
public ListMembershipsQueryParams resourceParentType(java.util.Map<String, Object> value) {
this.values.put("resource.parent.type", Objects.requireNonNull(value, "resourceParentType"));
return this;
}
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/com/sumup/sdk/clients/MembershipsClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ public ListMembershipsQueryParams resourceParentId(String value) {
* Pass explicit null to filter for resources without a parent.
* @return This ListMembershipsQueryParams instance.
*/
public ListMembershipsQueryParams resourceParentType(com.sumup.sdk.models.ResourceType value) {
public ListMembershipsQueryParams resourceParentType(java.util.Map<String, Object> value) {
this.values.put("resource.parent.type", Objects.requireNonNull(value, "resourceParentType"));
return this;
}
Expand Down
Loading
Loading