@@ -22,8 +22,9 @@ import (
2222)
2323
2424type Client struct {
25- modelsLock sync.Mutex
25+ clientsLock sync.Mutex
2626 cache * cache.Client
27+ clients map [string ]clientInfo
2728 modelToProvider map [string ]string
2829 runner * runner.Runner
2930 envs []string
@@ -38,13 +39,15 @@ func New(r *runner.Runner, envs []string, cache *cache.Client, credStore credent
3839 envs : envs ,
3940 credStore : credStore ,
4041 defaultProvider : defaultProvider ,
42+ modelToProvider : make (map [string ]string ),
43+ clients : make (map [string ]clientInfo ),
4144 }
4245}
4346
4447func (c * Client ) Call (ctx context.Context , messageRequest types.CompletionRequest , status chan <- types.CompletionStatus ) (* types.CompletionMessage , error ) {
45- c .modelsLock .Lock ()
48+ c .clientsLock .Lock ()
4649 provider , ok := c .modelToProvider [messageRequest .Model ]
47- c .modelsLock .Unlock ()
50+ c .clientsLock .Unlock ()
4851
4952 if ! ok {
5053 return nil , fmt .Errorf ("failed to find remote model %s" , messageRequest .Model )
@@ -105,12 +108,8 @@ func (c *Client) Supports(ctx context.Context, modelString string) (bool, error)
105108 return false , err
106109 }
107110
108- c .modelsLock .Lock ()
109- defer c .modelsLock .Unlock ()
110-
111- if c .modelToProvider == nil {
112- c .modelToProvider = map [string ]string {}
113- }
111+ c .clientsLock .Lock ()
112+ defer c .clientsLock .Unlock ()
114113
115114 c .modelToProvider [modelString ] = providerName
116115 return true , nil
@@ -145,11 +144,23 @@ func (c *Client) clientFromURL(ctx context.Context, apiURL string) (*openai.Clie
145144}
146145
147146func (c * Client ) load (ctx context.Context , toolName string ) (* openai.Client , error ) {
147+ c .clientsLock .Lock ()
148+ defer c .clientsLock .Unlock ()
149+
150+ client , ok := c .clients [toolName ]
151+ if ok && ! isHTTPURL (toolName ) && engine .IsDaemonRunning (client .url ) {
152+ return client .client , nil
153+ }
154+
148155 if isHTTPURL (toolName ) {
149156 remoteClient , err := c .clientFromURL (ctx , toolName )
150157 if err != nil {
151158 return nil , err
152159 }
160+ c .clients [toolName ] = clientInfo {
161+ client : remoteClient ,
162+ url : toolName ,
163+ }
153164 return remoteClient , nil
154165 }
155166
@@ -165,7 +176,7 @@ func (c *Client) load(ctx context.Context, toolName string) (*openai.Client, err
165176 return nil , err
166177 }
167178
168- client , err := openai .NewClient (ctx , c .credStore , openai.Options {
179+ oClient , err := openai .NewClient (ctx , c .credStore , openai.Options {
169180 BaseURL : strings .TrimSuffix (url , "/" ) + "/v1" ,
170181 Cache : c .cache ,
171182 CacheKey : prg .EntryToolID ,
@@ -174,7 +185,11 @@ func (c *Client) load(ctx context.Context, toolName string) (*openai.Client, err
174185 return nil , err
175186 }
176187
177- return client , nil
188+ c .clients [toolName ] = clientInfo {
189+ client : oClient ,
190+ url : url ,
191+ }
192+ return client .client , nil
178193}
179194
180195func (c * Client ) retrieveAPIKey (ctx context.Context , env , url string ) (string , error ) {
@@ -185,3 +200,8 @@ func isLocalhost(url string) bool {
185200 return strings .HasPrefix (url , "http://localhost" ) || strings .HasPrefix (url , "http://127.0.0.1" ) ||
186201 strings .HasPrefix (url , "https://localhost" ) || strings .HasPrefix (url , "https://127.0.0.1" )
187202}
203+
204+ type clientInfo struct {
205+ client * openai.Client
206+ url string
207+ }
0 commit comments