diff --git a/.netconfig b/.netconfig index 743011f..4bd8be5 100644 --- a/.netconfig +++ b/.netconfig @@ -165,4 +165,9 @@ url = https://github.com/andrewlock/NetEscapades.Configuration/blob/master/src/NetEscapades.Configuration.Yaml/YamlConfigurationStreamParser.cs weak sha = a1ec2c6746d96b4f6f140509aa68dcff09271146 - etag = 9e5c6908edc34eb661d647671f79153d8f3a54ebdc848c8765c78d2715f2f657 \ No newline at end of file + etag = 9e5c6908edc34eb661d647671f79153d8f3a54ebdc848c8765c78d2715f2f657 +[file "src/Tests/Extensions/CallHelpers.cs"] + url = https://github.com/grpc/grpc-dotnet/blob/master/examples/Tester/Tests/Client/Helpers/CallHelpers.cs + sha = a04684ab2306e5a17bad26d3da69636b326cce14 + etag = 7faacded709d2bede93356fd58b93af84884949f3bab098b8b8d121a03696449 + weak diff --git a/src/Extensions.Grok/Extensions.Grok.csproj b/src/Extensions.Grok/Extensions.Grok.csproj index 42827f7..9cbaf07 100644 --- a/src/Extensions.Grok/Extensions.Grok.csproj +++ b/src/Extensions.Grok/Extensions.Grok.csproj @@ -26,6 +26,7 @@ + \ No newline at end of file diff --git a/src/Extensions.Grok/GrokChatClient.cs b/src/Extensions.Grok/GrokChatClient.cs index 6c3c473..48eee2a 100644 --- a/src/Extensions.Grok/GrokChatClient.cs +++ b/src/Extensions.Grok/GrokChatClient.cs @@ -15,8 +15,19 @@ class GrokChatClient : IChatClient readonly GrokClientOptions clientOptions; internal GrokChatClient(GrpcChannel channel, GrokClientOptions clientOptions, string defaultModelId) + : this(new ChatClient(channel), clientOptions, defaultModelId) + { } + + /// + /// Test constructor. + /// + internal GrokChatClient(ChatClient client, string defaultModelId) + : this(client, new(), defaultModelId) + { } + + GrokChatClient(ChatClient client, GrokClientOptions clientOptions, string defaultModelId) { - client = new ChatClient(channel); + this.client = client; this.clientOptions = clientOptions; this.defaultModelId = defaultModelId; metadata = new ChatClientMetadata("xai", clientOptions.Endpoint, defaultModelId); @@ -97,7 +108,7 @@ public async Task GetResponseAsync(IEnumerable messag { ResponseId = response.Id, ModelId = response.Model, - CreatedAt = response.Created.ToDateTimeOffset(), + CreatedAt = response.Created?.ToDateTimeOffset(), FinishReason = lastOutput != null ? MapFinishReason(lastOutput.FinishReason) : null, Usage = MapToUsage(response.Usage), }; @@ -210,13 +221,16 @@ static CitationAnnotation MapCitation(string citation) GetCompletionsRequest MapToRequest(IEnumerable messages, ChatOptions? options) { - var request = new GetCompletionsRequest + var request = options?.RawRepresentationFactory?.Invoke(this) as GetCompletionsRequest ?? new GetCompletionsRequest() { // By default always include citations in the final output if available Include = { IncludeOption.InlineCitations }, Model = options?.ModelId ?? defaultModelId, }; + if (string.IsNullOrEmpty(request.Model)) + request.Model = options?.ModelId ?? defaultModelId; + if ((options?.EndUserId ?? clientOptions.EndUserId) is { } user) request.User = user; if (options?.MaxOutputTokens is { } maxTokens) request.MaxTokens = maxTokens; if (options?.Temperature is { } temperature) request.Temperature = temperature; diff --git a/src/Tests/Extensions/CallHelpers.cs b/src/Tests/Extensions/CallHelpers.cs new file mode 100644 index 0000000..77b8050 --- /dev/null +++ b/src/Tests/Extensions/CallHelpers.cs @@ -0,0 +1,46 @@ +#region Copyright notice and license + +// Copyright 2019 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#endregion + +using Grpc.Core; + +namespace Tests.Client.Helpers +{ + static class CallHelpers + { + public static AsyncUnaryCall CreateAsyncUnaryCall(TResponse response) + { + return new AsyncUnaryCall( + Task.FromResult(response), + Task.FromResult(new Metadata()), + () => Status.DefaultSuccess, + () => new Metadata(), + () => { }); + } + + public static AsyncUnaryCall CreateAsyncUnaryCall(StatusCode statusCode) + { + var status = new Status(statusCode, string.Empty); + return new AsyncUnaryCall( + Task.FromException(new RpcException(status)), + Task.FromResult(new Metadata()), + () => status, + () => new Metadata(), + () => { }); + } + } +} diff --git a/src/Tests/GrokTests.cs b/src/Tests/GrokTests.cs index 3d5c308..a608051 100644 --- a/src/Tests/GrokTests.cs +++ b/src/Tests/GrokTests.cs @@ -4,7 +4,8 @@ using Devlooped.Extensions.AI.Grok; using Devlooped.Grok; using Microsoft.Extensions.AI; -using OpenAI.Realtime; +using Moq; +using Tests.Client.Helpers; using static ConfigurationExtensions; using OpenAIClientOptions = OpenAI.OpenAIClientOptions; @@ -466,5 +467,40 @@ public async Task GrokStreamsUpdatesFromAllTools() Assert.True(typed.Price > 100); } + [Fact] + public async Task GrokCustomFactoryInvokedFromOptions() + { + var invoked = false; + var client = new Mock(MockBehavior.Strict); + client.Setup(x => x.GetCompletionAsync(It.IsAny(), null, null, CancellationToken.None)) + .Returns(CallHelpers.CreateAsyncUnaryCall(new GetChatCompletionResponse + { + Outputs = + { + new CompletionOutput + { + Message = new CompletionMessage + { + Content = "Hey Cazzulino!" + } + } + } + })); + + var grok = new GrokChatClient(client.Object, "grok-4-1-fast"); + var response = await grok.GetResponseAsync("Hi, my internet alias is kzu. Lookup my real full name online.", + new GrokChatOptions + { + RawRepresentationFactory = (client) => + { + invoked = true; + return new GetCompletionsRequest(); + } + }); + + Assert.True(invoked); + Assert.Equal("Hey Cazzulino!", response.Text); + } + record Response(DateOnly Today, string Release, decimal Price); }