Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
using Microsoft.TypeSpec.Generator.Providers;
using Microsoft.TypeSpec.Generator.Snippets;
using Microsoft.TypeSpec.Generator.Statements;
using Microsoft.TypeSpec.Generator.Shared;
using Microsoft.TypeSpec.Generator.Utilities;
using static Microsoft.TypeSpec.Generator.Snippets.Snippet;

Expand Down Expand Up @@ -133,11 +134,21 @@ private static bool UseSingletonInstance(InputClient inputClient)
var properties = new Dictionary<EnumProvider, PropertyProvider>(_serviceVersionsEnums.Count);
foreach (var (inputEnum, enumProvider) in _serviceVersionsEnums)
{
// For multi-service clients, use the full namespace to guarantee uniqueness
// (the last segment alone can collide when services share a namespace).
var versionPropertyName = _inputClient.IsMultiServiceClient
? $"{inputEnum.Namespace.ToIdentifierName()}{ApiVersionSuffix}"
: VersionSuffix;
string versionPropertyName;
if (!_inputClient.IsMultiServiceClient)
{
versionPropertyName = VersionSuffix;
}
else
{
var serviceNamespace = inputEnum.Namespace;
// Use the full namespace only when the last segment collides;
// otherwise, use BuildNameForService for shorter names.
versionPropertyName = !string.IsNullOrEmpty(serviceNamespace) &&
ClientHelper.HasLastSegmentCollision(serviceNamespace, inputEnum, _serviceVersionsEnums.Keys)
? $"{serviceNamespace.ToIdentifierName()}{ApiVersionSuffix}"
: ClientHelper.BuildNameForService(serviceNamespace ?? string.Empty, string.Empty, ApiVersionSuffix);
}

var versionProperty = new PropertyProvider(
null,
Expand All @@ -161,11 +172,23 @@ private static bool UseSingletonInstance(InputClient inputClient)
}

Dictionary<FieldProvider, EnumProvider> latestVersionFields = new(_serviceVersionsEnums.Count);
foreach (var enumProvider in _serviceVersionsEnums.Values)
foreach (var (inputEnum, enumProvider) in _serviceVersionsEnums)
{
var fieldName = _inputClient.IsMultiServiceClient
? $"{LatestPrefix}{enumProvider.Name.ToIdentifierName()}"
: LatestVersionFieldName;
string fieldName;
if (!_inputClient.IsMultiServiceClient)
{
fieldName = LatestVersionFieldName;
}
else
{
var serviceNamespace = inputEnum.Namespace;
// Use the full namespace only when the last segment collides;
// otherwise, use BuildNameForService for shorter names.
fieldName = !string.IsNullOrEmpty(serviceNamespace) &&
ClientHelper.HasLastSegmentCollision(serviceNamespace, inputEnum, _serviceVersionsEnums.Keys)
? $"{LatestPrefix}{serviceNamespace.ToIdentifierName()}{VersionSuffix}"
: ClientHelper.BuildNameForService(serviceNamespace ?? string.Empty, LatestPrefix, VersionSuffix);
}
var field = new FieldProvider(
modifiers: FieldModifiers.Private | FieldModifiers.Const,
type: enumProvider.Type,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -752,6 +752,11 @@ public void MultiServiceClient_SameLastSegment_ProducesUniqueVersionEnums()
var nestedTypes = clientOptionsProvider!.NestedTypes;
Assert.AreEqual(2, nestedTypes.Count);
CollectionAssert.AllItemsAreUnique(nestedTypes.Select(t => t.Name).ToList());

var writer = new TypeProviderWriter(clientOptionsProvider!);
var file = writer.Write();

Assert.AreEqual(Helpers.GetExpectedFromFile(), file.Content);
}

[Test]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3678,12 +3678,12 @@ public void GetApiVersionFieldForService_MultiService_ReturnsMatchingField()
// Should return the matching field for ServiceA
var fieldA = clientProvider!.GetApiVersionFieldForService("Sample.ServiceA");
Assert.IsNotNull(fieldA);
Assert.AreEqual("_sampleServiceAApiVersion", fieldA!.Name);
Assert.AreEqual("_serviceAApiVersion", fieldA!.Name);

// Should return the matching field for ServiceB
var fieldB = clientProvider.GetApiVersionFieldForService("Sample.ServiceB");
Assert.IsNotNull(fieldB);
Assert.AreEqual("_sampleServiceBApiVersion", fieldB!.Name);
Assert.AreEqual("_serviceBApiVersion", fieldB!.Name);
}

[Test]
Expand Down Expand Up @@ -3814,11 +3814,11 @@ public void GetApiVersionFieldForService_MultiService_CaseInsensitiveMatch()
// Should match case-insensitively
var fieldLowerCase = clientProvider!.GetApiVersionFieldForService("sample.serviceA");
Assert.IsNotNull(fieldLowerCase);
Assert.AreEqual("_sampleServiceAApiVersion", fieldLowerCase!.Name);
Assert.AreEqual("_serviceAApiVersion", fieldLowerCase!.Name);

var fieldUpperCase = clientProvider.GetApiVersionFieldForService("SAMPLE.SERVICEa");
Assert.IsNotNull(fieldUpperCase);
Assert.AreEqual("_sampleServiceAApiVersion", fieldUpperCase!.Name);
Assert.AreEqual("_serviceAApiVersion", fieldUpperCase!.Name);
}

[Test]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ public partial class TestClient
{
private readonly global::System.Uri _endpoint;
private readonly string _subscriptionId;
private readonly string _sampleServiceAApiVersion;
private readonly string _sampleServiceBApiVersion;
private readonly string _serviceAApiVersion;
private readonly string _serviceBApiVersion;
private global::Sample.ServiceA.ServiceA _cachedServiceA;
private global::Sample.ServiceB.ServiceB _cachedServiceB;

Expand Down Expand Up @@ -44,8 +44,8 @@ internal TestClient(global::System.ClientModel.Primitives.AuthenticationPolicy a
{
Pipeline = global::System.ClientModel.Primitives.ClientPipeline.Create(options, Array.Empty<global::System.ClientModel.Primitives.PipelinePolicy>(), new global::System.ClientModel.Primitives.PipelinePolicy[] { new global::System.ClientModel.Primitives.UserAgentPolicy(typeof(global::Sample.TestClient).Assembly) }, Array.Empty<global::System.ClientModel.Primitives.PipelinePolicy>());
}
_sampleServiceAApiVersion = options.SampleServiceAApiVersion;
_sampleServiceBApiVersion = options.SampleServiceBApiVersion;
_serviceAApiVersion = options.ServiceAApiVersion;
_serviceBApiVersion = options.ServiceBApiVersion;
}

public TestClient(global::System.Uri endpoint, string subscriptionId, global::Sample.TestClientOptions options) : this(null, endpoint, subscriptionId, options)
Expand All @@ -56,12 +56,12 @@ public TestClient(global::System.Uri endpoint, string subscriptionId, global::Sa

public virtual global::Sample.ServiceA.ServiceA GetServiceAClient()
{
return (global::System.Threading.Volatile.Read(ref _cachedServiceA) ?? (global::System.Threading.Interlocked.CompareExchange(ref _cachedServiceA, new global::Sample.ServiceA.ServiceA(Pipeline, _endpoint, _sampleServiceAApiVersion, _subscriptionId), null) ?? _cachedServiceA));
return (global::System.Threading.Volatile.Read(ref _cachedServiceA) ?? (global::System.Threading.Interlocked.CompareExchange(ref _cachedServiceA, new global::Sample.ServiceA.ServiceA(Pipeline, _endpoint, _serviceAApiVersion, _subscriptionId), null) ?? _cachedServiceA));
}

public virtual global::Sample.ServiceB.ServiceB GetServiceBClient()
{
return (global::System.Threading.Volatile.Read(ref _cachedServiceB) ?? (global::System.Threading.Interlocked.CompareExchange(ref _cachedServiceB, new global::Sample.ServiceB.ServiceB(Pipeline, _endpoint, _sampleServiceBApiVersion, _subscriptionId), null) ?? _cachedServiceB));
return (global::System.Threading.Volatile.Read(ref _cachedServiceB) ?? (global::System.Threading.Interlocked.CompareExchange(ref _cachedServiceB, new global::Sample.ServiceB.ServiceB(Pipeline, _endpoint, _serviceBApiVersion, _subscriptionId), null) ?? _cachedServiceB));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ public partial class TestClient
{
private readonly global::System.Uri _endpoint;
private readonly string _subscriptionId;
private readonly string _sampleComputeApiVersion;
private readonly string _sampleKeyVaultApiVersion;
private readonly string _sampleStorageApiVersion;
private readonly string _computeApiVersion;
private readonly string _keyVaultApiVersion;
private readonly string _storageApiVersion;
private global::Sample.KeyVault.KeyVault _cachedKeyVault;
private global::Sample.Storage.Storage _cachedStorage;
private global::Sample.Compute.Compute _cachedCompute;
Expand Down Expand Up @@ -47,9 +47,9 @@ internal TestClient(global::System.ClientModel.Primitives.AuthenticationPolicy a
{
Pipeline = global::System.ClientModel.Primitives.ClientPipeline.Create(options, Array.Empty<global::System.ClientModel.Primitives.PipelinePolicy>(), new global::System.ClientModel.Primitives.PipelinePolicy[] { new global::System.ClientModel.Primitives.UserAgentPolicy(typeof(global::Sample.TestClient).Assembly) }, Array.Empty<global::System.ClientModel.Primitives.PipelinePolicy>());
}
_sampleComputeApiVersion = options.SampleComputeApiVersion;
_sampleKeyVaultApiVersion = options.SampleKeyVaultApiVersion;
_sampleStorageApiVersion = options.SampleStorageApiVersion;
_computeApiVersion = options.ComputeApiVersion;
_keyVaultApiVersion = options.KeyVaultApiVersion;
_storageApiVersion = options.StorageApiVersion;
}

public TestClient(global::System.Uri endpoint, string subscriptionId, global::Sample.TestClientOptions options) : this(null, endpoint, subscriptionId, options)
Expand All @@ -60,17 +60,17 @@ public TestClient(global::System.Uri endpoint, string subscriptionId, global::Sa

public virtual global::Sample.KeyVault.KeyVault GetKeyVaultClient()
{
return (global::System.Threading.Volatile.Read(ref _cachedKeyVault) ?? (global::System.Threading.Interlocked.CompareExchange(ref _cachedKeyVault, new global::Sample.KeyVault.KeyVault(Pipeline, _endpoint, _sampleKeyVaultApiVersion, _subscriptionId), null) ?? _cachedKeyVault));
return (global::System.Threading.Volatile.Read(ref _cachedKeyVault) ?? (global::System.Threading.Interlocked.CompareExchange(ref _cachedKeyVault, new global::Sample.KeyVault.KeyVault(Pipeline, _endpoint, _keyVaultApiVersion, _subscriptionId), null) ?? _cachedKeyVault));
}

public virtual global::Sample.Storage.Storage GetStorageClient()
{
return (global::System.Threading.Volatile.Read(ref _cachedStorage) ?? (global::System.Threading.Interlocked.CompareExchange(ref _cachedStorage, new global::Sample.Storage.Storage(Pipeline, _endpoint, _sampleStorageApiVersion, _subscriptionId), null) ?? _cachedStorage));
return (global::System.Threading.Volatile.Read(ref _cachedStorage) ?? (global::System.Threading.Interlocked.CompareExchange(ref _cachedStorage, new global::Sample.Storage.Storage(Pipeline, _endpoint, _storageApiVersion, _subscriptionId), null) ?? _cachedStorage));
}

public virtual global::Sample.Compute.Compute GetComputeClient()
{
return (global::System.Threading.Volatile.Read(ref _cachedCompute) ?? (global::System.Threading.Interlocked.CompareExchange(ref _cachedCompute, new global::Sample.Compute.Compute(Pipeline, _endpoint, _sampleComputeApiVersion, _subscriptionId), null) ?? _cachedCompute));
return (global::System.Threading.Volatile.Read(ref _cachedCompute) ?? (global::System.Threading.Interlocked.CompareExchange(ref _cachedCompute, new global::Sample.Compute.Compute(Pipeline, _endpoint, _computeApiVersion, _subscriptionId), null) ?? _cachedCompute));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ namespace Sample
public partial class TestClient
{
private readonly global::System.Uri _endpoint;
private readonly string _sampleServiceAApiVersion;
private readonly string _sampleServiceBApiVersion;
private readonly string _serviceAApiVersion;
private readonly string _serviceBApiVersion;

protected TestClient()
{
Expand All @@ -39,8 +39,8 @@ internal TestClient(global::System.ClientModel.Primitives.AuthenticationPolicy a
{
Pipeline = global::System.ClientModel.Primitives.ClientPipeline.Create(options, Array.Empty<global::System.ClientModel.Primitives.PipelinePolicy>(), new global::System.ClientModel.Primitives.PipelinePolicy[] { new global::System.ClientModel.Primitives.UserAgentPolicy(typeof(global::Sample.TestClient).Assembly) }, Array.Empty<global::System.ClientModel.Primitives.PipelinePolicy>());
}
_sampleServiceAApiVersion = options.SampleServiceAApiVersion;
_sampleServiceBApiVersion = options.SampleServiceBApiVersion;
_serviceAApiVersion = options.ServiceAApiVersion;
_serviceBApiVersion = options.ServiceBApiVersion;
}

public TestClient(global::System.Uri endpoint, global::Sample.TestClientOptions options) : this(null, endpoint, options)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ namespace Sample
public partial class TestClient
{
private readonly global::System.Uri _endpoint;
private readonly string _sampleComputeApiVersion;
private readonly string _sampleKeyVaultApiVersion;
private readonly string _sampleStorageApiVersion;
private readonly string _computeApiVersion;
private readonly string _keyVaultApiVersion;
private readonly string _storageApiVersion;

protected TestClient()
{
Expand All @@ -40,9 +40,9 @@ internal TestClient(global::System.ClientModel.Primitives.AuthenticationPolicy a
{
Pipeline = global::System.ClientModel.Primitives.ClientPipeline.Create(options, Array.Empty<global::System.ClientModel.Primitives.PipelinePolicy>(), new global::System.ClientModel.Primitives.PipelinePolicy[] { new global::System.ClientModel.Primitives.UserAgentPolicy(typeof(global::Sample.TestClient).Assembly) }, Array.Empty<global::System.ClientModel.Primitives.PipelinePolicy>());
}
_sampleComputeApiVersion = options.SampleComputeApiVersion;
_sampleKeyVaultApiVersion = options.SampleKeyVaultApiVersion;
_sampleStorageApiVersion = options.SampleStorageApiVersion;
_computeApiVersion = options.ComputeApiVersion;
_keyVaultApiVersion = options.KeyVaultApiVersion;
_storageApiVersion = options.StorageApiVersion;
}

public TestClient(global::System.Uri endpoint, global::Sample.TestClientOptions options) : this(null, endpoint, options)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ public partial class TestClient
{
global::Sample.ClientUriBuilder uri = new global::Sample.ClientUriBuilder();
uri.Reset(_endpoint);
uri.AppendQuery("apiVersion", _sampleServiceAApiVersion, true);
uri.AppendQuery("apiVersion", _serviceAApiVersion, true);
global::System.ClientModel.Primitives.PipelineMessage message = Pipeline.CreateMessage(uri.ToUri(), "GET", PipelineMessageClassifier200);
global::System.ClientModel.Primitives.PipelineRequest request = message.Request;
message.Apply(options);
Expand All @@ -27,7 +27,7 @@ public partial class TestClient
{
global::Sample.ClientUriBuilder uri = new global::Sample.ClientUriBuilder();
uri.Reset(_endpoint);
uri.AppendQuery("apiVersion", _sampleServiceBApiVersion, true);
uri.AppendQuery("apiVersion", _serviceBApiVersion, true);
global::System.ClientModel.Primitives.PipelineMessage message = Pipeline.CreateMessage(uri.ToUri(), "GET", PipelineMessageClassifier200);
global::System.ClientModel.Primitives.PipelineRequest request = message.Request;
message.Apply(options);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ public partial class TestClient
{
global::Sample.ClientUriBuilder uri = new global::Sample.ClientUriBuilder();
uri.Reset(_endpoint);
uri.AppendQuery("apiVersion", _sampleKeyVaultApiVersion, true);
uri.AppendQuery("apiVersion", _keyVaultApiVersion, true);
global::System.ClientModel.Primitives.PipelineMessage message = Pipeline.CreateMessage(uri.ToUri(), "GET", PipelineMessageClassifier200);
global::System.ClientModel.Primitives.PipelineRequest request = message.Request;
message.Apply(options);
Expand All @@ -27,7 +27,7 @@ public partial class TestClient
{
global::Sample.ClientUriBuilder uri = new global::Sample.ClientUriBuilder();
uri.Reset(_endpoint);
uri.AppendQuery("apiVersion", _sampleStorageApiVersion, true);
uri.AppendQuery("apiVersion", _storageApiVersion, true);
global::System.ClientModel.Primitives.PipelineMessage message = Pipeline.CreateMessage(uri.ToUri(), "GET", PipelineMessageClassifier200);
global::System.ClientModel.Primitives.PipelineRequest request = message.Request;
message.Apply(options);
Expand All @@ -38,7 +38,7 @@ public partial class TestClient
{
global::Sample.ClientUriBuilder uri = new global::Sample.ClientUriBuilder();
uri.Reset(_endpoint);
uri.AppendQuery("apiVersion", _sampleComputeApiVersion, true);
uri.AppendQuery("apiVersion", _computeApiVersion, true);
global::System.ClientModel.Primitives.PipelineMessage message = Pipeline.CreateMessage(uri.ToUri(), "GET", PipelineMessageClassifier200);
global::System.ClientModel.Primitives.PipelineRequest request = message.Request;
message.Apply(options);
Expand Down
Loading
Loading