@@ -116,10 +116,35 @@ var (
116116 EventTypeCallFinish = EventType ("callFinish" )
117117)
118118
119+ func (r * Runner ) getContext (callCtx engine.Context , monitor Monitor , env []string ) (result []engine.InputContext , _ error ) {
120+ toolIDs , err := callCtx .Program .GetContextToolIDs (callCtx .Tool .ID )
121+ if err != nil {
122+ return nil , err
123+ }
124+
125+ for _ , toolID := range toolIDs {
126+ content , err := r .subCall (callCtx .Ctx , callCtx , monitor , env , toolID , "" , "" )
127+ if err != nil {
128+ return nil , err
129+ }
130+ result = append (result , engine.InputContext {
131+ ToolID : toolID ,
132+ Content : content ,
133+ })
134+ }
135+ return result , nil
136+ }
137+
119138func (r * Runner ) call (callCtx engine.Context , monitor Monitor , env []string , input string ) (string , error ) {
120139 progress , progressClose := streamProgress (& callCtx , monitor )
121140 defer progressClose ()
122141
142+ var err error
143+ callCtx .InputContext , err = r .getContext (callCtx , monitor , env )
144+ if err != nil {
145+ return "" , err
146+ }
147+
123148 e := engine.Engine {
124149 Model : r .c ,
125150 RuntimeManager : r .runtimeManager ,
@@ -221,6 +246,15 @@ func streamProgress(callCtx *engine.Context, monitor Monitor) (chan<- types.Comp
221246 }
222247}
223248
249+ func (r * Runner ) subCall (ctx context.Context , parentContext engine.Context , monitor Monitor , env []string , toolID , input , callID string ) (string , error ) {
250+ callCtx , err := parentContext .SubCall (ctx , toolID , callID )
251+ if err != nil {
252+ return "" , err
253+ }
254+
255+ return r .call (callCtx , monitor , env , input )
256+ }
257+
224258func (r * Runner ) subCalls (callCtx engine.Context , monitor Monitor , env []string , lastReturn * engine.Return ) (callResults []engine.CallResult , _ error ) {
225259 var (
226260 resultLock sync.Mutex
@@ -229,12 +263,7 @@ func (r *Runner) subCalls(callCtx engine.Context, monitor Monitor, env []string,
229263 eg , subCtx := errgroup .WithContext (callCtx .Ctx )
230264 for id , call := range lastReturn .Calls {
231265 eg .Go (func () error {
232- callCtx , err := callCtx .SubCall (subCtx , call .ToolID , id )
233- if err != nil {
234- return err
235- }
236-
237- result , err := r .call (callCtx , monitor , env , call .Input )
266+ result , err := r .subCall (subCtx , callCtx , monitor , env , call .ToolID , call .Input , id )
238267 if err != nil {
239268 return err
240269 }
0 commit comments