Skip to content

Commit c546c07

Browse files
committed
Add extension method GetStructuredResponse for IOpenAiClient
1 parent 16d08e8 commit c546c07

File tree

16 files changed

+448
-26
lines changed

16 files changed

+448
-26
lines changed

OpenAI.ChatGpt.AspNetCore/ChatGPTFactory.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,9 +127,9 @@ public async Task<ChatGPT> Create(
127127

128128
public void Dispose()
129129
{
130-
if (!_isHttpClientInjected)
130+
if (!_isHttpClientInjected && _client is IDisposable disposableClient)
131131
{
132-
_client.Dispose();
132+
disposableClient.Dispose();
133133
}
134134
}
135135
}

OpenAI.ChatGpt.Modules.Translator/ChatGPTTranslatorService.cs

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ public ChatGPTTranslatorService(
2525

2626
public ChatGPTTranslatorService(
2727
string apiKey,
28-
string? host,
28+
string? host = null,
2929
string? defaultSourceLanguage = null,
3030
string? defaultTargetLanguage = null,
3131
string? extraPrompt = null)
@@ -39,9 +39,9 @@ public ChatGPTTranslatorService(
3939

4040
public void Dispose()
4141
{
42-
if (!_isHttpClientInjected)
42+
if (!_isHttpClientInjected && _client is IDisposable disposableClient)
4343
{
44-
_client.Dispose();
44+
disposableClient.Dispose();
4545
}
4646
}
4747

@@ -55,6 +55,14 @@ public async Task<string> Translate(
5555
if (text == null) throw new ArgumentNullException(nameof(text));
5656
var sourceLanguageOrDefault = sourceLanguage ?? _defaultSourceLanguage;
5757
var targetLanguageOrDefault = targetLanguage ?? _defaultTargetLanguage;
58+
if (sourceLanguageOrDefault is null)
59+
{
60+
throw new ArgumentNullException(nameof(sourceLanguage), "Source language is not specified");
61+
}
62+
if (targetLanguageOrDefault is null)
63+
{
64+
throw new ArgumentNullException(nameof(targetLanguage), "Target language is not specified");
65+
}
5866
var prompt = GetPrompt(sourceLanguageOrDefault, targetLanguageOrDefault);
5967
var response = await _client.GetChatCompletions(
6068
Dialog.StartAsSystem(prompt).ThenUser(text),

OpenAI.ChatGpt/ChatGPT.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,9 @@ public void Dispose()
9292
{
9393
Stop();
9494
_currentChat?.Dispose();
95-
if (!_isClientInjected)
95+
if (!_isClientInjected && _client is IDisposable disposableClient)
9696
{
97-
_client.Dispose();
97+
disposableClient.Dispose();
9898
}
9999
}
100100

OpenAI.ChatGpt/ChatService.cs

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,13 +93,16 @@ private async Task<string> GetNextMessageResponse(
9393
cancellationToken = _cts.Token;
9494

9595
var history = await LoadHistory(cancellationToken);
96-
var messages = history.Append(message);
96+
var messages = history.Append(message).ToArray();
9797

9898
IsWriting = true;
9999
try
100100
{
101+
var (model, maxTokens) = FindOptimalModelAndMaxToken(messages);
101102
var response = await _client.GetChatCompletionsRaw(
102103
messages,
104+
maxTokens: maxTokens,
105+
model: model,
103106
user: Topic.Config.PassUserIdToOpenAiRequests is true ? UserId : null,
104107
requestModifier: Topic.Config.ModifyRequest,
105108
cancellationToken: cancellationToken
@@ -117,7 +120,13 @@ await _chatHistoryStorage.SaveMessages(
117120
IsWriting = false;
118121
}
119122
}
120-
123+
124+
private (string model, int maxTokens) FindOptimalModelAndMaxToken(ChatCompletionMessage[] messages)
125+
{
126+
return ChatCompletionMessage.FindOptimalModelAndMaxToken(
127+
messages, Topic.Config.Model, Topic.Config.MaxTokens);
128+
}
129+
121130
public IAsyncEnumerable<string> StreamNextMessageResponse(
122131
string message,
123132
bool throwOnCancellation = true,
@@ -143,11 +152,14 @@ private async IAsyncEnumerable<string> StreamNextMessageResponse(
143152
cancellationToken = _cts.Token;
144153

145154
var history = await LoadHistory(cancellationToken);
146-
var messages = history.Append(message);
155+
var messages = history.Append(message).ToArray();
147156
var sb = new StringBuilder();
148157
IsWriting = true;
158+
var (model, maxTokens) = FindOptimalModelAndMaxToken(messages);
149159
var stream = _client.StreamChatCompletions(
150160
messages,
161+
maxTokens: maxTokens,
162+
model: model,
151163
user: Topic.Config.PassUserIdToOpenAiRequests is true ? UserId : null,
152164
requestModifier: Topic.Config.ModifyRequest,
153165
cancellationToken: cancellationToken

OpenAI.ChatGpt/IOpenAiClient.cs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
1-
using System.Runtime.CompilerServices;
2-
using OpenAI.ChatGpt.Models.ChatCompletion;
1+
using OpenAI.ChatGpt.Models.ChatCompletion;
32
using OpenAI.ChatGpt.Models.ChatCompletion.Messaging;
43

54
namespace OpenAI.ChatGpt;
65

7-
public interface IOpenAiClient : IDisposable
6+
public interface IOpenAiClient
87
{
98
Task<string> GetChatCompletions(
109
UserOrSystemMessage dialog,
@@ -13,6 +12,7 @@ Task<string> GetChatCompletions(
1312
float temperature = ChatCompletionTemperatures.Default,
1413
string? user = null,
1514
Action<ChatCompletionRequest>? requestModifier = null,
15+
Action<ChatCompletionResponse>? rawResponseGetter = null,
1616
CancellationToken cancellationToken = default);
1717

1818
Task<string> GetChatCompletions(
@@ -22,6 +22,7 @@ Task<string> GetChatCompletions(
2222
float temperature = ChatCompletionTemperatures.Default,
2323
string? user = null,
2424
Action<ChatCompletionRequest>? requestModifier = null,
25+
Action<ChatCompletionResponse>? rawResponseGetter = null,
2526
CancellationToken cancellationToken = default);
2627

2728
Task<ChatCompletionResponse> GetChatCompletionsRaw(
@@ -81,8 +82,7 @@ IAsyncEnumerable<string> StreamChatCompletions(
8182
CancellationToken cancellationToken = default);
8283

8384
IAsyncEnumerable<string> StreamChatCompletions(
84-
ChatCompletionRequest request,
85-
[EnumeratorCancellation] CancellationToken cancellationToken = default);
85+
ChatCompletionRequest request,CancellationToken cancellationToken = default);
8686

8787
IAsyncEnumerable<ChatCompletionResponse> StreamChatCompletionsRaw(
8888
ChatCompletionRequest request, CancellationToken cancellationToken = default);

OpenAI.ChatGpt/Models/ChatCompletion/ChatCompletionModels.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ public static void EnsureMaxTokensIsSupported(string model, int maxTokens)
199199
{
200200
throw new ArgumentOutOfRangeException(
201201
nameof(maxTokens),
202-
$"Max tokens must be less than or equal to {limit} for model {model}"
202+
$"Max tokens must be less than or equal to {limit} for model {model} but was {maxTokens}"
203203
);
204204
}
205205
}
@@ -210,7 +210,7 @@ public static void EnsureMaxTokensIsSupportedByAnyModel(int maxTokens)
210210
if (maxTokens > limit)
211211
{
212212
throw new ArgumentOutOfRangeException(
213-
nameof(maxTokens), $"Max tokens must be less than or equal to {limit}");
213+
nameof(maxTokens), $"Max tokens must be less than or equal to {limit} but was {maxTokens}");
214214
}
215215
}
216216
}

OpenAI.ChatGpt/Models/ChatCompletion/Messaging/ChatCompletionMessage.cs

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,4 +95,34 @@ public override string ToString()
9595
? $"{Role}: {Content}"
9696
: string.Join(Environment.NewLine, _messages.Select(m => $"{m.Role}: {m.Content}"));
9797
}
98+
99+
public static (string model, int maxTokens) FindOptimalModelAndMaxToken(
100+
IEnumerable<ChatCompletionMessage> messages,
101+
string? model,
102+
int? maxTokens,
103+
string smallModel = ChatCompletionModels.Default,
104+
string bigModel = ChatCompletionModels.Gpt3_5_Turbo_16k,
105+
bool useMaxPossibleTokens = true)
106+
{
107+
var tokenCount = CalculateApproxTotalTokenCount(messages);
108+
switch (model, maxTokens)
109+
{
110+
case (null, null):
111+
{
112+
model = tokenCount > 6000 ? bigModel : smallModel;
113+
maxTokens = GetMaxPossibleTokens(model);
114+
break;
115+
}
116+
case (null, _):
117+
model = smallModel;
118+
break;
119+
case (_, null):
120+
maxTokens = useMaxPossibleTokens ? GetMaxPossibleTokens(model) : ChatCompletionRequest.MaxTokensDefault;
121+
break;
122+
}
123+
124+
return (model, maxTokens.Value);
125+
126+
int GetMaxPossibleTokens(string s) => ChatCompletionModels.GetMaxTokensLimitForModel(s) - tokenCount - 500;
127+
}
98128
}
Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ namespace OpenAI.ChatGpt;
1212

1313
/// <summary> Thread-safe OpenAI client. </summary>
1414
[Fody.ConfigureAwait(false)]
15-
public class OpenAiClient : IDisposable, IOpenAiClient
15+
public class OpenAiClient : IOpenAiClient, IDisposable
1616
{
1717
private const string DefaultHost = "https://api.openai.com/v1/";
1818
private const string ImagesEndpoint = "images/generations";
@@ -122,6 +122,7 @@ public async Task<string> GetChatCompletions(
122122
float temperature = ChatCompletionTemperatures.Default,
123123
string? user = null,
124124
Action<ChatCompletionRequest>? requestModifier = null,
125+
Action<ChatCompletionResponse>? rawResponseGetter = null,
125126
CancellationToken cancellationToken = default)
126127
{
127128
if (dialog == null) throw new ArgumentNullException(nameof(dialog));
@@ -135,6 +136,7 @@ public async Task<string> GetChatCompletions(
135136
requestModifier
136137
);
137138
var response = await GetChatCompletionsRaw(request, cancellationToken);
139+
rawResponseGetter?.Invoke(response);
138140
return response.Choices[0].Message!.Content;
139141
}
140142

@@ -145,6 +147,7 @@ public async Task<string> GetChatCompletions(
145147
float temperature = ChatCompletionTemperatures.Default,
146148
string? user = null,
147149
Action<ChatCompletionRequest>? requestModifier = null,
150+
Action<ChatCompletionResponse>? rawResponseGetter = null,
148151
CancellationToken cancellationToken = default)
149152
{
150153
if (messages == null) throw new ArgumentNullException(nameof(messages));
@@ -158,6 +161,7 @@ public async Task<string> GetChatCompletions(
158161
requestModifier
159162
);
160163
var response = await GetChatCompletionsRaw(request, cancellationToken);
164+
rawResponseGetter?.Invoke(response);
161165
return response.GetMessageContent();
162166
}
163167

0 commit comments

Comments
 (0)