Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

I went ahead and fixed the setup to work on MacOS Retina (tested M1) #10

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# Computer Use - OOTB
# Computer Use - OOTB MacOS working version (tested on m1 retina)

## 🌟 Overview
This is an out-of-the-box (OOTB) solution for Claude's new Computer Use APIs.

MacOS working version (tested on m1 retina)

**No Docker** is required, and it theoretically supports **any platform**, with testing currently done on **Windows**. This project provides a user-friendly interface based on Gradio. 🎨

## Update
Expand Down Expand Up @@ -74,7 +76,7 @@ Desktop Interface
- [ ] **Platform**
- [x] **Windows**
- [x] **Mobile** (Send command)
- [ ] **Mac**
- [x] **Mac**
- [ ] **Mobile** (Be controlled)
- [ ] **Support for More MLLMs**
- [x] **Claude 3.5 Sonnet** 🎵
Expand Down
2 changes: 2 additions & 0 deletions computer_use_demo/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ def setup_state(state):
state["api_key"] = load_from_storage("api_key") or os.getenv("ANTHROPIC_API_KEY", "")
if not state["api_key"]:
print("API key not found. Please set it in the environment or storage.")
else:
print(f"API key loaded: {state['api_key'][:5]}...{state['api_key'][-5:]}")
if "provider" not in state:
state["provider"] = os.getenv("API_PROVIDER", "anthropic") or APIProvider.ANTHROPIC
if "provider_radio" not in state:
Expand Down
65 changes: 43 additions & 22 deletions computer_use_demo/tools/computer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from pathlib import Path
from typing import Literal, TypedDict
from uuid import uuid4
import io
from PIL import Image

from anthropic.types.beta import BetaToolComputerUse20241022Param

Expand Down Expand Up @@ -186,21 +188,18 @@ async def __call__(

raise ToolError(f"Invalid action: {action}")

async def screenshot(self):
"""Take a screenshot of the current screen and return a ToolResult with the base64 encoded image."""
output_dir = Path(OUTPUT_DIR)
output_dir.mkdir(parents=True, exist_ok=True)
path = output_dir / f"screenshot_{uuid4().hex}.png"

# Take screenshot using pyautogui
async def screenshot(self) -> ToolResult:
screenshot = pyautogui.screenshot()
screenshot.save(str(path))
img_byte_arr = io.BytesIO()
screenshot.save(img_byte_arr, format='PNG')
img_byte_arr = img_byte_arr.getvalue()

if path.exists():
# Return a ToolResult instance instead of a dictionary
return ToolResult(base64_image=base64.b64encode(path.read_bytes()).decode())

raise ToolError(f"Failed to take screenshot: {path} does not exist.")
# Resize the image if it's too large
if len(img_byte_arr) > 5 * 1024 * 1024:
img_byte_arr = self.resize_image(img_byte_arr)

base64_image = base64.b64encode(img_byte_arr).decode('utf-8')
return ToolResult(output="Screenshot taken", base64_image=base64_image)

async def shell(self, command: str, take_screenshot=True) -> ToolResult:
"""Run a shell command and return the output, error, and optionally a screenshot."""
Expand Down Expand Up @@ -240,11 +239,20 @@ def scale_coordinates(self, source: ScalingSource, x: int, y: int):
return round(x * x_scaling_factor), round(y * y_scaling_factor)

def get_screen_size(self):
if platform.system() == "Windows":
# Command to get screen resolution on Windows
if platform.system() == "Darwin": # macOS
try:
output = subprocess.check_output(["system_profiler", "SPDisplaysDataType"]).decode('utf-8')
for line in output.split('\n'):
if "Resolution" in line:
resolution = line.split(':')[1].strip()
width, height = map(lambda x: int(x.split()[0]), resolution.split(' x '))
return width, height
except Exception as e:
print(f"Error getting screen size: {e}")
return 1920, 1080 # Default fallback resolution
elif platform.system() == "Windows":
# Keep existing Windows code
cmd = "wmic path Win32_VideoController get CurrentHorizontalResolution,CurrentVerticalResolution"
elif platform.system() == "Darwin": # macOS
cmd = "system_profiler SPDisplaysDataType | grep Resolution"
else: # Linux or other OS
cmd = "xrandr | grep '*' | awk '{print $1}'"

Expand All @@ -254,9 +262,6 @@ def get_screen_size(self):
if platform.system() == "Windows":
lines = output.strip().split('\n')[1:] # Skip the header
width, height = map(int, lines[0].split())
elif platform.system() == "Darwin":
resolution = output.split()[0]
width, height = map(int, resolution.split('x'))
else:
resolution = output.strip().split()[0]
width, height = map(int, resolution.split('x'))
Expand All @@ -265,7 +270,7 @@ def get_screen_size(self):

except subprocess.CalledProcessError as e:
print(f"Error occurred: {e}")
return None, None # Return None or some default values
return 1920, 1080 # Default fallback resolution


def get_mouse_position(self):
Expand All @@ -281,4 +286,20 @@ def map_keys(self, text: str):
"""Map text to cliclick key codes if necessary."""
# For simplicity, return text as is
# Implement mapping if special keys are needed
return text
return text

def resize_image(self, image_data: bytes, max_size: int = 5 * 1024 * 1024) -> bytes:
img = Image.open(io.BytesIO(image_data))

# Calculate the scaling factor
current_size = len(image_data)
scale_factor = (max_size / current_size) ** 0.5

# Resize the image
new_size = (int(img.width * scale_factor), int(img.height * scale_factor))
img = img.resize(new_size, Image.LANCZOS)

# Save the resized image to a bytes buffer
buffer = io.BytesIO()
img.save(buffer, format="PNG", optimize=True)
return buffer.getvalue()