diff --git a/web-agent/app/worker.py b/web-agent/app/worker.py index 2c1d9b3..672bbcb 100644 --- a/web-agent/app/worker.py +++ b/web-agent/app/worker.py @@ -1,17 +1,19 @@ #!/usr/bin/env python3 -import threading - from gevent import monkey; monkey.patch_all() +from gevent import Timeout import argparse +import atexit import base64 import gzip import json import logging import os +import random import shutil import secrets +import signal import string import tempfile import time @@ -19,13 +21,16 @@ from collections import deque from logging.handlers import TimedRotatingFileHandler from pathlib import Path -from typing import Optional, Tuple, Any, Dict, Union +from typing import Optional, Tuple, Any, Dict, Union, List, Callable +from urllib.parse import urlparse, urlunparse +import gevent import requests +from gevent.lock import Semaphore from gevent.pool import Pool # Global variables -__version__ = "1.1.8" +__version__ = "1.1.10" letters: str = string.ascii_letters rand_string: str = ''.join(secrets.choice(letters) for _ in range(10)) @@ -46,16 +51,26 @@ rate_limiter = None config_dict: dict = None +# Limit concurrent teleport calls (max 2 per worker) +teleport_semaphore: Optional[Semaphore] = None + +# Timeout for teleport operations (prevents semaphore deadlock if operation hangs) +TELEPORT_TIMEOUT = int(os.getenv('TELEPORT_TIMEOUT_SECONDS', '60')) + +# HTTP request timeout for teleport endpoints +TELEPORT_REQUEST_TIMEOUT = 30 + def main() -> None: - global config_dict, logger, rate_limiter + global config_dict, logger, rate_limiter, teleport_semaphore - # Instantiate RateLimiter for 25 requests per 15 seconds window rate_limiter = RateLimiter(request_limit=25, time_window=15) + teleport_semaphore = Semaphore(2) # Max 2 concurrent teleport calls parser = argparse.ArgumentParser() config_dict, agent_index, debug_mode = get_initial_config(parser) logger = setup_logger(agent_index, debug_mode) + logger.info("Agent Started for url %s, verify %s, timeout %s, outgoing proxy %s, inward %s, uploadToAc %s", config_dict.get('server_url'), config_dict.get('verify_cert', False), config_dict.get('timeout', 10), config_dict['outgoing_proxy'], @@ -68,10 +83,110 @@ def main() -> None: process() +def _get_task_from_server(headers: Dict[str, str], params: Dict[str, str], get_task_server_url: str) -> Tuple[requests.Response, float]: + """Execute get-task request in a separate greenlet to prevent LoopExit.""" + get_task_start_time = time.time() + + try: + with Timeout(TELEPORT_TIMEOUT): + with teleport_semaphore: + get_task_response: requests.Response = requests.get( + get_task_server_url, + headers=headers, + timeout=25, verify=config_dict.get('verify_cert', False), + proxies=config_dict['outgoing_proxy'], + params=params + ) + except Timeout: + logger.error(f"Get-task timed out after {TELEPORT_TIMEOUT}s, semaphore released") + raise + + get_task_duration_ms = (time.time() - get_task_start_time) * 1000 + return get_task_response, get_task_duration_ms + + +def is_concurrent_limit_error(response: requests.Response) -> bool: + """Check if 429 is due to concurrent request limit vs rate limit.""" + if response.status_code == 429: + try: + return "Too many concurrent requests" in response.text + except Exception: + # If response.text fails (rare), assume not concurrent error + return False + return False + + +def get_retry_delay(response: requests.Response, default_delay: int = 2) -> float: + """Get retry delay: concurrent error (0-10s) > header > default.""" + if is_concurrent_limit_error(response): + delay = random.uniform(0, 10) + logger.info(f"Concurrent limit error, random delay: {delay:.2f}s") + return delay + + retry_after = response.headers.get('X-Rate-Limit-Retry-After-Seconds') + if retry_after: + try: + delay = int(retry_after) + if delay < 0: + logger.warning(f"Negative retry delay {delay}s, using default {default_delay}s") + return default_delay + if delay > 60: + logger.warning(f"Excessive retry delay {delay}s, capping at 60s") + return 60 + logger.info(f"Using header delay: {delay}s") + return delay + except ValueError: + logger.warning(f"Invalid retry delay '{retry_after}', using default {default_delay}s") + + return default_delay + + +def retry_request( + func: Callable[[], requests.Response], + max_retries: int = 5, + operation_name: str = "request" +) -> Optional[requests.Response]: + """Retry on 429 (rate limit) or 504 (timeout). 429 uses smart delay, 504 uses exponential backoff.""" + for attempt in range(max_retries + 1): + try: + response = func() + + if response.status_code not in (429, 504): + return response + + if attempt >= max_retries: + logger.error(f"{operation_name} failed after {max_retries} retries") + return response + + # Calculate delay based on error type + if response.status_code == 429: + delay = get_retry_delay(response) + error_type = "concurrent limit" if is_concurrent_limit_error(response) else "rate limit" + else: # 504 + delay = min(1 * (2 ** attempt), 30) # Exponential: 1s, 2s, 4s, 8s, 16s, 30s + error_type = "gateway timeout" + + logger.warning(f"{operation_name} {error_type} (attempt {attempt + 1}/{max_retries + 1}), retry in {delay:.2f}s") + gevent.sleep(delay) + + except requests.exceptions.RequestException as e: + logger.error(f"{operation_name} request error: {e}") + raise + + logger.error(f"{operation_name} unexpected loop exit") + return None + + +def delayed_retry(delay_seconds: int) -> None: + """Wait by spawning timer greenlet. Keeps hub alive during main loop delays.""" + timer = gevent.spawn(lambda: gevent.sleep(delay_seconds)) + timer.join() # Wait for timer, but timer greenlet keeps hub active + + def process() -> None: headers: Dict[str, str] = _get_headers() thread_backoff_time: int = min_backoff_time - # thread_pool = Pool(config_dict['thread_pool_size']) + while True: try: # Get the next task for the agent @@ -79,27 +194,31 @@ def process() -> None: rate_limiter.throttle() params = { - 'agentId' : config_dict['agent_id'] + 'agentId': config_dict['agent_id'], + 'agentVersion': __version__ } get_task_server_url = f"{config_dict.get('server_url')}/api/http-teleport/get-task" if len(config_dict.get('env_name', '')) > 0: params['envName'] = config_dict['env_name'] logger.info("Requesting task from %s", get_task_server_url) - get_task_response: requests.Response = requests.get( - get_task_server_url, - headers=headers, - timeout=25, verify=config_dict.get('verify_cert', False), - proxies=config_dict['outgoing_proxy'], - params=params - ) + + # Spawn get-task in separate greenlet to keep main loop active (prevents LoopExit) + get_task_greenlet = gevent.spawn(_get_task_from_server, headers, params, get_task_server_url) + + try: + get_task_response, get_task_duration_ms = get_task_greenlet.get(timeout=TELEPORT_REQUEST_TIMEOUT) + except gevent.Timeout: + logger.error("Get-task request timed out after 30 seconds") + delayed_retry(5) + continue if get_task_response.status_code == 200: thread_backoff_time = min_backoff_time task: Optional[Dict[str, Any]] = get_task_response.json().get('data', None) + if task is None: logger.info("Received empty task") - time.sleep(5) # Wait before requesting next task continue logger.info("Received task: %s", task['taskId']) @@ -109,25 +228,28 @@ def process() -> None: if thread_pool is None: process_task_async(task) else: - thread_pool.wait_available() # Wait if the thread_pool is full - thread_pool.spawn(process_task_async, task) # Submit the task when free - elif get_task_response.status_code == 204: + thread_pool.wait_available() + thread_pool.spawn(process_task_async, task) + elif get_task_response.status_code == 429: logger.info("No task available. Waiting...") - time.sleep(5) + delayed_retry(5) elif get_task_response.status_code > 500: logger.error("Getting 5XX error %d, increasing backoff time", get_task_response.status_code) - time.sleep(thread_backoff_time) + delayed_retry(thread_backoff_time) thread_backoff_time = min(max_backoff_time, thread_backoff_time * 2) else: logger.error("Unexpected response: %d", get_task_response.status_code) - time.sleep(5) + delayed_retry(5) except requests.exceptions.RequestException as e: logger.error("Network error: %s", e) - time.sleep(10) # Wait longer on network errors + delayed_retry(10) # Wait longer on network errors + except gevent.hub.LoopExit as e: + logger.error("Getting LoopExit Error, resetting the thread pool") + config_dict['thread_pool'] = Pool(config_dict['thread_pool_size']) except Exception as e: logger.error("Unexpected error while processing: %s", e, exc_info=True) - time.sleep(5) + delayed_retry(5) def process_task_async(task: Dict[str, Any]) -> None: @@ -142,41 +264,45 @@ def process_task_async(task: Dict[str, Any]) -> None: except Exception as e: logger.info("Unexpected error while processing task id: %s, method: %s url: %s, error: %s", taskId, method, url, e) - time.sleep(5) -def update_task(task: Optional[Dict[str, Any]], count: int = 0) -> None: +def update_task(task: Optional[Dict[str, Any]]) -> None: + """Update task result with 429 retry and semaphore protection.""" if task is None: return - # Update the task status - if count > max_retry: - logger.error("Retry count exceeds for task %s", task['taskId']) - return - try: + + def _make_update_request() -> requests.Response: rate_limiter.throttle() - update_task_response: requests.Response = requests.post( - f"{config_dict.get('server_url')}/api/http-teleport/put-result", - headers=_get_headers(), - json=task, - timeout=30, verify=config_dict.get('verify_cert'), proxies=config_dict['outgoing_proxy'] - ) - - if update_task_response.status_code == 200: - logger.info("Task %s updated successfully. Response: %s", task['taskId'], - update_task_response.text) - elif update_task_response.status_code == 429 or update_task_response.status_code == 504: - time.sleep(2) - logger.warning("Rate limit hit while updating the task output, retrying again for task %s", task['taskId']) - count = count + 1 - update_task(task, count) - else: - logger.warning("Failed to update task %s: %s", task['taskId'], update_task_response.text) + try: + with Timeout(TELEPORT_TIMEOUT): + with teleport_semaphore: + response = requests.post( + f"{config_dict.get('server_url')}/api/http-teleport/put-result", + headers=_get_headers(), + json=task, + timeout=TELEPORT_REQUEST_TIMEOUT, + verify=config_dict.get('verify_cert'), + proxies=config_dict['outgoing_proxy'] + ) + return response + except Timeout: + logger.error(f"Put-result timed out after {TELEPORT_TIMEOUT}s, semaphore released") + raise + try: + response = retry_request(_make_update_request, max_retries=5, operation_name=f"update_task[{task['taskId']}]") + + if response and response.status_code == 200: + logger.info(f"Task {task['taskId']} updated successfully. Response: {response.text}") + elif response and response.status_code == 504: + logger.warning(f"Timeout updating task {task['taskId']}: {response.text}") + elif response and response.status_code == 429: + logger.warning(f"Rate limit updating task {task['taskId']} after all retries") + elif response: + logger.warning(f"Failed to update task {task['taskId']}: {response.text}") except requests.exceptions.RequestException as e: - logger.error("Network error processing task %s: %s", task['taskId'], e) - count = count + 1 - update_task(task, count) + logger.error(f"Network error updating task {task['taskId']}: {e}") def _get_headers() -> Dict[str, str]: @@ -188,39 +314,50 @@ def _get_headers() -> Dict[str, str]: def check_for_logs_fetch(url, task, temp_output_file_zip): + """Upload agent logs if this is a fetch-logs request.""" if 'agent/fetch-logs' in url and 'fetchLogs' in task.get('taskId'): try: - - # Zip the logs_folder shutil.make_archive(temp_output_file_zip.name[:-4], 'zip', log_folder) - - # Update the task with the zip file information task['responseZipped'] = True - headers: Dict[str, str] = { - "Authorization": f"Bearer {config_dict['api_key']}", - } logger.info(f"Logs zipped successfully: {temp_output_file_zip.name}") + + headers: Dict[str, str] = {"Authorization": f"Bearer {config_dict['api_key']}"} task_json = json.dumps(task) files = { - # 'fileFieldName' is the name of the form field expected by the server - "file": (temp_output_file_zip.name, open(temp_output_file_zip.name, "rb"), f"{'application/zip'}"), + "file": (temp_output_file_zip.name, open(temp_output_file_zip.name, "rb"), "application/zip"), "task": (None, task_json, "application/json") } - rate_limiter.throttle() + upload_logs_url = f"{config_dict.get('server_url')}/api/http-teleport/upload-logs" if len(config_dict.get('env_name', '')) > 0: - upload_logs_url = f"{config_dict.get('server_url')}/api/http-teleport/upload-logs?envName={config_dict.get('env_name')}" - upload_result: requests.Response = requests.post( - upload_logs_url, - headers=headers, - timeout=300, verify=config_dict.get('verify_cert', False), proxies=config_dict['outgoing_proxy'], - files=files - ) - if upload_result.status_code == 200: + upload_logs_url += f"?envName={config_dict.get('env_name')}" + + def _upload_logs() -> requests.Response: + rate_limiter.throttle() + try: + with Timeout(TELEPORT_TIMEOUT): + with teleport_semaphore: + return requests.post( + upload_logs_url, + headers=headers, + timeout=TELEPORT_REQUEST_TIMEOUT, + verify=config_dict.get('verify_cert', False), + proxies=config_dict['outgoing_proxy'], + files=files + ) + except Timeout: + logger.error(f"Upload-logs timed out after {TELEPORT_TIMEOUT}s, semaphore released") + raise + + response = retry_request(_upload_logs, max_retries=5, operation_name="upload_logs") + + if response and response.status_code == 200: + logger.info("Logs uploaded successfully") return True else: - logger.error("Response code while uploading is not 200 , response code {} and error {} ", upload_result.status_code, upload_result.content) - return True + logger.error(f"Failed to upload logs: code={response.status_code if response else 'None'}, error={response.content if response else 'None'}") + return True + except Exception as e: logger.error(f"Error zipping logs: {str(e)}") raise e @@ -236,6 +373,8 @@ def process_task(task: Dict[str, Any]) -> Optional[dict[str, Any]]: expiryTime: int = task.get('expiryTsMs', round((time.time() + 300) * 1000)) logger.info("Processing task %s: %s %s", taskId, method, url) + task_start_time = time.time() + # creating temp file to store outputs _createFolder(log_folder) # create folder to store log files _createFolder(output_file_folder) # create folder to store output files @@ -277,9 +416,11 @@ def process_task(task: Dict[str, Any]) -> Optional[dict[str, Any]]: logger.debug("Input data is not str or bytes %s", input_data) + http_start_time = time.time() response: requests.Response = requests.request(method, url, headers=headers, data=encoded_input_data, stream=True, timeout=(15, timeout), verify=config_dict.get('verify_cert'), proxies=config_dict['inward_proxy']) + http_duration_ms = (time.time() - http_start_time) * 1000 logger.info("Response: %d", response.status_code) data: Any = None @@ -322,6 +463,7 @@ def process_task(task: Dict[str, Any]) -> Optional[dict[str, Any]]: base64_string = base64.b64encode(file_data).decode('utf-8') task['responseBase64'] = True task['output'] = base64_string + return task return upload_response(temp_output_file.name, temp_output_file_zip.name, taskId, task) @@ -363,31 +505,45 @@ def zip_response(temp_file, temp_file_zip) -> bool: def upload_response(temp_file, temp_file_zip, taskId: str, task: Dict[str, Any]) -> Optional[Dict[str, Any]]: + """Upload task response with 429 retry and semaphore protection.""" if config_dict.get('upload_to_ac', True): try: success = zip_response(temp_file, temp_file_zip) file_path = temp_file_zip if success else temp_file task['responseZipped'] = success file_name = f"{taskId}_{uuid.uuid4().hex}.{'zip' if success else 'txt'}" - headers: Dict[str, str] = { - "Authorization": f"Bearer {config_dict['api_key']}", - } + + headers: Dict[str, str] = {"Authorization": f"Bearer {config_dict['api_key']}"} task_json = json.dumps(task) files = { - # 'fileFieldName' is the name of the form field expected by the server "file": (file_name, open(file_path, "rb"), f"{'application/zip' if success else 'text/plain'}"), "task": (None, task_json, "application/json") - # If you have multiple files, you can add them here as more entries } - rate_limiter.throttle() - upload_result: requests.Response = requests.post( - f"{config_dict.get('server_url')}/api/http-teleport/upload-result", - headers=headers, - timeout=300, verify=config_dict.get('verify_cert', False), proxies=config_dict['outgoing_proxy'], - files=files - ) - logger.info("Upload result response: %s, code: %d", upload_result.text, upload_result.status_code) - upload_result.raise_for_status() + + def _upload_result_file() -> requests.Response: + rate_limiter.throttle() + try: + with Timeout(TELEPORT_TIMEOUT): + with teleport_semaphore: + response = requests.post( + f"{config_dict.get('server_url')}/api/http-teleport/upload-result", + headers=headers, + timeout=TELEPORT_REQUEST_TIMEOUT, + verify=config_dict.get('verify_cert', False), + proxies=config_dict['outgoing_proxy'], + files=files + ) + return response + except Timeout: + logger.error(f"Upload-result timed out after {TELEPORT_TIMEOUT}s, semaphore released") + raise + + upload_result = retry_request(_upload_result_file, max_retries=5, operation_name=f"upload_result[{taskId}]") + + if upload_result: + logger.info("Upload result response: %s, code: %d", upload_result.text, upload_result.status_code) + upload_result.raise_for_status() + return None except Exception as e: logger.error("Unable to upload file to armorcode: %s", e) @@ -416,7 +572,7 @@ def __init__(self, request_limit: int, time_window: int) -> None: self.request_limit = request_limit self.time_window = time_window self.timestamps = deque() - self.lock = threading.Lock() + self.lock = Semaphore() def set_limits(self, request_limit: int, time_window: int): self.request_limit = request_limit @@ -438,7 +594,13 @@ def allow_request(self) -> bool: def throttle(self) -> None: while not self.allow_request(): - time.sleep(0.5) + gevent.sleep(0.5) + + +def _get_url_without_params(url: str) -> str: + """Remove query parameters from URL.""" + parsed = urlparse(url) + return urlunparse((parsed.scheme, parsed.netloc, parsed.path, '', '', '')) def upload_s3(temp_file, preSignedUrl: str, headers: Dict[str, Any]) -> bool: @@ -476,26 +638,44 @@ def _createFolder(folder_path: str) -> None: def get_s3_upload_url(taskId: str) -> Tuple[Optional[str], Optional[str]]: + """Get S3 upload URL with 429 retry and semaphore protection.""" params: Dict[str, str] = {'fileName': f"{taskId}{uuid.uuid4().hex}"} - try: + + def _request_upload_url() -> requests.Response: rate_limiter.throttle() - get_s3_url: requests.Response = requests.get( - f"{config_dict.get('server_url')}/api/http-teleport/upload-url", - params=params, - headers=_get_headers(), - timeout=25, verify=config_dict.get('verify_cert', False), proxies=config_dict['outgoing_proxy'] - ) - get_s3_url.raise_for_status() - - data: Optional[Dict[str, str]] = get_s3_url.json().get('data', None) - if data is not None: - return data.get('putUrl'), data.get('getUrl') - logger.warning("No data returned when requesting S3 upload URL") + try: + with Timeout(TELEPORT_TIMEOUT): + with teleport_semaphore: + return requests.get( + f"{config_dict.get('server_url')}/api/http-teleport/upload-url", + params=params, + headers=_get_headers(), + timeout=25, + verify=config_dict.get('verify_cert', False), + proxies=config_dict['outgoing_proxy'] + ) + except Timeout: + logger.error(f"Get-s3-upload-url timed out after {TELEPORT_TIMEOUT}s, semaphore released") + raise + + try: + response = retry_request(_request_upload_url, max_retries=5, operation_name="get_s3_upload_url") + + if response and response.status_code == 200: + data: Optional[Dict[str, str]] = response.json().get('data') + if data: + return data.get('putUrl'), data.get('getUrl') + logger.warning("No data in S3 upload URL response") + else: + logger.warning(f"Failed to get S3 URL: {response.status_code if response else 'None'}") + return None, None + except requests.exceptions.RequestException as e: - logger.error("Network error getting S3 upload URL: %s", e) + logger.error(f"Network error getting S3 upload URL: {e}") except Exception as e: - logger.exception("Unexpected error getting S3 upload URL: %s", e) + logger.exception(f"Unexpected error getting S3 upload URL: {e}") + return None, None @@ -561,8 +741,8 @@ def update_agent_config(global_config: dict[str, Any]) -> None: if global_config.get("verifyCert", False): config_dict['verify_cert'] = global_config.get("verifyCert", False) - if global_config.get("threadPoolSize", 5): - config_dict['thread_pool_size'] = global_config.get("poolSize", 5) + if global_config.get("threadPoolSize", 25): + config_dict['thread_pool_size'] = global_config.get("threadPoolSize", 25) config_dict['thread_pool'] = Pool(config_dict['thread_pool_size']) if global_config.get("uploadToAC") is not None: config_dict['upload_to_ac'] = global_config.get("uploadToAC", True) @@ -595,7 +775,7 @@ def get_initial_config(parser) -> tuple[dict[str, Union[Union[bool, None, str, i "outgoing_proxy": None, # Proxy for outgoing requests (e.g., to ArmorCode) "upload_to_ac": False, # Whether to upload to ArmorCode "env_name": None, # Environment name (Optional[str]) - "thread_pool_size": 5 # Connection thread_pool size + "thread_pool_size": 25 # Connection thread_pool size } parser.add_argument("--serverUrl", required=False, help="Server Url") parser.add_argument("--apiKey", required=False, help="Api Key") @@ -610,7 +790,7 @@ def get_initial_config(parser) -> tuple[dict[str, Union[Union[bool, None, str, i parser.add_argument("--outgoingProxyHttps", required=False, help="Pass outgoing Https proxy", default=None) parser.add_argument("--outgoingProxyHttp", required=False, help="Pass outgoing Http proxy", default=None) - parser.add_argument("--poolSize", required=False, help="Multi threading thread_pool size", default=5) + parser.add_argument("--poolSize", required=False, help="Multi threading thread_pool size", default=25) parser.add_argument("--rateLimitPerMin", required=False, help="Rate limit per min", default=250) parser.add_argument( @@ -688,7 +868,7 @@ def get_initial_config(parser) -> tuple[dict[str, Union[Union[bool, None, str, i config['server_url'] = os.getenv('server_url') if config.get('api_key', None) is None: config['api_key'] = os.getenv("api_key") - config['thread_pool'] = Pool(config.get('thread_pool_size', 5)) + config['thread_pool'] = Pool(config.get('thread_pool_size', 25)) return config, agent_index, debug_mode diff --git a/web-agent/entrypoint.sh b/web-agent/entrypoint.sh index 35ede4c..6d48219 100644 --- a/web-agent/entrypoint.sh +++ b/web-agent/entrypoint.sh @@ -1,3 +1,3 @@ #!/bin/sh # Pass all arguments to the Python script and redirect output python -/usr/src/venv/bin/python3 worker.py "$@" > /tmp/armorcode/console.log 2>&1 \ No newline at end of file +/usr/src/venv/bin/python3 -W ignore worker.py "$@" > /dev/null 2> /tmp/armorcode/console.log