Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,41 @@ public virtual void AddSharedSourceDirectory(string sharedSourceDirectory)
_sharedSourceDirectories.Add(sharedSourceDirectory);
}

internal HashSet<string> AdditionalRootTypes { get; } = [];
private record KeptTypesInfo(HashSet<string> TypeNames, HashSet<TypeProvider> TypeProviders);

internal HashSet<string> NonRootTypes { get; } = [];
private readonly KeptTypesInfo _additionalRootTypeInfo = new([], []);
private readonly KeptTypesInfo _nonRootTypeInfo = new([], []);

private HashSet<string>? _additionalRootTypes;
private HashSet<string>? _nonRootTypes;

/// <summary>
/// The set of fully qualified type names to keep as roots. Resolved lazily so that
/// <see cref="TypeProvider"/> entries added via <see cref="AddTypeToKeep(TypeProvider, bool)"/>
/// are not forced to materialize their <see cref="TypeProvider.Type"/> at registration time
/// (which would dispatch virtual <c>Build*</c> methods on partially constructed providers).
/// </summary>
internal HashSet<string> AdditionalRootTypes => _additionalRootTypes ??= MaterializeKeepSet(_additionalRootTypeInfo);

/// <summary>
/// The set of fully qualified type names to keep as non-roots. Resolved lazily; see
/// <see cref="AdditionalRootTypes"/> for rationale.
/// </summary>
internal HashSet<string> NonRootTypes => _nonRootTypes ??= MaterializeKeepSet(_nonRootTypeInfo);

private static HashSet<string> MaterializeKeepSet(KeptTypesInfo info)
{
if (info.TypeProviders.Count == 0)
{
return info.TypeNames;
}
var result = new HashSet<string>(info.TypeNames);
foreach (var provider in info.TypeProviders)
{
result.Add(provider.Type.FullyQualifiedName);
}
return result;
}

/// <summary>
/// Adds a type to the list of types to keep.
Expand All @@ -174,21 +206,50 @@ public void AddTypeToKeep(string typeName, bool isRoot = true)
{
if (isRoot)
{
AdditionalRootTypes.Add(typeName);
if (_additionalRootTypeInfo.TypeNames.Add(typeName))
{
_additionalRootTypes = null;
}
}
else
{
NonRootTypes.Add(typeName);
if (_nonRootTypeInfo.TypeNames.Add(typeName))
{
_nonRootTypes = null;
}
}
}

/// <summary>
/// Adds a type to the list of types to keep.
/// </summary>
/// <remarks>
/// The provider's fully qualified name is resolved lazily, when the keep list is consumed during
/// post-processing. This makes it safe to call this method from a <see cref="TypeProvider"/>
/// constructor (including base constructors that run before the derived constructor body), since
/// it does not force evaluation of <see cref="TypeProvider.Type"/> — which would dispatch virtual
/// <c>Build*</c> methods on a not-yet-fully-constructed instance.
/// </remarks>
/// <param name="type">The type provider representing the type.</param>
/// <param name="isRoot">Whether to treat the type as a root type. Any dependencies of root types will
/// not have their accessibility changed regardless of the 'unreferenced-types-handling' value.</param>
public void AddTypeToKeep(TypeProvider type, bool isRoot = true) => AddTypeToKeep(type.Type.FullyQualifiedName, isRoot);
public void AddTypeToKeep(TypeProvider type, bool isRoot = true)
{
if (isRoot)
{
if (_additionalRootTypeInfo.TypeProviders.Add(type))
{
_additionalRootTypes = null;
}
}
else
{
if (_nonRootTypeInfo.TypeProviders.Add(type))
{
_nonRootTypes = null;
}
}
}

/// <summary>
/// Writes additional output files (e.g. configuration schemas) after the main code generation is complete.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,22 +72,17 @@ public ModelProvider(InputModelType inputModel) : base(inputModel)
_inputModel = inputModel;
_isMultiLevelDiscriminator = ComputeIsMultiLevelDiscriminator();
_useObjectAdditionalProperties = new Lazy<bool>(ShouldUseObjectAdditionalProperties);

if (_inputModel.BaseModel is not null)
{
DiscriminatorValueExpression = EnsureDiscriminatorValueExpression();
}

if (_inputModel.Access == "public")
{
CodeModelGenerator.Instance.AddTypeToKeep(this);
}
}

public bool IsUnknownDiscriminatorModel => _inputModel.IsUnknownDiscriminatorModel;

public string? DiscriminatorValue => _inputModel.DiscriminatorValue;
public ValueExpression? DiscriminatorValueExpression { get; init; }

private ValueExpression? _discriminatorValueExpression;
public ValueExpression? DiscriminatorValueExpression =>
_inputModel.BaseModel is not null
? _discriminatorValueExpression ??= EnsureDiscriminatorValueExpression()
: null;

private IReadOnlyList<ModelProvider>? _derivedModels;
public IReadOnlyList<ModelProvider> DerivedModels => _derivedModels ??= BuildDerivedModels();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,11 @@ protected internal TypeFactory()

if (modelProvider != null)
{
if (model.Access == "public")
{
CodeModelGenerator.Instance.AddTypeToKeep(modelProvider);
}

CSharpTypeMap[modelProvider.Type] = modelProvider;
TypeProvidersByName[modelProvider.Type.Name] = modelProvider;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1480,6 +1480,106 @@ public void InternalModelsAreNotIncludedInAdditionalRootTypes()
Assert.IsFalse(rootTypes.Contains("Sample.Models.MockInputModel"));
}

[Test]
public void KeepSetsReflectTypeProvidersAddedAfterFirstAccess()
{
var inputModel = InputFactory.Model("MockInputModel", access: "public");
MockHelpers.LoadMockGenerator(inputModelTypes: [inputModel]);
var provider = new DerivedModelProviderReadingOwnField(inputModel);

_ = CodeModelGenerator.Instance.AdditionalRootTypes;
_ = CodeModelGenerator.Instance.NonRootTypes;

CodeModelGenerator.Instance.AddTypeToKeep(provider);
CodeModelGenerator.Instance.AddTypeToKeep(provider, isRoot: false);

var fullyQualifiedName = provider.Type.FullyQualifiedName;
Assert.IsTrue(CodeModelGenerator.Instance.AdditionalRootTypes.Contains(fullyQualifiedName));
Assert.IsTrue(CodeModelGenerator.Instance.NonRootTypes.Contains(fullyQualifiedName));
}

// Regression test for two complementary fixes:
//
// 1. ModelProvider no longer registers itself with AddTypeToKeep from its constructor;
// registration is performed by TypeFactory.CreateModel after construction completes.
// This mirrors the EnumProvider lifecycle and prevents a virtual call chain
// (AddTypeToKeep -> TypeProvider.Type -> BaseType -> virtual BuildBaseType()) from
// being dispatched on a partially-constructed derived ModelProvider whose override
// reads derived-class fields that are still uninitialized.
//
// 2. AddTypeToKeep(TypeProvider) defers FQN resolution until the keep set is consumed,
// so even ctor-time callers cannot force premature TypeProvider.Type evaluation.
[Test]
public void DerivedModelProviderConstructionDoesNotForceTypeEvaluation()
{
var inputModel = InputFactory.Model("MockInputModel", access: "public");
MockHelpers.LoadMockGenerator(inputModelTypes: [inputModel]);

// (1) Constructing a derived ModelProvider whose BuildBaseType reads a derived field
// must not throw.
DerivedModelProviderReadingOwnField? provider = null;
Assert.DoesNotThrow(() => provider = new DerivedModelProviderReadingOwnField(inputModel));

// (2) AddTypeToKeep(TypeProvider) must not throw and the provider's FQN must appear
// once the keep set is materialized.
Assert.DoesNotThrow(() => CodeModelGenerator.Instance.AddTypeToKeep(provider!));
var rootTypes = CodeModelGenerator.Instance.AdditionalRootTypes;
Assert.IsTrue(rootTypes.Contains(provider!.Type.FullyQualifiedName));
}

private sealed class DerivedModelProviderReadingOwnField : ModelProvider
{
private readonly InputModelType _derivedInputModel;

public DerivedModelProviderReadingOwnField(InputModelType inputModel) : base(inputModel)
{
_derivedInputModel = inputModel;
}

protected override CSharpType? BuildBaseType()
{
// Reading a derived-class field that base(...) cannot have populated yet.
// If the framework forces Type evaluation during base ctor, this NREs.
_ = _derivedInputModel.DiscriminatorValue;
return base.BuildBaseType();
}
}

// Regression for the second virtual-call-in-ctor offender: ModelProvider..ctor used to
// eagerly compute DiscriminatorValueExpression, which read BaseModelProvider and thus
// virtually dispatched BuildBaseType()/BuildBaseModel() onto a partially-constructed
// derived class. Surfaced while validating the Cdn provisioning migration (the keep-set
// fix alone was not sufficient when the model has a base + discriminator value).
[Test]
public void DerivedModelProviderConstructionDoesNotForceDiscriminatorEvaluation()
{
var discriminatorEnum = InputFactory.StringEnum("kindEnum", [("One", "one"), ("Two", "two")]);
var baseInputModel = InputFactory.Model(
"BaseModel",
properties:
[
InputFactory.Property("kind", discriminatorEnum, isRequired: false, isDiscriminator: true),
]);
var derivedInputModel = InputFactory.Model(
"DerivedModel",
baseModel: baseInputModel,
discriminatedKind: "one",
properties:
[
InputFactory.Property("kind", InputFactory.EnumMember.String("One", "one", discriminatorEnum), isRequired: true, isDiscriminator: true),
]);
MockHelpers.LoadMockGenerator(inputModelTypes: [baseInputModel, derivedInputModel]);

// Constructing a derived ModelProvider whose BuildBaseType reads a derived field
// must not throw, even when the input model has a base + discriminator value.
DerivedModelProviderReadingOwnField? provider = null;
Assert.DoesNotThrow(() => provider = new DerivedModelProviderReadingOwnField(derivedInputModel));

// The discriminator expression must still be available once consumed lazily
// (callers under emission/serialization rely on it).
Assert.DoesNotThrow(() => { _ = provider!.DiscriminatorValueExpression; });
}

[TestCase(true, true, InputModelTypeUsage.Output, true, false)]
[TestCase(true, false, InputModelTypeUsage.Output, true, false)]
[TestCase(false, true, InputModelTypeUsage.Output, true, false)]
Expand Down
Loading