diff --git a/.netconfig b/.netconfig
index 743011f..4bd8be5 100644
--- a/.netconfig
+++ b/.netconfig
@@ -165,4 +165,9 @@
url = https://github.com/andrewlock/NetEscapades.Configuration/blob/master/src/NetEscapades.Configuration.Yaml/YamlConfigurationStreamParser.cs
weak
sha = a1ec2c6746d96b4f6f140509aa68dcff09271146
- etag = 9e5c6908edc34eb661d647671f79153d8f3a54ebdc848c8765c78d2715f2f657
\ No newline at end of file
+ etag = 9e5c6908edc34eb661d647671f79153d8f3a54ebdc848c8765c78d2715f2f657
+[file "src/Tests/Extensions/CallHelpers.cs"]
+ url = https://github.com/grpc/grpc-dotnet/blob/master/examples/Tester/Tests/Client/Helpers/CallHelpers.cs
+ sha = a04684ab2306e5a17bad26d3da69636b326cce14
+ etag = 7faacded709d2bede93356fd58b93af84884949f3bab098b8b8d121a03696449
+ weak
diff --git a/src/Extensions.Grok/Extensions.Grok.csproj b/src/Extensions.Grok/Extensions.Grok.csproj
index 42827f7..9cbaf07 100644
--- a/src/Extensions.Grok/Extensions.Grok.csproj
+++ b/src/Extensions.Grok/Extensions.Grok.csproj
@@ -26,6 +26,7 @@
+
\ No newline at end of file
diff --git a/src/Extensions.Grok/GrokChatClient.cs b/src/Extensions.Grok/GrokChatClient.cs
index 6c3c473..48eee2a 100644
--- a/src/Extensions.Grok/GrokChatClient.cs
+++ b/src/Extensions.Grok/GrokChatClient.cs
@@ -15,8 +15,19 @@ class GrokChatClient : IChatClient
readonly GrokClientOptions clientOptions;
internal GrokChatClient(GrpcChannel channel, GrokClientOptions clientOptions, string defaultModelId)
+ : this(new ChatClient(channel), clientOptions, defaultModelId)
+ { }
+
+ ///
+ /// Test constructor.
+ ///
+ internal GrokChatClient(ChatClient client, string defaultModelId)
+ : this(client, new(), defaultModelId)
+ { }
+
+ GrokChatClient(ChatClient client, GrokClientOptions clientOptions, string defaultModelId)
{
- client = new ChatClient(channel);
+ this.client = client;
this.clientOptions = clientOptions;
this.defaultModelId = defaultModelId;
metadata = new ChatClientMetadata("xai", clientOptions.Endpoint, defaultModelId);
@@ -97,7 +108,7 @@ public async Task GetResponseAsync(IEnumerable messag
{
ResponseId = response.Id,
ModelId = response.Model,
- CreatedAt = response.Created.ToDateTimeOffset(),
+ CreatedAt = response.Created?.ToDateTimeOffset(),
FinishReason = lastOutput != null ? MapFinishReason(lastOutput.FinishReason) : null,
Usage = MapToUsage(response.Usage),
};
@@ -210,13 +221,16 @@ static CitationAnnotation MapCitation(string citation)
GetCompletionsRequest MapToRequest(IEnumerable messages, ChatOptions? options)
{
- var request = new GetCompletionsRequest
+ var request = options?.RawRepresentationFactory?.Invoke(this) as GetCompletionsRequest ?? new GetCompletionsRequest()
{
// By default always include citations in the final output if available
Include = { IncludeOption.InlineCitations },
Model = options?.ModelId ?? defaultModelId,
};
+ if (string.IsNullOrEmpty(request.Model))
+ request.Model = options?.ModelId ?? defaultModelId;
+
if ((options?.EndUserId ?? clientOptions.EndUserId) is { } user) request.User = user;
if (options?.MaxOutputTokens is { } maxTokens) request.MaxTokens = maxTokens;
if (options?.Temperature is { } temperature) request.Temperature = temperature;
diff --git a/src/Tests/Extensions/CallHelpers.cs b/src/Tests/Extensions/CallHelpers.cs
new file mode 100644
index 0000000..77b8050
--- /dev/null
+++ b/src/Tests/Extensions/CallHelpers.cs
@@ -0,0 +1,46 @@
+#region Copyright notice and license
+
+// Copyright 2019 The gRPC Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#endregion
+
+using Grpc.Core;
+
+namespace Tests.Client.Helpers
+{
+ static class CallHelpers
+ {
+ public static AsyncUnaryCall CreateAsyncUnaryCall(TResponse response)
+ {
+ return new AsyncUnaryCall(
+ Task.FromResult(response),
+ Task.FromResult(new Metadata()),
+ () => Status.DefaultSuccess,
+ () => new Metadata(),
+ () => { });
+ }
+
+ public static AsyncUnaryCall CreateAsyncUnaryCall(StatusCode statusCode)
+ {
+ var status = new Status(statusCode, string.Empty);
+ return new AsyncUnaryCall(
+ Task.FromException(new RpcException(status)),
+ Task.FromResult(new Metadata()),
+ () => status,
+ () => new Metadata(),
+ () => { });
+ }
+ }
+}
diff --git a/src/Tests/GrokTests.cs b/src/Tests/GrokTests.cs
index 3d5c308..a608051 100644
--- a/src/Tests/GrokTests.cs
+++ b/src/Tests/GrokTests.cs
@@ -4,7 +4,8 @@
using Devlooped.Extensions.AI.Grok;
using Devlooped.Grok;
using Microsoft.Extensions.AI;
-using OpenAI.Realtime;
+using Moq;
+using Tests.Client.Helpers;
using static ConfigurationExtensions;
using OpenAIClientOptions = OpenAI.OpenAIClientOptions;
@@ -466,5 +467,40 @@ public async Task GrokStreamsUpdatesFromAllTools()
Assert.True(typed.Price > 100);
}
+ [Fact]
+ public async Task GrokCustomFactoryInvokedFromOptions()
+ {
+ var invoked = false;
+ var client = new Mock(MockBehavior.Strict);
+ client.Setup(x => x.GetCompletionAsync(It.IsAny(), null, null, CancellationToken.None))
+ .Returns(CallHelpers.CreateAsyncUnaryCall(new GetChatCompletionResponse
+ {
+ Outputs =
+ {
+ new CompletionOutput
+ {
+ Message = new CompletionMessage
+ {
+ Content = "Hey Cazzulino!"
+ }
+ }
+ }
+ }));
+
+ var grok = new GrokChatClient(client.Object, "grok-4-1-fast");
+ var response = await grok.GetResponseAsync("Hi, my internet alias is kzu. Lookup my real full name online.",
+ new GrokChatOptions
+ {
+ RawRepresentationFactory = (client) =>
+ {
+ invoked = true;
+ return new GetCompletionsRequest();
+ }
+ });
+
+ Assert.True(invoked);
+ Assert.Equal("Hey Cazzulino!", response.Text);
+ }
+
record Response(DateOnly Today, string Release, decimal Price);
}