1414import Xlib .display
1515import Xlib .X
1616import Xlib .Xutil # not sure if Xutil is necessary
17-
17+ import google . generativeai as genai
1818from prompt_toolkit import prompt
1919from prompt_toolkit .shortcuts import message_dialog
2020from prompt_toolkit .styles import Style as PromptStyle
2929
3030DEBUG = False
3131
32- client = OpenAI ()
33- client .api_key = os .getenv ("OPENAI_API_KEY" )
34- client .base_url = os .getenv ("OPENAI_API_BASE_URL" , client .base_url )
32+ OPENAI_API_KEY = os .getenv ("OPENAI_API_KEY" )
33+ GOOGLE_API_KEY = os .getenv ("GOOGLE_API_KEY" )
34+
35+ if OPENAI_API_KEY :
36+ client = OpenAI ()
37+ client .api_key = OPENAI_API_KEY
38+ client .base_url = os .getenv ("OPENAI_API_BASE_URL" , client .base_url )
39+
3540
3641monitor_size = {
3742 "width" : 1920 ,
@@ -194,12 +199,37 @@ def supports_ansi():
194199 ANSI_BRIGHT_MAGENTA = ""
195200
196201
202+ def validation (
203+ model ,
204+ accurate_mode ,
205+ voice_mode ,
206+ ):
207+ if accurate_mode and model != "gpt-4-vision-preview" :
208+ print ("To use accuracy mode, please use gpt-4-vision-preview" )
209+ sys .exit (1 )
210+
211+ if voice_mode and not OPENAI_API_KEY :
212+ print ("To use voice mode, please add an OpenAI API key" )
213+ sys .exit (1 )
214+
215+ if model == "gpt-4-vision-preview" and not OPENAI_API_KEY :
216+ print ("To use `gpt-4-vision-preview` add an OpenAI API key" )
217+ sys .exit (1 )
218+
219+ if model == "gemini-pro-vision" and not GOOGLE_API_KEY :
220+ print ("To use `gemini-pro-vision` add a Google API key" )
221+ sys .exit (1 )
222+
223+
197224def main (model , accurate_mode , terminal_prompt , voice_mode = False ):
198225 """
199226 Main function for the Self-Operating Computer
200227 """
201228 mic = None
202229 # Initialize WhisperMic if voice_mode is True if voice_mode is True
230+
231+ validation (model , accurate_mode , voice_mode )
232+
203233 if voice_mode :
204234 try :
205235 from whisper_mic import WhisperMic
@@ -259,7 +289,8 @@ def main(model, accurate_mode, terminal_prompt, voice_mode=False):
259289 print ("[loop] messages before next action:\n \n \n " , messages [1 :])
260290 try :
261291 response = get_next_action (model , messages , objective , accurate_mode )
262- action = parse_oai_response (response )
292+
293+ action = parse_response (response )
263294 action_type = action .get ("type" )
264295 action_detail = action .get ("data" )
265296
@@ -358,6 +389,11 @@ def get_next_action(model, messages, objective, accurate_mode):
358389 return content
359390 elif model == "agent-1" :
360391 return "coming soon"
392+ elif model == "gemini-pro-vision" :
393+ content = get_next_action_from_gemini_pro_vision (
394+ messages , objective , accurate_mode
395+ )
396+ return content
361397
362398 raise ModelNotRecognizedException (model )
363399
@@ -376,10 +412,11 @@ def get_last_assistant_message(messages):
376412 return None # Return None if no assistant message is found
377413
378414
379- def accurate_mode_double_check (pseudo_messages , prev_x , prev_y ):
415+ def accurate_mode_double_check (model , pseudo_messages , prev_x , prev_y ):
380416 """
381417 Reprompt OAI with additional screenshot of a mini screenshot centered around the cursor for further finetuning of clicked location
382418 """
419+ print ("[get_next_action_from_gemini_pro_vision] accurate_mode_double_check" )
383420 try :
384421 screenshot_filename = os .path .join ("screenshots" , "screenshot_mini.png" )
385422 capture_mini_screenshot_with_cursor (
@@ -394,31 +431,37 @@ def accurate_mode_double_check(pseudo_messages, prev_x, prev_y):
394431 img_base64 = base64 .b64encode (img_file .read ()).decode ("utf-8" )
395432
396433 accurate_vision_prompt = format_accurate_mode_vision_prompt (prev_x , prev_y )
434+ if model == "gpt-4-vision-preview" :
435+ accurate_mode_message = {
436+ "role" : "user" ,
437+ "content" : [
438+ {"type" : "text" , "text" : accurate_vision_prompt },
439+ {
440+ "type" : "image_url" ,
441+ "image_url" : {"url" : f"data:image/jpeg;base64,{ img_base64 } " },
442+ },
443+ ],
444+ }
397445
398- accurate_mode_message = {
399- "role" : "user" ,
400- "content" : [
401- {"type" : "text" , "text" : accurate_vision_prompt },
402- {
403- "type" : "image_url" ,
404- "image_url" : {"url" : f"data:image/jpeg;base64,{ img_base64 } " },
405- },
406- ],
407- }
408-
409- pseudo_messages .append (accurate_mode_message )
410-
411- response = client .chat .completions .create (
412- model = "gpt-4-vision-preview" ,
413- messages = pseudo_messages ,
414- presence_penalty = 1 ,
415- frequency_penalty = 1 ,
416- temperature = 0.7 ,
417- max_tokens = 300 ,
418- )
446+ pseudo_messages .append (accurate_mode_message )
419447
420- content = response .choices [0 ].message .content
448+ response = client .chat .completions .create (
449+ model = "gpt-4-vision-preview" ,
450+ messages = pseudo_messages ,
451+ presence_penalty = 1 ,
452+ frequency_penalty = 1 ,
453+ temperature = 0.7 ,
454+ max_tokens = 300 ,
455+ )
421456
457+ content = response .choices [0 ].message .content
458+ elif model == "gemini-pro-vision" :
459+ model = genai .GenerativeModel ("gemini-pro-vision" )
460+ response = model .generate_content (
461+ [accurate_vision_prompt , Image .open (new_screenshot_filename )]
462+ )
463+ content = response .text [1 :]
464+ print (content )
422465 return content
423466 except Exception as e :
424467 print (f"Error reprompting model for accurate_mode: { e } " )
@@ -501,7 +544,9 @@ def get_next_action_from_openai(messages, objective, accurate_mode):
501544 print (
502545 f"Previous coords before accurate tuning: prev_x { prev_x } prev_y { prev_y } "
503546 )
504- content = accurate_mode_double_check (pseudo_messages , prev_x , prev_y )
547+ content = accurate_mode_double_check (
548+ "gpt-4-vision-preview" , pseudo_messages , prev_x , prev_y
549+ )
505550 assert content != "ERROR" , "ERROR: accurate_mode_double_check failed"
506551
507552 return content
@@ -511,7 +556,69 @@ def get_next_action_from_openai(messages, objective, accurate_mode):
511556 return "Failed take action after looking at the screenshot"
512557
513558
514- def parse_oai_response (response ):
559+ def get_next_action_from_gemini_pro_vision (messages , objective , accurate_mode ):
560+ """
561+ Get the next action for Self-Operating Computer using Gemini Pro Vision
562+ """
563+ print ("[get_next_action_from_gemini_pro_vision] messages" , messages )
564+ # sleep for a second
565+ time .sleep (1 )
566+ try :
567+ screenshots_dir = "screenshots"
568+ if not os .path .exists (screenshots_dir ):
569+ os .makedirs (screenshots_dir )
570+
571+ screenshot_filename = os .path .join (screenshots_dir , "screenshot.png" )
572+ # Call the function to capture the screen with the cursor
573+ capture_screen_with_cursor (screenshot_filename )
574+
575+ new_screenshot_filename = os .path .join (
576+ "screenshots" , "screenshot_with_grid.png"
577+ )
578+
579+ add_grid_to_image (screenshot_filename , new_screenshot_filename , 500 )
580+ # sleep for a second
581+ time .sleep (1 )
582+
583+ with open (new_screenshot_filename , "rb" ) as img_file :
584+ img_base64 = base64 .b64encode (img_file .read ()).decode ("utf-8" )
585+
586+ previous_action = get_last_assistant_message (messages )
587+
588+ vision_prompt = format_vision_prompt (objective , previous_action )
589+
590+ model = genai .GenerativeModel ("gemini-pro-vision" )
591+ print ("[get_next_action_from_gemini_pro_vision] model.generate_content" )
592+
593+ response = model .generate_content (
594+ [vision_prompt , Image .open (new_screenshot_filename )]
595+ )
596+
597+ # create a copy of messages and save to pseudo_messages
598+ pseudo_messages = messages .copy ()
599+ pseudo_messages .append (response .text )
600+ print (
601+ "[get_next_action_from_gemini_pro_vision] pseudo_messages.append(response.text)" ,
602+ response .text ,
603+ )
604+
605+ messages .append (
606+ {
607+ "role" : "user" ,
608+ "content" : "`screenshot.png`" ,
609+ }
610+ )
611+ content = response .text [1 :]
612+
613+ return content
614+
615+ except Exception as e :
616+ print (f"Error parsing JSON: { e } " )
617+ return "Failed take action after looking at the screenshot"
618+
619+
620+ def parse_response (response ):
621+ print ("[parse_response] response" , response )
515622 if response == "DONE" :
516623 return {"type" : "DONE" , "data" : None }
517624 elif response .startswith ("CLICK" ):
@@ -522,12 +629,18 @@ def parse_oai_response(response):
522629
523630 elif response .startswith ("TYPE" ):
524631 # Extract the text to type
525- type_data = re .search (r'TYPE "(.+)"' , response , re .DOTALL ).group (1 )
632+ try :
633+ type_data = re .search (r"TYPE (.+)" , response , re .DOTALL ).group (1 )
634+ except :
635+ type_data = re .search (r'TYPE "(.+)"' , response , re .DOTALL ).group (1 )
526636 return {"type" : "TYPE" , "data" : type_data }
527637
528638 elif response .startswith ("SEARCH" ):
529639 # Extract the search query
530- search_data = re .search (r'SEARCH "(.+)"' , response ).group (1 )
640+ try :
641+ search_data = re .search (r'SEARCH "(.+)"' , response ).group (1 )
642+ except :
643+ search_data = re .search (r"SEARCH (.+)" , response ).group (1 )
531644 return {"type" : "SEARCH" , "data" : search_data }
532645
533646 return {"type" : "UNKNOWN" , "data" : response }
0 commit comments