Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ require (
github.com/hexops/valast v1.4.4
github.com/jaytaylor/html2text v0.0.0-20230321000545-74c2419ad056
github.com/mholt/archives v0.1.0
github.com/pkoukk/tiktoken-go v0.1.7
github.com/pkoukk/tiktoken-go-loader v0.0.2-0.20240522064338-c17e8bc0f699
github.com/rs/cors v1.11.0
github.com/samber/lo v1.38.1
github.com/sirupsen/logrus v1.9.3
Expand Down Expand Up @@ -62,7 +64,7 @@ require (
github.com/cpuguy83/go-md2man/v2 v2.0.3 // indirect
github.com/cyphar/filepath-securejoin v0.2.4 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/dlclark/regexp2 v1.4.0 // indirect
github.com/dlclark/regexp2 v1.10.0 // indirect
github.com/dsnet/compress v0.0.2-0.20230904184137-39efe44ab707 // indirect
github.com/emirpasic/gods v1.18.1 // indirect
github.com/go-git/gcfg v1.5.1-0.20230307220236-3a3c6141e376 // indirect
Expand Down
8 changes: 6 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,8 @@ github.com/cyphar/filepath-securejoin v0.2.4/go.mod h1:aPGpWjXOXUn2NCNjFvBE6aRxG
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dlclark/regexp2 v1.4.0 h1:F1rxgk7p4uKjwIQxBs9oAXe5CqrXlCduYEJvrF4u93E=
github.com/dlclark/regexp2 v1.4.0/go.mod h1:2pZnwuY/m+8K6iRw6wQdMtk+rH5tNGR1i55kozfMjCc=
github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0=
github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/docker/cli v26.0.0+incompatible h1:90BKrx1a1HKYpSnnBFR6AgDq/FqkHxwlUyzJVPxD30I=
github.com/docker/cli v26.0.0+incompatible/go.mod h1:JLrzqnKDaYBop7H2jaqPtU4hHvMKP+vjCwu2uszcLI8=
github.com/docker/docker-credential-helpers v0.8.1 h1:j/eKUktUltBtMzKqmfLB0PAgqYyMHOp5vfsD1807oKo=
Expand Down Expand Up @@ -316,6 +316,10 @@ github.com/pjbgf/sha1cd v0.3.0/go.mod h1:nZ1rrWOcGJ5uZgEEVL1VUM9iRQiZvWdbZjkKyFz
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkoukk/tiktoken-go v0.1.7 h1:qOBHXX4PHtvIvmOtyg1EeKlwFRiMKAcoMp4Q+bLQDmw=
github.com/pkoukk/tiktoken-go v0.1.7/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg=
github.com/pkoukk/tiktoken-go-loader v0.0.2-0.20240522064338-c17e8bc0f699 h1:Sp8yiuxsitkmCfEvUnmNf8wzuZwlGNkRjI2yF0C3QUQ=
github.com/pkoukk/tiktoken-go-loader v0.0.2-0.20240522064338-c17e8bc0f699/go.mod h1:4mIkYyZooFlnenDlormIo6cd5wrlUKNr97wp9nGgEKo=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
Expand Down
27 changes: 22 additions & 5 deletions pkg/openai/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -349,16 +349,29 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques
return nil, err
}

toolTokenCount, err := countTools(messageRequest.Tools)
if err != nil {
return nil, err
}

if messageRequest.Chat {
// Check the last message. If it is from a tool call, and if it takes up more than 80% of the budget on its own, reject it.
lastMessage := msgs[len(msgs)-1]
if lastMessage.Role == string(types.CompletionMessageRoleTypeTool) && countMessage(lastMessage) > int(float64(getBudget(messageRequest.MaxTokens))*0.8) {
lastMessageCount, err := countMessage(lastMessage)
if err != nil {
return nil, err
}

if lastMessage.Role == string(types.CompletionMessageRoleTypeTool) && lastMessageCount+toolTokenCount > int(float64(getBudget(messageRequest.MaxTokens))*0.8) {
// We need to update it in the msgs slice for right now and in the messageRequest for future calls.
msgs[len(msgs)-1].Content = TooLongMessage
messageRequest.Messages[len(messageRequest.Messages)-1].Content = types.Text(TooLongMessage)
}

msgs = dropMessagesOverCount(messageRequest.MaxTokens, msgs)
msgs, err = dropMessagesOverCount(messageRequest.MaxTokens, toolTokenCount, msgs)
if err != nil {
return nil, err
}
}

if len(msgs) == 0 {
Expand Down Expand Up @@ -439,7 +452,7 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques
// Decrease maxTokens by 10% to make garbage collection more aggressive.
// The retry loop will further decrease maxTokens if needed.
maxTokens := decreaseTenPercent(messageRequest.MaxTokens)
result, err = c.contextLimitRetryLoop(ctx, request, id, env, maxTokens, status)
result, err = c.contextLimitRetryLoop(ctx, request, id, env, maxTokens, toolTokenCount, status)
}
if err != nil {
return nil, err
Expand Down Expand Up @@ -473,15 +486,19 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques
return &result, nil
}

func (c *Client) contextLimitRetryLoop(ctx context.Context, request openai.ChatCompletionRequest, id string, env []string, maxTokens int, status chan<- types.CompletionStatus) (types.CompletionMessage, error) {
func (c *Client) contextLimitRetryLoop(ctx context.Context, request openai.ChatCompletionRequest, id string, env []string, maxTokens int, toolTokenCount int, status chan<- types.CompletionStatus) (types.CompletionMessage, error) {
var (
response types.CompletionMessage
err error
)

for range 10 { // maximum 10 tries
// Try to drop older messages again, with a decreased max tokens.
request.Messages = dropMessagesOverCount(maxTokens, request.Messages)
request.Messages, err = dropMessagesOverCount(maxTokens, toolTokenCount, request.Messages)
if err != nil {
return types.CompletionMessage{}, err
}

response, err = c.call(ctx, request, id, env, status)
if err == nil {
return response, nil
Expand Down
65 changes: 50 additions & 15 deletions pkg/openai/count.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
package openai

import (
"encoding/json"

openai "github.com/gptscript-ai/chat-completion-client"
"github.com/gptscript-ai/gptscript/pkg/types"
"github.com/pkoukk/tiktoken-go"
tiktoken_loader "github.com/pkoukk/tiktoken-go-loader"
)

const DefaultMaxTokens = 128_000
Expand All @@ -12,22 +17,26 @@ func decreaseTenPercent(maxTokens int) int {
}

func getBudget(maxTokens int) int {
if maxTokens == 0 {
if maxTokens <= 0 {
return DefaultMaxTokens
}
return maxTokens
}

func dropMessagesOverCount(maxTokens int, msgs []openai.ChatCompletionMessage) (result []openai.ChatCompletionMessage) {
func dropMessagesOverCount(maxTokens, toolTokenCount int, msgs []openai.ChatCompletionMessage) (result []openai.ChatCompletionMessage, err error) {
var (
lastSystem int
withinBudget int
budget = getBudget(maxTokens)
budget = getBudget(maxTokens) - toolTokenCount
)

for i, msg := range msgs {
if msg.Role == openai.ChatMessageRoleSystem {
budget -= countMessage(msg)
count, err := countMessage(msg)
if err != nil {
return nil, err
}
budget -= count
lastSystem = i
result = append(result, msg)
} else {
Expand All @@ -37,7 +46,11 @@ func dropMessagesOverCount(maxTokens int, msgs []openai.ChatCompletionMessage) (

for i := len(msgs) - 1; i > lastSystem; i-- {
withinBudget = i
budget -= countMessage(msgs[i])
count, err := countMessage(msgs[i])
if err != nil {
return nil, err
}
budget -= count
if budget <= 0 {
break
}
Expand All @@ -54,22 +67,44 @@ func dropMessagesOverCount(maxTokens int, msgs []openai.ChatCompletionMessage) (
if withinBudget == len(msgs)-1 {
// We are going to drop all non system messages, which seems useless, so just return them
// all and let it fail
return msgs
return msgs, nil
}

return append(result, msgs[withinBudget:]...)
return append(result, msgs[withinBudget:]...), nil
}

func countMessage(msg openai.ChatCompletionMessage) (count int) {
count += len(msg.Role)
count += len(msg.Content)
func countMessage(msg openai.ChatCompletionMessage) (int, error) {
tiktoken.SetBpeLoader(tiktoken_loader.NewOfflineLoader())
encoding, err := tiktoken.GetEncoding("o200k_base")
if err != nil {
return 0, err
}

count := len(encoding.Encode(msg.Role, nil, nil))
count += len(encoding.Encode(msg.Content, nil, nil))
for _, content := range msg.MultiContent {
count += len(content.Text)
count += len(encoding.Encode(content.Text, nil, nil))
}
for _, tool := range msg.ToolCalls {
count += len(tool.Function.Name)
count += len(tool.Function.Arguments)
count += len(encoding.Encode(tool.Function.Name, nil, nil))
count += len(encoding.Encode(tool.Function.Arguments, nil, nil))
}
count += len(msg.ToolCallID)
return count / 3
count += len(encoding.Encode(msg.ToolCallID, nil, nil))

return count, nil
}

func countTools(tools []types.ChatCompletionTool) (int, error) {
tiktoken.SetBpeLoader(tiktoken_loader.NewOfflineLoader())
encoding, err := tiktoken.GetEncoding("o200k_base")
if err != nil {
return 0, err
}

toolJSON, err := json.Marshal(tools)
if err != nil {
return 0, err
}

return len(encoding.Encode(string(toolJSON), nil, nil)), nil
}