Skip to content

Commit 468e2bd

Browse files
committed
Add validate_mode
1 parent fe53d76 commit 468e2bd

File tree

1 file changed

+48
-26
lines changed

1 file changed

+48
-26
lines changed

operate/main.py

Lines changed: 48 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929

3030
DEBUG = False
3131

32+
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
33+
3234
monitor_size = {
3335
"width": 1920,
3436
"height": 1080,
@@ -190,18 +192,35 @@ def supports_ansi():
190192
ANSI_BRIGHT_MAGENTA = ""
191193

192194

195+
def validate_mode(
196+
model,
197+
accurate_mode,
198+
voice_mode,
199+
):
200+
if accurate_mode and model != "gpt-4-vision-preview":
201+
print("To use accuracy mode, please use gpt-4-vision-preview")
202+
sys.exit(1)
203+
204+
if voice_mode and not OPENAI_API_KEY:
205+
print("To use voice mode, please add an OpenAI API key")
206+
sys.exit(1)
207+
208+
193209
def main(model, accurate_mode, terminal_prompt, voice_mode=False):
194210
"""
195211
Main function for the Self-Operating Computer
196212
"""
197213
if model == "gpt-4-vision-preview":
198214
client = OpenAI()
199-
client.api_key = os.getenv("OPENAI_API_KEY")
215+
client.api_key = OPENAI_API_KEY
200216
client.base_url = os.getenv("OPENAI_API_BASE_URL", client.base_url)
201217
elif model == "gemini-pro-vision":
202-
GOOGLE_API_KEY=os.getenv('GOOGLE_API_KEY')
218+
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
203219
mic = None
204220
# Initialize WhisperMic if voice_mode is True if voice_mode is True
221+
222+
validate_mode(model, accurate_mode, voice_mode)
223+
205224
if voice_mode:
206225
try:
207226
from whisper_mic import WhisperMic
@@ -261,7 +280,8 @@ def main(model, accurate_mode, terminal_prompt, voice_mode=False):
261280
print("[loop] messages before next action:\n\n\n", messages[1:])
262281
try:
263282
response = get_next_action(model, messages, objective, accurate_mode)
264-
action = parse_oai_response(response)
283+
284+
action = parse_response(response)
265285
action_type = action.get("type")
266286
action_detail = action.get("data")
267287

@@ -361,7 +381,9 @@ def get_next_action(model, messages, objective, accurate_mode):
361381
elif model == "agent-1":
362382
return "coming soon"
363383
elif model == "gemini-pro-vision":
364-
content = get_next_action_from_gemini_pro_vision(messages, objective, accurate_mode)
384+
content = get_next_action_from_gemini_pro_vision(
385+
messages, objective, accurate_mode
386+
)
365387
return content
366388

367389
raise ModelNotRecognizedException(model)
@@ -385,6 +407,7 @@ def accurate_mode_double_check(model, pseudo_messages, prev_x, prev_y):
385407
"""
386408
Reprompt OAI with additional screenshot of a mini screenshot centered around the cursor for further finetuning of clicked location
387409
"""
410+
print("[get_next_action_from_gemini_pro_vision] accurate_mode_double_check")
388411
try:
389412
screenshot_filename = os.path.join("screenshots", "screenshot_mini.png")
390413
capture_mini_screenshot_with_cursor(
@@ -425,7 +448,9 @@ def accurate_mode_double_check(model, pseudo_messages, prev_x, prev_y):
425448
content = response.choices[0].message.content
426449
elif model == "gemini-pro-vision":
427450
model = genai.GenerativeModel("gemini-pro-vision")
428-
response = model.generate_content([accurate_vision_prompt, Image.open(new_screenshot_filename)])
451+
response = model.generate_content(
452+
[accurate_vision_prompt, Image.open(new_screenshot_filename)]
453+
)
429454
content = response.text[1:]
430455
print(content)
431456
return content
@@ -510,7 +535,9 @@ def get_next_action_from_openai(messages, objective, accurate_mode):
510535
print(
511536
f"Previous coords before accurate tuning: prev_x {prev_x} prev_y {prev_y}"
512537
)
513-
content = accurate_mode_double_check("gpt-4-vision-preview", pseudo_messages, prev_x, prev_y)
538+
content = accurate_mode_double_check(
539+
"gpt-4-vision-preview", pseudo_messages, prev_x, prev_y
540+
)
514541
assert content != "ERROR", "ERROR: accurate_mode_double_check failed"
515542

516543
return content
@@ -519,10 +546,13 @@ def get_next_action_from_openai(messages, objective, accurate_mode):
519546
print(f"Error parsing JSON: {e}")
520547
return "Failed take action after looking at the screenshot"
521548

549+
522550
def get_next_action_from_gemini_pro_vision(messages, objective, accurate_mode):
523551
"""
524552
Get the next action for Self-Operating Computer using Gemini Pro Vision
525553
"""
554+
print("[get_next_action_from_gemini_pro_vision] ")
555+
print("[get_next_action_from_gemini_pro_vision] messages", messages)
526556
# sleep for a second
527557
time.sleep(1)
528558
try:
@@ -550,12 +580,19 @@ def get_next_action_from_gemini_pro_vision(messages, objective, accurate_mode):
550580
vision_prompt = format_vision_prompt(objective, previous_action)
551581

552582
model = genai.GenerativeModel("gemini-pro-vision")
583+
print("[get_next_action_from_gemini_pro_vision] model.generate_content")
553584

554-
response = model.generate_content([vision_prompt, Image.open(new_screenshot_filename)])
585+
response = model.generate_content(
586+
[vision_prompt, Image.open(new_screenshot_filename)]
587+
)
555588

556589
# create a copy of messages and save to pseudo_messages
557590
pseudo_messages = messages.copy()
558591
pseudo_messages.append(response.text)
592+
print(
593+
"[get_next_action_from_gemini_pro_vision] pseudo_messages.append(response.text)",
594+
response.text,
595+
)
559596

560597
messages.append(
561598
{
@@ -564,22 +601,6 @@ def get_next_action_from_gemini_pro_vision(messages, objective, accurate_mode):
564601
}
565602
)
566603
content = response.text[1:]
567-
print(content)
568-
if accurate_mode:
569-
if content.startswith("CLICK"):
570-
# Adjust pseudo_messages to include the accurate_mode_message
571-
572-
click_data = re.search(r"CLICK \{ (.+) \}", content).group(1)
573-
click_data_json = json.loads(f"{{{click_data}}}")
574-
prev_x = click_data_json["x"]
575-
prev_y = click_data_json["y"]
576-
577-
if DEBUG:
578-
print(
579-
f"Previous coords before accurate tuning: prev_x {prev_x} prev_y {prev_y}"
580-
)
581-
content = accurate_mode_double_check("gemini-pro-vision", pseudo_messages, prev_x, prev_y)
582-
assert content != "ERROR", "ERROR: accurate_mode_double_check failed"
583604

584605
return content
585606

@@ -588,7 +609,8 @@ def get_next_action_from_gemini_pro_vision(messages, objective, accurate_mode):
588609
return "Failed take action after looking at the screenshot"
589610

590611

591-
def parse_oai_response(response):
612+
def parse_response(response):
613+
print("[parse_response] response", response)
592614
if response == "DONE":
593615
return {"type": "DONE", "data": None}
594616
elif response.startswith("CLICK"):
@@ -600,7 +622,7 @@ def parse_oai_response(response):
600622
elif response.startswith("TYPE"):
601623
# Extract the text to type
602624
try:
603-
type_data = re.search(r'TYPE (.+)', response, re.DOTALL).group(1)
625+
type_data = re.search(r"TYPE (.+)", response, re.DOTALL).group(1)
604626
except:
605627
type_data = re.search(r'TYPE "(.+)"', response, re.DOTALL).group(1)
606628
return {"type": "TYPE", "data": type_data}
@@ -610,7 +632,7 @@ def parse_oai_response(response):
610632
try:
611633
search_data = re.search(r'SEARCH "(.+)"', response).group(1)
612634
except:
613-
search_data = re.search(r'SEARCH (.+)', response).group(1)
635+
search_data = re.search(r"SEARCH (.+)", response).group(1)
614636
return {"type": "SEARCH", "data": search_data}
615637

616638
return {"type": "UNKNOWN", "data": response}

0 commit comments

Comments
 (0)