Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,8 @@ private MethodBodyStatements BuildMessage(
}
else
{
statements.AddRange(AppendHeaderParameters(request, operation, paramMap));
var contentParam = signature.Parameters.FirstOrDefault(p => p.Name == "content" && p.Location == ParameterLocation.Body);
statements.AddRange(AppendHeaderParameters(request, operation, paramMap, contentParam: contentParam));
statements.AddRange(GetSetContent(request, signature.Parameters));
}

Expand Down Expand Up @@ -377,7 +378,7 @@ private PropertyProvider GetClassifier(InputOperation operation)
throw new InvalidOperationException($"Unexpected status codes for operation {operation.Name}");
}

private IEnumerable<MethodBodyStatement> AppendHeaderParameters(HttpRequestApi request, InputOperation operation, Dictionary<string, ParameterProvider> paramMap, bool isNextLink = false)
private IEnumerable<MethodBodyStatement> AppendHeaderParameters(HttpRequestApi request, InputOperation operation, Dictionary<string, ParameterProvider> paramMap, bool isNextLink = false, ParameterProvider? contentParam = null)
{
List<MethodBodyStatement> statements = new(operation.Parameters.Count);

Expand Down Expand Up @@ -425,6 +426,18 @@ private IEnumerable<MethodBodyStatement> AppendHeaderParameters(HttpRequestApi r
{
statement = BuildQueryOrHeaderOrPathParameterNullCheck(type, valueExpression, statement);
}
// If this is a Content-Type header and there's an optional content parameter, wrap in content null check
else if (inputHeaderParameter.IsContentType && contentParam != null)
{
// Check if any body parameter in the operation is optional
var hasOptionalBody = operation.Parameters.Any(p =>
p is InputBodyParameter bodyParam && !bodyParam.IsRequired);

if (hasOptionalBody)
{
statement = new IfStatement(contentParam.NotEqual(Null)) { statement };
}
}

statements.Add(statement);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1492,5 +1492,97 @@ public void TestApiVersionParameterReinjectedInCreateNextRequestMethod()
Assert.That(file.Content, Contains.Substring("api-version"));
Assert.That(file.Content, Contains.Substring("maxpagesize"));
}

[Test]
public void ContentTypeHeaderWrappedInNullCheckWhenContentIsOptional()
{
// Test that when there's an optional body parameter with a Content-Type header,
// the Content-Type header setting is wrapped in a null check for the content parameter
var contentTypeParam = InputFactory.HeaderParameter(
"Content-Type",
InputFactory.Literal.String("application/json"),
isRequired: true,
isContentType: true,
scope: InputParameterScope.Constant);
var bodyParam = InputFactory.BodyParameter(
"body",
InputPrimitiveType.String,
isRequired: false);
var operation = InputFactory.Operation(
"TestOperation",
requestMediaTypes: ["application/json"],
parameters: [contentTypeParam, bodyParam]);
var inputServiceMethod = InputFactory.BasicServiceMethod("Test", operation);
var inputClient = InputFactory.Client("TestClient", methods: [inputServiceMethod]);
MockHelpers.LoadMockGenerator(clients: () => [inputClient]);

var client = ScmCodeModelGenerator.Instance.TypeFactory.CreateClient(inputClient);
Assert.IsNotNull(client);

var restClient = client!.RestClient;
Assert.IsNotNull(restClient);

var createMethod = restClient.Methods.FirstOrDefault(m => m.Signature.Name == "CreateTestOperationRequest");
Assert.IsNotNull(createMethod, "CreateTestOperationRequest method not found");

var statements = createMethod!.BodyStatements as MethodBodyStatements;
Assert.IsNotNull(statements);

var expectedStatement = @"if ((content != null))
{
request.Headers.Set(""Content-Type"", ""application/json"");
}
";
var statementsString = string.Join("\n", statements!.Select(s => s.ToDisplayString()));
Assert.IsTrue(statements!.Any(s => s.ToDisplayString() == expectedStatement),
$"Expected to find statement:\n{expectedStatement}\nBut got statements:\n{statementsString}");
}

[Test]
public void ContentTypeHeaderNotWrappedInNullCheckWhenContentIsRequired()
{
// Test that when there's a required body parameter with a Content-Type header,
// the Content-Type header setting is NOT wrapped in a null check
var contentTypeParam = InputFactory.HeaderParameter(
"Content-Type",
InputFactory.Literal.String("application/json"),
isRequired: true,
isContentType: true,
scope: InputParameterScope.Constant);
var bodyParam = InputFactory.BodyParameter(
"body",
InputPrimitiveType.String,
isRequired: true);
var operation = InputFactory.Operation(
"TestOperation",
requestMediaTypes: ["application/json"],
parameters: [contentTypeParam, bodyParam]);
var inputServiceMethod = InputFactory.BasicServiceMethod("Test", operation);
var inputClient = InputFactory.Client("TestClient", methods: [inputServiceMethod]);
MockHelpers.LoadMockGenerator(clients: () => [inputClient]);

var client = ScmCodeModelGenerator.Instance.TypeFactory.CreateClient(inputClient);
Assert.IsNotNull(client);

var restClient = client!.RestClient;
Assert.IsNotNull(restClient);

var createMethod = restClient.Methods.FirstOrDefault(m => m.Signature.Name == "CreateTestOperationRequest");
Assert.IsNotNull(createMethod, "CreateTestOperationRequest method not found");

var statements = createMethod!.BodyStatements as MethodBodyStatements;
Assert.IsNotNull(statements);

// Verify there's no if statement wrapping the Content-Type header
var wrappedStatement = @"if ((content != null))
{
request.Headers.Set(""Content-Type"", ""application/json"");
}
";
var statementsString = string.Join("\n", statements!.Select(s => s.ToDisplayString()));
var hasIfWrappedContentType = statements!.Any(s => s.ToDisplayString().Contains(wrappedStatement));
Assert.IsFalse(hasIfWrappedContentType,
$"Content-Type should NOT be wrapped in an if statement for required content, but found:\n{statementsString}");
}
}
}
Loading