diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/MrwSerializationTypeDefinition.Xml.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/MrwSerializationTypeDefinition.Xml.cs index 2b50fff372e..5d76b8f44a1 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/MrwSerializationTypeDefinition.Xml.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/MrwSerializationTypeDefinition.Xml.cs @@ -67,7 +67,7 @@ private MethodProvider BuildXmlModelWriteCoreMethod() MethodSignatureModifiers modifiers = _isStruct ? MethodSignatureModifiers.Private : MethodSignatureModifiers.Internal | MethodSignatureModifiers.Virtual; - if (_shouldOverrideMethods) + if (_shouldOverrideXmlMethods) { modifiers = MethodSignatureModifiers.Internal | MethodSignatureModifiers.Override; } @@ -81,7 +81,7 @@ private MethodProvider BuildXmlModelWriteCoreMethod() private MethodBodyStatement[] BuildXmlModelWriteCoreMethodBody() { - var categorizedProperties = _shouldOverrideMethods + var categorizedProperties = _shouldOverrideXmlMethods ? CategorizedXmlProperties : AllCategorizedXmlProperties; var statements = new List @@ -90,7 +90,7 @@ private MethodBodyStatement[] BuildXmlModelWriteCoreMethodBody() MethodBodyStatement.EmptyLine }; - if (_shouldOverrideMethods) + if (_shouldOverrideXmlMethods) { statements.Add(Base.Invoke(XmlModelWriteCoreMethodName, _xmlWriterParameter, _serializationOptionsParameter).Terminate()); } diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/MrwSerializationTypeDefinition.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/MrwSerializationTypeDefinition.cs index 8c28b4b6fca..b911501c2a9 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/MrwSerializationTypeDefinition.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/src/Providers/MrwSerializationTypeDefinition.cs @@ -53,10 +53,14 @@ public partial class MrwSerializationTypeDefinition : TypeProvider private readonly ScopedApi _mrwOptionsParameterSnippet; private readonly ScopedApi _jsonElementParameterSnippet; private readonly ScopedApi _isNotEqualToWireConditionSnippet; - private readonly CSharpType _jsonModelTInterface; - private readonly CSharpType? _jsonModelObjectInterface; - private readonly CSharpType _persistableModelTInterface; - private readonly CSharpType? _persistableModelObjectInterface; + // These interface types depend on _model.Type. Build them lazily so we do not cache a + // CSharpType before delayed base model resolution has updated the model's inheritance. + private CSharpType? _jsonModelTInterfaceValue; + private CSharpType _jsonModelTInterface => _jsonModelTInterfaceValue ??= new CSharpType(typeof(IJsonModel<>), SerializationInterfaceType.Type); + private CSharpType? _jsonModelObjectInterface => _isStruct ? (CSharpType)typeof(IJsonModel) : null; + private CSharpType? _persistableModelTInterfaceValue; + private CSharpType _persistableModelTInterface => _persistableModelTInterfaceValue ??= new CSharpType(typeof(IPersistableModel<>), SerializationInterfaceType.Type); + private CSharpType? _persistableModelObjectInterface => _isStruct ? (CSharpType)typeof(IPersistableModel) : null; private readonly ModelProvider _model; private readonly InputModelType _inputModel; private readonly FieldProvider? _rawDataField; @@ -67,10 +71,18 @@ public partial class MrwSerializationTypeDefinition : TypeProvider private readonly bool _supportsXml; private ConstructorProvider? _serializationConstructor; // Flag to determine if the model should override the serialization methods - private readonly bool _shouldOverrideMethods; - private readonly bool _shouldSkipDerivedSerializationMethodOverrides; + private bool ShouldOverrideMethods => _model.BaseModelProvider != null && !_isStruct; + private bool ShouldSkipSerializationMethodOverrides => ShouldSkipDerivedSerializationMethodOverrides(_model.BaseModelProvider); + private readonly bool _shouldOverrideXmlMethods; private readonly Lazy _additionalProperties; + // Unknown discriminator models use their base model as the serialization interface type. + // This can also touch model.Type, so defer it until serialization method/interface emission. + private TypeProvider SerializationInterfaceType => _serializationInterfaceType ??= _inputModel.IsUnknownDiscriminatorModel + ? ScmCodeModelGenerator.Instance.TypeFactory.CreateModel(_inputModel.BaseModel!)! + : _model; + private TypeProvider? _serializationInterfaceType; + private CSharpType RootType => _rootType ??= GetRootModelType(); private CSharpType? _rootType; @@ -84,17 +96,10 @@ public MrwSerializationTypeDefinition(InputModelType inputModel, ModelProvider m _isStruct = _model.DeclarationModifiers.HasFlag(TypeSignatureModifiers.Struct); _supportsXml = inputModel.Usage.HasFlag(InputModelTypeUsage.Xml); _supportsJson = inputModel.Usage.HasFlag(InputModelTypeUsage.Json) || !_supportsXml; - // Initialize the serialization interfaces - var interfaceType = inputModel.IsUnknownDiscriminatorModel ? ScmCodeModelGenerator.Instance.TypeFactory.CreateModel(inputModel.BaseModel!)! : _model; - _jsonModelTInterface = new CSharpType(typeof(IJsonModel<>), interfaceType.Type); - _jsonModelObjectInterface = _isStruct ? (CSharpType)typeof(IJsonModel) : null; - _persistableModelTInterface = new CSharpType(typeof(IPersistableModel<>), interfaceType.Type); - _persistableModelObjectInterface = _isStruct ? (CSharpType)typeof(IPersistableModel) : null; + _shouldOverrideXmlMethods = _model.BaseModelProvider != null && !_isStruct; _rawDataField = _model.Fields.FirstOrDefault(f => f.Name == AdditionalPropertiesHelper.AdditionalBinaryDataPropsFieldName); _additionalBinaryDataProperty = new(GetAdditionalBinaryDataPropertiesProp); _additionalProperties = new(() => [.. _model.Properties.Where(p => p.IsAdditionalProperties)]); - _shouldOverrideMethods = _model.BaseModelProvider != null && !_isStruct; - _shouldSkipDerivedSerializationMethodOverrides = ShouldSkipDerivedSerializationMethodOverrides(_model.BaseModelProvider); _utf8JsonWriterSnippet = _utf8JsonWriterParameter.As(); _mrwOptionsParameterSnippet = _serializationOptionsParameter.As(); _jsonElementParameterSnippet = _jsonElementDeserializationParam.As(); @@ -530,7 +535,7 @@ internal MethodProvider BuildJsonModelWriteCoreMethod() MethodSignatureModifiers modifiers = _isStruct ? MethodSignatureModifiers.Private : MethodSignatureModifiers.Protected | MethodSignatureModifiers.Virtual; - if (_shouldOverrideMethods) + if (ShouldOverrideMethods) { modifiers = MethodSignatureModifiers.Protected | MethodSignatureModifiers.Override; } @@ -552,7 +557,7 @@ internal MethodProvider BuildPersistableModelWriteCoreMethod() ? MethodSignatureModifiers.Private : MethodSignatureModifiers.Protected | MethodSignatureModifiers.Virtual; - if (_shouldOverrideMethods && !_shouldSkipDerivedSerializationMethodOverrides) + if (ShouldOverrideMethods && !ShouldSkipSerializationMethodOverrides) { modifiers = MethodSignatureModifiers.Protected | MethodSignatureModifiers.Override; } @@ -576,7 +581,7 @@ internal MethodProvider BuildPersistableModelCreateCoreMethod() ? MethodSignatureModifiers.Private : MethodSignatureModifiers.Protected | MethodSignatureModifiers.Virtual; - if (_shouldOverrideMethods && !_shouldSkipDerivedSerializationMethodOverrides) + if (ShouldOverrideMethods && !ShouldSkipSerializationMethodOverrides) { modifiers = MethodSignatureModifiers.Protected | MethodSignatureModifiers.Override; } @@ -624,7 +629,7 @@ internal MethodProvider BuildJsonModelCreateCoreMethod() ? MethodSignatureModifiers.Private : MethodSignatureModifiers.Protected | MethodSignatureModifiers.Virtual; - if (_shouldOverrideMethods && !_shouldSkipDerivedSerializationMethodOverrides) + if (ShouldOverrideMethods && !ShouldSkipSerializationMethodOverrides) { modifiers = MethodSignatureModifiers.Protected | MethodSignatureModifiers.Override; } @@ -1055,7 +1060,7 @@ private MethodBodyStatement[] BuildPersistableModelCreateCoreMethodBody() private MethodBodyStatement CallBaseJsonModelWriteCore(bool isDynamicModelWithNonDynamicBase) { // base.() - bool callBaseWriteMethod = _shouldOverrideMethods + bool callBaseWriteMethod = ShouldOverrideMethods && (_jsonPatchProperty is null || !isDynamicModelWithNonDynamicBase); return callBaseWriteMethod ? Base.Invoke(JsonModelWriteCoreMethodName, [_utf8JsonWriterParameter, _serializationOptionsParameter]).Terminate() diff --git a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/test/Providers/MrwSerializationTypeDefinitions/SystemObjectModelSerializationTests.cs b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/test/Providers/MrwSerializationTypeDefinitions/SystemObjectModelSerializationTests.cs index 503ed1ad68f..fcc90582416 100644 --- a/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/test/Providers/MrwSerializationTypeDefinitions/SystemObjectModelSerializationTests.cs +++ b/packages/http-client-csharp/generator/Microsoft.TypeSpec.Generator.ClientModel/test/Providers/MrwSerializationTypeDefinitions/SystemObjectModelSerializationTests.cs @@ -121,6 +121,31 @@ public void JsonModelWriteCore_IsOverride_WhenBaseIsRegularModel() "JsonModelWriteCore should be 'override' with regular base too"); } + [Test] + public void JsonModelWriteCore_IsOverride_WhenBaseProviderIsResolvedAfterSerialization() + { + var baseInputModel = InputFactory.Model("Resource"); + var derivedInputModel = InputFactory.Model("TrackedResource", properties: [InputFactory.Property("Location", InputPrimitiveType.String)]); + MockHelpers.LoadMockGenerator(inputModels: () => [baseInputModel, derivedInputModel]); + + var derived = new DelayedBaseModelProvider(derivedInputModel); + var serialization = new MrwSerializationTypeDefinition(derivedInputModel, derived); + + // The serialization provider can be constructed before later visitors/customization + // resolution make the base model provider available. + derived.BaseModel = new SystemObjectModelProvider(new CSharpType(typeof(object)), baseInputModel); + + var method = serialization.BuildJsonModelWriteCoreMethod(); + + Assert.AreEqual(derived.BaseModel.Type, derived.Type.BaseType, + "The generated model type should inherit the base resolved after serialization construction."); + Assert.AreEqual(derived.BaseModel.Type, serialization.Type.BaseType, + "The serialization type should inherit the same resolved base."); + Assert.IsTrue(method.Signature.Modifiers.HasFlag(MethodSignatureModifiers.Override), + "JsonModelWriteCore should evaluate BaseModelProvider when the method is built, not when serialization is constructed"); + Assert.IsFalse(method.Signature.Modifiers.HasFlag(MethodSignatureModifiers.Virtual)); + } + // ------------------------------------------------------------------- // PersistableModelWriteCore: 'virtual' with system base, 'override' with regular // (the framework base already implements this; derived model re-introduces it) @@ -328,5 +353,14 @@ FakeMrwBase IPersistableModel.Create(BinaryData data, ModelReaderWr string IPersistableModel.GetFormatFromOptions(ModelReaderWriterOptions options) => "J"; } + + private class DelayedBaseModelProvider(InputModelType inputModel) : ModelProvider(inputModel) + { + public ModelProvider? BaseModel { get; set; } + + protected override ModelProvider? BuildBaseModelProvider() => BaseModel; + + protected override CSharpType? BuildBaseType() => BaseModel?.Type; + } } }