@@ -68,12 +68,12 @@ def create_step(
6868 "label" : "production" ,
6969 "groundTruthColumnName" : "groundTruth" ,
7070 "latencyColumnName" : "latency" ,
71+ "costColumnName" : "cost" ,
72+ "numOfTokenColumnName" : "tokens" ,
7173 }
7274 if isinstance (new_step , steps .OpenAIChatCompletionStep ):
7375 config .update (
7476 {
75- "costColumnName" : "cost" ,
76- "numOfTokenColumnName" : "tokens" ,
7777 "prompt" : new_step .inputs .get ("prompt" ),
7878 }
7979 )
@@ -99,25 +99,54 @@ def process_trace_for_upload(trace: traces.Trace) -> Tuple[Dict[str, Any], List[
9999 input_variables = root_step .inputs
100100 input_variable_names = list (input_variables .keys ())
101101
102+ processed_steps = bubble_up_costs_and_tokens (trace .to_dict ())
103+
102104 trace_data = {
103105 ** input_variables ,
104106 "output" : root_step .output ,
105107 "groundTruth" : root_step .ground_truth ,
106108 "latency" : root_step .latency ,
107- "steps" : trace .to_dict (),
109+ "cost" : processed_steps [0 ].get ("cost" , 0 ),
110+ "tokens" : processed_steps [0 ].get ("tokens" , 0 ),
111+ "steps" : processed_steps ,
108112 }
109- # Extra fields for openai_chat_completion step
110- if isinstance (root_step , steps .OpenAIChatCompletionStep ):
111- trace_data .update (
112- {
113- "cost" : root_step .cost ,
114- "tokens" : root_step .prompt_tokens + root_step .completion_tokens ,
115- }
116- )
117113
118114 return trace_data , input_variable_names
119115
120116
117+ def bubble_up_costs_and_tokens (
118+ trace_dict : List [Dict [str , Any ]]
119+ ) -> List [Dict [str , Any ]]:
120+ """Adds the cost and number of tokens of nested steps to their parent steps."""
121+
122+ def add_step_costs_and_tokens (step : Dict [str , Any ]) -> Tuple [float , int ]:
123+ step_cost = step_tokens = 0
124+
125+ if "cost" in step :
126+ step_cost += step ["cost" ]
127+ if "tokens" in step :
128+ step_tokens += step ["tokens" ]
129+
130+ # Recursively add costs and tokens from nested steps
131+ for nested_step in step .get ("steps" , []):
132+ nested_cost , nested_tokens = add_step_costs_and_tokens (nested_step )
133+ step_cost += nested_cost
134+ step_tokens += nested_tokens
135+
136+ if "steps" in step :
137+ if step_cost > 0 and "cost" not in step :
138+ step ["cost" ] = step_cost
139+ if step_tokens > 0 and "tokens" not in step :
140+ step ["tokens" ] = step_tokens
141+
142+ return step_cost , step_tokens
143+
144+ for root_step_dict in trace_dict :
145+ add_step_costs_and_tokens (root_step_dict )
146+
147+ return trace_dict
148+
149+
121150def trace (* step_args , ** step_kwargs ):
122151 def decorator (func ):
123152 func_signature = inspect .signature (func )
0 commit comments