Skip to content

Commit b816047

Browse files
authored
Merge pull request #110 from linusaltacc/gemini-pro-vision-support
Added Gemini Pro Vision Support to Self Operating Computer
2 parents 86639de + 154ad21 commit b816047

File tree

4 files changed

+155
-34
lines changed

4 files changed

+155
-34
lines changed

.example.env

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
1-
OPENAI_API_KEY='your-key-here'
1+
OPENAI_API_KEY='your-key-here'
2+
GOOGLE_API_KEY='your-key-here'

README.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
## Key Features
2222
- **Compatibility**: Designed for various multimodal models.
23-
- **Integration**: Currently integrated with **GPT-4v** as the default model.
23+
- **Integration**: Currently integrated with **GPT-4v** as the default model, with extended support for Gemini Pro Vision.
2424
- **Future Plans**: Support for additional models.
2525

2626
## Current Challenges
@@ -75,6 +75,12 @@ mv .example.env .env
7575
```
7676
OPENAI_API_KEY='your-key-here'
7777
```
78+
OR
79+
80+
**Add your Google AI Studio API key to your new .env file. If you don't have one, you can obtain a key [here](https://makersuite.google.com/app/apikey) after setting up your Google AI Studio account**:
81+
```
82+
GOOGLE_API_KEY='your-key-here'
83+
```
7884
8. **Run it**!
7985
```
8086
operate

operate/main.py

Lines changed: 145 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import Xlib.display
1515
import Xlib.X
1616
import Xlib.Xutil # not sure if Xutil is necessary
17-
17+
import google.generativeai as genai
1818
from prompt_toolkit import prompt
1919
from prompt_toolkit.shortcuts import message_dialog
2020
from prompt_toolkit.styles import Style as PromptStyle
@@ -29,9 +29,14 @@
2929

3030
DEBUG = False
3131

32-
client = OpenAI()
33-
client.api_key = os.getenv("OPENAI_API_KEY")
34-
client.base_url = os.getenv("OPENAI_API_BASE_URL", client.base_url)
32+
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
33+
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
34+
35+
if OPENAI_API_KEY:
36+
client = OpenAI()
37+
client.api_key = OPENAI_API_KEY
38+
client.base_url = os.getenv("OPENAI_API_BASE_URL", client.base_url)
39+
3540

3641
monitor_size = {
3742
"width": 1920,
@@ -194,12 +199,37 @@ def supports_ansi():
194199
ANSI_BRIGHT_MAGENTA = ""
195200

196201

202+
def validation(
203+
model,
204+
accurate_mode,
205+
voice_mode,
206+
):
207+
if accurate_mode and model != "gpt-4-vision-preview":
208+
print("To use accuracy mode, please use gpt-4-vision-preview")
209+
sys.exit(1)
210+
211+
if voice_mode and not OPENAI_API_KEY:
212+
print("To use voice mode, please add an OpenAI API key")
213+
sys.exit(1)
214+
215+
if model == "gpt-4-vision-preview" and not OPENAI_API_KEY:
216+
print("To use `gpt-4-vision-preview` add an OpenAI API key")
217+
sys.exit(1)
218+
219+
if model == "gemini-pro-vision" and not GOOGLE_API_KEY:
220+
print("To use `gemini-pro-vision` add a Google API key")
221+
sys.exit(1)
222+
223+
197224
def main(model, accurate_mode, terminal_prompt, voice_mode=False):
198225
"""
199226
Main function for the Self-Operating Computer
200227
"""
201228
mic = None
202229
# Initialize WhisperMic if voice_mode is True if voice_mode is True
230+
231+
validation(model, accurate_mode, voice_mode)
232+
203233
if voice_mode:
204234
try:
205235
from whisper_mic import WhisperMic
@@ -259,7 +289,8 @@ def main(model, accurate_mode, terminal_prompt, voice_mode=False):
259289
print("[loop] messages before next action:\n\n\n", messages[1:])
260290
try:
261291
response = get_next_action(model, messages, objective, accurate_mode)
262-
action = parse_oai_response(response)
292+
293+
action = parse_response(response)
263294
action_type = action.get("type")
264295
action_detail = action.get("data")
265296

@@ -358,6 +389,11 @@ def get_next_action(model, messages, objective, accurate_mode):
358389
return content
359390
elif model == "agent-1":
360391
return "coming soon"
392+
elif model == "gemini-pro-vision":
393+
content = get_next_action_from_gemini_pro_vision(
394+
messages, objective, accurate_mode
395+
)
396+
return content
361397

362398
raise ModelNotRecognizedException(model)
363399

@@ -376,10 +412,11 @@ def get_last_assistant_message(messages):
376412
return None # Return None if no assistant message is found
377413

378414

379-
def accurate_mode_double_check(pseudo_messages, prev_x, prev_y):
415+
def accurate_mode_double_check(model, pseudo_messages, prev_x, prev_y):
380416
"""
381417
Reprompt OAI with additional screenshot of a mini screenshot centered around the cursor for further finetuning of clicked location
382418
"""
419+
print("[get_next_action_from_gemini_pro_vision] accurate_mode_double_check")
383420
try:
384421
screenshot_filename = os.path.join("screenshots", "screenshot_mini.png")
385422
capture_mini_screenshot_with_cursor(
@@ -394,31 +431,37 @@ def accurate_mode_double_check(pseudo_messages, prev_x, prev_y):
394431
img_base64 = base64.b64encode(img_file.read()).decode("utf-8")
395432

396433
accurate_vision_prompt = format_accurate_mode_vision_prompt(prev_x, prev_y)
434+
if model == "gpt-4-vision-preview":
435+
accurate_mode_message = {
436+
"role": "user",
437+
"content": [
438+
{"type": "text", "text": accurate_vision_prompt},
439+
{
440+
"type": "image_url",
441+
"image_url": {"url": f"data:image/jpeg;base64,{img_base64}"},
442+
},
443+
],
444+
}
397445

398-
accurate_mode_message = {
399-
"role": "user",
400-
"content": [
401-
{"type": "text", "text": accurate_vision_prompt},
402-
{
403-
"type": "image_url",
404-
"image_url": {"url": f"data:image/jpeg;base64,{img_base64}"},
405-
},
406-
],
407-
}
408-
409-
pseudo_messages.append(accurate_mode_message)
410-
411-
response = client.chat.completions.create(
412-
model="gpt-4-vision-preview",
413-
messages=pseudo_messages,
414-
presence_penalty=1,
415-
frequency_penalty=1,
416-
temperature=0.7,
417-
max_tokens=300,
418-
)
446+
pseudo_messages.append(accurate_mode_message)
419447

420-
content = response.choices[0].message.content
448+
response = client.chat.completions.create(
449+
model="gpt-4-vision-preview",
450+
messages=pseudo_messages,
451+
presence_penalty=1,
452+
frequency_penalty=1,
453+
temperature=0.7,
454+
max_tokens=300,
455+
)
421456

457+
content = response.choices[0].message.content
458+
elif model == "gemini-pro-vision":
459+
model = genai.GenerativeModel("gemini-pro-vision")
460+
response = model.generate_content(
461+
[accurate_vision_prompt, Image.open(new_screenshot_filename)]
462+
)
463+
content = response.text[1:]
464+
print(content)
422465
return content
423466
except Exception as e:
424467
print(f"Error reprompting model for accurate_mode: {e}")
@@ -501,7 +544,9 @@ def get_next_action_from_openai(messages, objective, accurate_mode):
501544
print(
502545
f"Previous coords before accurate tuning: prev_x {prev_x} prev_y {prev_y}"
503546
)
504-
content = accurate_mode_double_check(pseudo_messages, prev_x, prev_y)
547+
content = accurate_mode_double_check(
548+
"gpt-4-vision-preview", pseudo_messages, prev_x, prev_y
549+
)
505550
assert content != "ERROR", "ERROR: accurate_mode_double_check failed"
506551

507552
return content
@@ -511,7 +556,69 @@ def get_next_action_from_openai(messages, objective, accurate_mode):
511556
return "Failed take action after looking at the screenshot"
512557

513558

514-
def parse_oai_response(response):
559+
def get_next_action_from_gemini_pro_vision(messages, objective, accurate_mode):
560+
"""
561+
Get the next action for Self-Operating Computer using Gemini Pro Vision
562+
"""
563+
print("[get_next_action_from_gemini_pro_vision] messages", messages)
564+
# sleep for a second
565+
time.sleep(1)
566+
try:
567+
screenshots_dir = "screenshots"
568+
if not os.path.exists(screenshots_dir):
569+
os.makedirs(screenshots_dir)
570+
571+
screenshot_filename = os.path.join(screenshots_dir, "screenshot.png")
572+
# Call the function to capture the screen with the cursor
573+
capture_screen_with_cursor(screenshot_filename)
574+
575+
new_screenshot_filename = os.path.join(
576+
"screenshots", "screenshot_with_grid.png"
577+
)
578+
579+
add_grid_to_image(screenshot_filename, new_screenshot_filename, 500)
580+
# sleep for a second
581+
time.sleep(1)
582+
583+
with open(new_screenshot_filename, "rb") as img_file:
584+
img_base64 = base64.b64encode(img_file.read()).decode("utf-8")
585+
586+
previous_action = get_last_assistant_message(messages)
587+
588+
vision_prompt = format_vision_prompt(objective, previous_action)
589+
590+
model = genai.GenerativeModel("gemini-pro-vision")
591+
print("[get_next_action_from_gemini_pro_vision] model.generate_content")
592+
593+
response = model.generate_content(
594+
[vision_prompt, Image.open(new_screenshot_filename)]
595+
)
596+
597+
# create a copy of messages and save to pseudo_messages
598+
pseudo_messages = messages.copy()
599+
pseudo_messages.append(response.text)
600+
print(
601+
"[get_next_action_from_gemini_pro_vision] pseudo_messages.append(response.text)",
602+
response.text,
603+
)
604+
605+
messages.append(
606+
{
607+
"role": "user",
608+
"content": "`screenshot.png`",
609+
}
610+
)
611+
content = response.text[1:]
612+
613+
return content
614+
615+
except Exception as e:
616+
print(f"Error parsing JSON: {e}")
617+
return "Failed take action after looking at the screenshot"
618+
619+
620+
def parse_response(response):
621+
print("[parse_response] response", response)
515622
if response == "DONE":
516623
return {"type": "DONE", "data": None}
517624
elif response.startswith("CLICK"):
@@ -522,12 +629,18 @@ def parse_oai_response(response):
522629

523630
elif response.startswith("TYPE"):
524631
# Extract the text to type
525-
type_data = re.search(r'TYPE "(.+)"', response, re.DOTALL).group(1)
632+
try:
633+
type_data = re.search(r"TYPE (.+)", response, re.DOTALL).group(1)
634+
except:
635+
type_data = re.search(r'TYPE "(.+)"', response, re.DOTALL).group(1)
526636
return {"type": "TYPE", "data": type_data}
527637

528638
elif response.startswith("SEARCH"):
529639
# Extract the search query
530-
search_data = re.search(r'SEARCH "(.+)"', response).group(1)
640+
try:
641+
search_data = re.search(r'SEARCH "(.+)"', response).group(1)
642+
except:
643+
search_data = re.search(r"SEARCH (.+)", response).group(1)
531644
return {"type": "SEARCH", "data": search_data}
532645

533646
return {"type": "UNKNOWN", "data": response}

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,3 +47,4 @@ typing_extensions==4.8.0
4747
urllib3==2.0.7
4848
wcwidth==0.2.9
4949
zipp==3.17.0
50+
google-generativeai==0.3.0

0 commit comments

Comments
 (0)