2929
3030DEBUG = False
3131
32+ OPENAI_API_KEY = os .getenv ("OPENAI_API_KEY" )
33+
3234monitor_size = {
3335 "width" : 1920 ,
3436 "height" : 1080 ,
@@ -190,18 +192,35 @@ def supports_ansi():
190192 ANSI_BRIGHT_MAGENTA = ""
191193
192194
195+ def validate_mode (
196+ model ,
197+ accurate_mode ,
198+ voice_mode ,
199+ ):
200+ if accurate_mode and model != "gpt-4-vision-preview" :
201+ print ("To use accuracy mode, please use gpt-4-vision-preview" )
202+ sys .exit (1 )
203+
204+ if voice_mode and not OPENAI_API_KEY :
205+ print ("To use voice mode, please add an OpenAI API key" )
206+ sys .exit (1 )
207+
208+
193209def main (model , accurate_mode , terminal_prompt , voice_mode = False ):
194210 """
195211 Main function for the Self-Operating Computer
196212 """
197213 if model == "gpt-4-vision-preview" :
198214 client = OpenAI ()
199- client .api_key = os . getenv ( " OPENAI_API_KEY" )
215+ client .api_key = OPENAI_API_KEY
200216 client .base_url = os .getenv ("OPENAI_API_BASE_URL" , client .base_url )
201217 elif model == "gemini-pro-vision" :
202- GOOGLE_API_KEY = os .getenv (' GOOGLE_API_KEY' )
218+ GOOGLE_API_KEY = os .getenv (" GOOGLE_API_KEY" )
203219 mic = None
204220 # Initialize WhisperMic if voice_mode is True if voice_mode is True
221+
222+ validate_mode (model , accurate_mode , voice_mode )
223+
205224 if voice_mode :
206225 try :
207226 from whisper_mic import WhisperMic
@@ -261,7 +280,8 @@ def main(model, accurate_mode, terminal_prompt, voice_mode=False):
261280 print ("[loop] messages before next action:\n \n \n " , messages [1 :])
262281 try :
263282 response = get_next_action (model , messages , objective , accurate_mode )
264- action = parse_oai_response (response )
283+
284+ action = parse_response (response )
265285 action_type = action .get ("type" )
266286 action_detail = action .get ("data" )
267287
@@ -361,7 +381,9 @@ def get_next_action(model, messages, objective, accurate_mode):
361381 elif model == "agent-1" :
362382 return "coming soon"
363383 elif model == "gemini-pro-vision" :
364- content = get_next_action_from_gemini_pro_vision (messages , objective , accurate_mode )
384+ content = get_next_action_from_gemini_pro_vision (
385+ messages , objective , accurate_mode
386+ )
365387 return content
366388
367389 raise ModelNotRecognizedException (model )
@@ -385,6 +407,7 @@ def accurate_mode_double_check(model, pseudo_messages, prev_x, prev_y):
385407 """
386408 Reprompt OAI with additional screenshot of a mini screenshot centered around the cursor for further finetuning of clicked location
387409 """
410+ print ("[get_next_action_from_gemini_pro_vision] accurate_mode_double_check" )
388411 try :
389412 screenshot_filename = os .path .join ("screenshots" , "screenshot_mini.png" )
390413 capture_mini_screenshot_with_cursor (
@@ -425,7 +448,9 @@ def accurate_mode_double_check(model, pseudo_messages, prev_x, prev_y):
425448 content = response .choices [0 ].message .content
426449 elif model == "gemini-pro-vision" :
427450 model = genai .GenerativeModel ("gemini-pro-vision" )
428- response = model .generate_content ([accurate_vision_prompt , Image .open (new_screenshot_filename )])
451+ response = model .generate_content (
452+ [accurate_vision_prompt , Image .open (new_screenshot_filename )]
453+ )
429454 content = response .text [1 :]
430455 print (content )
431456 return content
@@ -510,7 +535,9 @@ def get_next_action_from_openai(messages, objective, accurate_mode):
510535 print (
511536 f"Previous coords before accurate tuning: prev_x { prev_x } prev_y { prev_y } "
512537 )
513- content = accurate_mode_double_check ("gpt-4-vision-preview" , pseudo_messages , prev_x , prev_y )
538+ content = accurate_mode_double_check (
539+ "gpt-4-vision-preview" , pseudo_messages , prev_x , prev_y
540+ )
514541 assert content != "ERROR" , "ERROR: accurate_mode_double_check failed"
515542
516543 return content
@@ -519,10 +546,13 @@ def get_next_action_from_openai(messages, objective, accurate_mode):
519546 print (f"Error parsing JSON: { e } " )
520547 return "Failed take action after looking at the screenshot"
521548
549+
522550def get_next_action_from_gemini_pro_vision (messages , objective , accurate_mode ):
523551 """
524552 Get the next action for Self-Operating Computer using Gemini Pro Vision
525553 """
554+ print ("[get_next_action_from_gemini_pro_vision] " )
555+ print ("[get_next_action_from_gemini_pro_vision] messages" , messages )
526556 # sleep for a second
527557 time .sleep (1 )
528558 try :
@@ -550,12 +580,19 @@ def get_next_action_from_gemini_pro_vision(messages, objective, accurate_mode):
550580 vision_prompt = format_vision_prompt (objective , previous_action )
551581
552582 model = genai .GenerativeModel ("gemini-pro-vision" )
583+ print ("[get_next_action_from_gemini_pro_vision] model.generate_content" )
553584
554- response = model .generate_content ([vision_prompt , Image .open (new_screenshot_filename )])
585+ response = model .generate_content (
586+ [vision_prompt , Image .open (new_screenshot_filename )]
587+ )
555588
556589 # create a copy of messages and save to pseudo_messages
557590 pseudo_messages = messages .copy ()
558591 pseudo_messages .append (response .text )
592+ print (
593+ "[get_next_action_from_gemini_pro_vision] pseudo_messages.append(response.text)" ,
594+ response .text ,
595+ )
559596
560597 messages .append (
561598 {
@@ -564,22 +601,6 @@ def get_next_action_from_gemini_pro_vision(messages, objective, accurate_mode):
564601 }
565602 )
566603 content = response .text [1 :]
567- print (content )
568- if accurate_mode :
569- if content .startswith ("CLICK" ):
570- # Adjust pseudo_messages to include the accurate_mode_message
571-
572- click_data = re .search (r"CLICK \{ (.+) \}" , content ).group (1 )
573- click_data_json = json .loads (f"{{{ click_data } }}" )
574- prev_x = click_data_json ["x" ]
575- prev_y = click_data_json ["y" ]
576-
577- if DEBUG :
578- print (
579- f"Previous coords before accurate tuning: prev_x { prev_x } prev_y { prev_y } "
580- )
581- content = accurate_mode_double_check ("gemini-pro-vision" , pseudo_messages , prev_x , prev_y )
582- assert content != "ERROR" , "ERROR: accurate_mode_double_check failed"
583604
584605 return content
585606
@@ -588,7 +609,8 @@ def get_next_action_from_gemini_pro_vision(messages, objective, accurate_mode):
588609 return "Failed take action after looking at the screenshot"
589610
590611
591- def parse_oai_response (response ):
612+ def parse_response (response ):
613+ print ("[parse_response] response" , response )
592614 if response == "DONE" :
593615 return {"type" : "DONE" , "data" : None }
594616 elif response .startswith ("CLICK" ):
@@ -600,7 +622,7 @@ def parse_oai_response(response):
600622 elif response .startswith ("TYPE" ):
601623 # Extract the text to type
602624 try :
603- type_data = re .search (r' TYPE (.+)' , response , re .DOTALL ).group (1 )
625+ type_data = re .search (r" TYPE (.+)" , response , re .DOTALL ).group (1 )
604626 except :
605627 type_data = re .search (r'TYPE "(.+)"' , response , re .DOTALL ).group (1 )
606628 return {"type" : "TYPE" , "data" : type_data }
@@ -610,7 +632,7 @@ def parse_oai_response(response):
610632 try :
611633 search_data = re .search (r'SEARCH "(.+)"' , response ).group (1 )
612634 except :
613- search_data = re .search (r' SEARCH (.+)' , response ).group (1 )
635+ search_data = re .search (r" SEARCH (.+)" , response ).group (1 )
614636 return {"type" : "SEARCH" , "data" : search_data }
615637
616638 return {"type" : "UNKNOWN" , "data" : response }
0 commit comments