Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions packages/http-client-csharp/docs/emitter.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@ Set to `false` to skip generation of convenience methods. The default value is `

### `unreferenced-types-handling`

**Type:** `"removeOrInternalize" | "internalize" | "keepAll"`
**Type:** `"internalize" | "keepAll"`

Defines the strategy on how to handle unreferenced types. The default value is `removeOrInternalize`.
Defines the strategy on how to handle unreferenced types. The default value is `internalize`.

### `new-project`

Expand Down
6 changes: 3 additions & 3 deletions packages/http-client-csharp/emitter/src/options.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ type ApiVersionSelection = string | Record<string, string>;

export interface CSharpEmitterOptions {
"api-version"?: ApiVersionSelection;
"unreferenced-types-handling"?: "removeOrInternalize" | "internalize" | "keepAll";
"unreferenced-types-handling"?: "internalize" | "keepAll";
"new-project"?: boolean;
"save-inputs"?: boolean;
debug?: boolean;
Expand Down Expand Up @@ -61,10 +61,10 @@ export const CSharpEmitterOptionsSchema: JSONSchemaType<CSharpEmitterOptions> =
},
"unreferenced-types-handling": {
type: "string",
enum: ["removeOrInternalize", "internalize", "keepAll"],
enum: ["internalize", "keepAll"],
nullable: true,
description:
"Defines the strategy on how to handle unreferenced types. The default value is `removeOrInternalize`.",
"Defines the strategy on how to handle unreferenced types. The default value is `internalize`.",
},
"new-project": {
type: "boolean",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

export interface Configuration {
"package-name": string | null;
"unreferenced-types-handling"?: "removeOrInternalize" | "internalize" | "keepAll";
"unreferenced-types-handling"?: "internalize" | "keepAll";
"disable-xml-docs"?: boolean;
"disable-roslyn-reduce"?: boolean;
license?: {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ describe("Configuration tests", async () => {
}
const customOptions: TestEmitterOptions = {
"package-name": "custom-package",
"unreferenced-types-handling": "removeOrInternalize",
"unreferenced-types-handling": "internalize",
"disable-xml-docs": true,
"disable-roslyn-reduce": true,
license: {
Expand All @@ -141,7 +141,7 @@ describe("Configuration tests", async () => {
const config = createConfiguration(customOptions, "rootNamespace", sdkContext);

expect(config["package-name"]).toBe("custom-package");
expect(config["unreferenced-types-handling"]).toBe("removeOrInternalize");
expect(config["unreferenced-types-handling"]).toBe("internalize");
expect(config["disable-xml-docs"]).toBe(true);
expect(config["disable-roslyn-reduce"]).toBe(true);
expect(config.license).toEqual({
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,8 @@ public virtual MethodBodyStatement SerializeXmlValue(
SerializationFormat format)
=> MrwSerializationTypeDefinition.SerializeXmlValueCore(valueType, value, xmlWriter, mrwOptionsParameter, format);

protected override ModelProvider? CreateModelCore(InputModelType model) => new ScmModelProvider(model);
protected override ModelProvider? CreateModelCore(InputModelType model)
=> model.IsFileType ? null : new ScmModelProvider(model);

protected override ModelFactoryProvider CreateModelFactoryCore(IEnumerable<InputModelType> models)
=> new ScmModelFactoryProvider(models);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

using System.Collections.Generic;
using System.Linq;
using System.Reflection;
using Microsoft.TypeSpec.Generator.ClientModel.Providers;
using Microsoft.TypeSpec.Generator.Input;
using Microsoft.TypeSpec.Generator.Primitives;
Expand Down Expand Up @@ -146,6 +147,18 @@ public void TestCreateSerializations_ReturnsBothMrwAndMultipart_WhenJsonAndMpfdU
"Expected a multipart serialization provider for a model with MultipartFormData usage.");
}

[Test]
public void FileTypeDoesNotCreateModelProvider()
{
var file = InputFactory.Model("File", @namespace: "TypeSpec.Http");
typeof(InputModelType).GetProperty(nameof(InputModelType.IsFileType))!.SetValue(file, true);

MockHelpers.LoadMockGenerator(inputModels: () => [file]);

var provider = ScmCodeModelGenerator.Instance.TypeFactory.CreateModel(file);
Assert.IsNull(provider);
}

// ScmTypeFactory overrides CreateModelCore to return ScmModelProvider. External-type
// handling lives in the (non-overridable) base TypeFactory.CreateModel, so it must still
// apply here. This guards against regressing the fix by re-introducing external handling
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,19 @@ await GeneratedCodeWorkspace.LoadBaselineContract(),
{
// Ensure back-compatibility processing is done after all visitors have run
outputType.ProcessTypeForBackCompatibility();
}

PostProcessTypeProviders(output.TypeProviders);

LoggingHelpers.LogElapsedTime("All generated type providers post-processed");

var modelFactory = output.ModelFactory.Value;
foreach (var outputType in output.TypeProviders)
{
if (ReferenceEquals(outputType, modelFactory) && outputType.Methods.Count == 0)
{
continue;
}

var writer = CodeModelGenerator.Instance.GetWriter(outputType);
generateFilesTasks.Add(generatedCodeWorkspace.AddGeneratedFile(writer.Write()));
Expand All @@ -111,8 +124,6 @@ await GeneratedCodeWorkspace.LoadBaselineContract(),

LoggingHelpers.LogElapsedTime("All old generated files have been deleted");

await generatedCodeWorkspace.PostProcessAsync();

// Write the generated files to the output directory
await foreach (var file in generatedCodeWorkspace.GetGeneratedFilesAsync())
{
Expand All @@ -138,6 +149,21 @@ await GeneratedCodeWorkspace.LoadBaselineContract(),
LoggingHelpers.LogElapsedTime("All files have been written to disk");
}

private static void PostProcessTypeProviders(IReadOnlyList<TypeProvider> typeProviders)
{
if (Configuration.UnreferencedTypesHandling == Configuration.UnreferencedTypesHandlingOption.KeepAll)
{
return;
}

var modelFactory = CodeModelGenerator.Instance.OutputLibrary.ModelFactory.Value;
var postProcessor = new PostProcessor(
[.. CodeModelGenerator.Instance.TypeFactory.UnionVariantTypesToKeep, .. CodeModelGenerator.Instance.AdditionalRootTypes],
modelFactoryFullName: modelFactory.Type.FullyQualifiedName,
additionalNonRootTypeNames: CodeModelGenerator.Instance.NonRootTypes);
postProcessor.Internalize(typeProviders);
}

internal static void FilterAllCustomizedMembers(OutputLibrary output)
{
foreach (var typeProvider in output.TypeProviders)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,8 @@ public class Configuration
{
public enum UnreferencedTypesHandlingOption
{
RemoveOrInternalize = 0,
Internalize = 1,
KeepAll = 2
Internalize = 0,
KeepAll = 1
}

private const string GeneratedFolderName = "Generated";
Expand Down Expand Up @@ -83,7 +82,7 @@ private static class Options
/// </summary>
public LicenseInfo? LicenseInfo { get; }

internal static UnreferencedTypesHandlingOption UnreferencedTypesHandling { get; private set; } = UnreferencedTypesHandlingOption.RemoveOrInternalize;
internal static UnreferencedTypesHandlingOption UnreferencedTypesHandling { get; private set; } = UnreferencedTypesHandlingOption.Internalize;

private string? _projectDirectory;
internal string ProjectDirectory => _projectDirectory ??= Path.Combine(OutputDirectory, "src");
Expand Down Expand Up @@ -253,7 +252,7 @@ private static T ReadEnumOption<T>(JsonElement root, string option) where T : st

public static Enum? GetDefaultEnumOptionValue(string option) => option switch
{
Options.UnreferencedTypesHandling => UnreferencedTypesHandlingOption.RemoveOrInternalize,
Options.UnreferencedTypesHandling => UnreferencedTypesHandlingOption.Internalize,
_ => null
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -260,33 +260,6 @@ internal static Project AddDirectory(Project project, string directory, Func<str
return project;
}

/// <summary>
/// This method invokes the postProcessor to do some post processing work
/// Depending on the configuration, it will either remove + internalize, just internalize or do nothing
/// </summary>
public async Task PostProcessAsync()
{
var modelFactory = CodeModelGenerator.Instance.OutputLibrary.ModelFactory.Value;
var nonRootTypes = CodeModelGenerator.Instance.NonRootTypes;
var postProcessor = new PostProcessor(
[.. CodeModelGenerator.Instance.TypeFactory.UnionVariantTypesToKeep, .. CodeModelGenerator.Instance.AdditionalRootTypes],
modelFactoryFullName: modelFactory.Type.FullyQualifiedName,
additionalNonRootTypeNames: nonRootTypes);

switch (Configuration.UnreferencedTypesHandling)
{
case Configuration.UnreferencedTypesHandlingOption.KeepAll:
break;
case Configuration.UnreferencedTypesHandlingOption.Internalize:
_project = await postProcessor.InternalizeAsync(_project);
break;
case Configuration.UnreferencedTypesHandlingOption.RemoveOrInternalize:
_project = await postProcessor.InternalizeAsync(_project);
_project = await postProcessor.RemoveAsync(_project);
break;
}
}

/// <summary>
/// Resolves PackageReference items from the project's .csproj file and adds their assemblies
/// as metadata references so that custom code referencing external NuGet types compiles correctly.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Simplification;
using Microsoft.TypeSpec.Generator.Primitives;
using Microsoft.TypeSpec.Generator.Providers;

namespace Microsoft.TypeSpec.Generator
{
Expand Down Expand Up @@ -113,6 +115,95 @@ private async Task<TypeSymbols> GetTypeSymbolsAsync(Compilation compilation,
protected virtual bool ShouldIncludeDocument(Document document) =>
!GeneratedCodeWorkspace.IsGeneratedTestDocument(document);

public void Internalize(IReadOnlyList<TypeProvider> typeProviders)
{
var allProviders = ProviderReferenceMapBuilder.GetAllProviders(typeProviders).ToArray();
var candidateProviders = allProviders
.Where(IsPublicType)
.Where(provider => !IsExcludedProvider(provider))
.ToArray();
var rootProviders = candidateProviders.Where(IsRootProvider).ToArray();
var referenceMap = new ProviderReferenceMapBuilder(typeProviders).BuildPublicReferenceMap(rootProviders);
var referencedProviders = VisitProvidersFromRoot(rootProviders, referenceMap).ToHashSet();
var providersToInternalize = candidateProviders
.Where(provider => !referencedProviders.Contains(provider))
.ToArray();

foreach (var provider in providersToInternalize)
{
provider.Update(modifiers: MakeInternal(provider.DeclarationModifiers));
}

RemoveMethodsFromModelFactory(providersToInternalize.Select(provider => provider.Name).ToHashSet());
}

private bool IsRootProvider(TypeProvider provider)
=> IsClientProvider(provider) || provider.CustomCodeView != null || ShouldKeepProvider(provider, _typesToKeep);

private bool IsExcludedProvider(TypeProvider provider)
=> IsModelFactoryProvider(provider) || ShouldKeepProvider(provider, _additionalNonRootTypeNames);

private bool IsModelFactoryProvider(TypeProvider provider)
=> _modelFactoryFullName != null && provider.Type.FullyQualifiedName == _modelFactoryFullName;

private static bool IsClientProvider(TypeProvider provider)
=> provider.Name.EndsWith("Client", StringComparison.Ordinal);

private static bool ShouldKeepProvider(TypeProvider provider, HashSet<string> typesToKeep)
=> typesToKeep.Contains(provider.Name) || typesToKeep.Contains(provider.Type.FullyQualifiedName);

private static bool IsPublicType(TypeProvider provider)
=> provider.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Public);

private static TypeSignatureModifiers MakeInternal(TypeSignatureModifiers modifiers)
=> (modifiers & ~(TypeSignatureModifiers.Public | TypeSignatureModifiers.Private | TypeSignatureModifiers.Protected)) | TypeSignatureModifiers.Internal;

private static IEnumerable<TypeProvider> VisitProvidersFromRoot(
IEnumerable<TypeProvider> rootProviders,
IReadOnlyDictionary<TypeProvider, IReadOnlyList<TypeProvider>> referenceMap)
{
var queue = new Queue<TypeProvider>(rootProviders);
var visited = new HashSet<TypeProvider>();
while (queue.Count > 0)
{
var provider = queue.Dequeue();
if (!visited.Add(provider))
{
continue;
}

yield return provider;
if (!referenceMap.TryGetValue(provider, out var references))
{
continue;
}

foreach (var reference in references)
{
queue.Enqueue(reference);
}
}
}

private void RemoveMethodsFromModelFactory(HashSet<string> namesToRemove)
{
if (_modelFactoryFullName == null || namesToRemove.Count == 0)
{
return;
}

var modelFactory = CodeModelGenerator.Instance.OutputLibrary.ModelFactory.Value;
if (modelFactory.Type.FullyQualifiedName != _modelFactoryFullName)
{
return;
}

var methodsToKeep = modelFactory.Methods
.Where(method => !namesToRemove.Contains(method.Signature.Name))
.ToArray();
modelFactory.Update(methods: methodsToKeep);
}

/// <summary>
/// This method marks the "not publicly" referenced types as internal if they are previously defined as public. It will do this job in the following steps:
/// 1. This method will read all the public types defined in the given <paramref name="project"/>, and build a cache for those symbols
Expand All @@ -133,12 +224,12 @@ public async Task<Project> InternalizeAsync(Project project)

// first get all the declared symbols
var definitions = await GetTypeSymbolsAsync(compilation, project, true);
// get the root symbols
var rootSymbols = await GetRootSymbolsAsync(project, definitions);
// build the reference map
var referenceMap =
await new ReferenceMapBuilder(compilation, project).BuildPublicReferenceMapAsync(
definitions.DeclaredSymbols, definitions.DeclaredNodesCache);
// get the root symbols
var rootSymbols = await GetRootSymbolsAsync(project, definitions);
// traverse all the root and recursively add all the things we met
var publicSymbols = VisitSymbolsFromRootAsync(rootSymbols, referenceMap);

Expand Down
Loading
Loading