Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 34 additions & 10 deletions web-agent/app/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,9 @@ def main() -> None:
# Instantiate RateLimiter for 25 requests per 15 seconds window
rate_limiter = RateLimiter(request_limit=25, time_window=15)
parser = argparse.ArgumentParser()
config_dict, agent_index, debug_mode = get_initial_config(parser)
config_dict, agent_index, debug_mode, enable_stdout_logging = get_initial_config(parser)

logger = setup_logger(agent_index, debug_mode)
logger = setup_logger(agent_index, debug_mode, enable_stdout_logging)

# Initialize metrics logger
metrics_folder = os.path.join(log_folder, 'metrics')
Expand Down Expand Up @@ -791,20 +791,20 @@ def get_s3_upload_url(taskId: str) -> Tuple[Optional[str], Optional[str]]:


# Function to set up logging with timed rotation and log retention
def setup_logger(index: str, debug_mode: bool) -> logging.Logger:
def setup_logger(index: str, debug_mode: bool, enable_stdout: bool = False) -> logging.Logger:
log_filename: str = os.path.join(log_folder, f"app_log{index}.log")

# Create a TimedRotatingFileHandler
handler: TimedRotatingFileHandler = TimedRotatingFileHandler(
# Create a TimedRotatingFileHandler for file logging
file_handler: TimedRotatingFileHandler = TimedRotatingFileHandler(
log_filename, when="midnight", interval=1, backupCount=7
) # This will keep logs for the last 7 days

# Set the log format
# Set the log format (shared by both handlers)
formatter: logging.Formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
handler.setFormatter(formatter)
file_handler.setFormatter(formatter)

# Create the logger instance
logger: logging.Logger = logging.getLogger(__name__)
Expand All @@ -813,7 +813,18 @@ def setup_logger(index: str, debug_mode: bool) -> logging.Logger:
else:
logger.setLevel(logging.INFO) # Set the log level (DEBUG, INFO, etc.)

logger.addHandler(handler)
# Clear any existing handlers to prevent duplicates
logger.handlers.clear()

# Add file handler
logger.addHandler(file_handler)

# Conditionally add console handler for stdout logging
if enable_stdout:
console_handler: logging.StreamHandler = logging.StreamHandler()
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)

logger.info("Log folder is created %s", log_folder)
return logger

Expand Down Expand Up @@ -875,7 +886,7 @@ def generate_unique_id():



def get_initial_config(parser) -> tuple[dict[str, Union[Union[bool, None, str, int], Any]], str, bool]:
def get_initial_config(parser) -> tuple[dict[str, Union[Union[bool, None, str, int], Any]], str, bool, bool]:
global rate_limiter
config = {
"api_key": None, # Optional[str]
Expand Down Expand Up @@ -914,6 +925,15 @@ def get_initial_config(parser) -> tuple[dict[str, Union[Union[bool, None, str, i
help="Upload to Armorcode instead of S3 (default: True)"
)

parser.add_argument(
"--enableStdoutLogging",
nargs='?',
type=str2bool,
const=True,
default=False,
help="Enable logging to stdout/console in addition to file (default: False)"
)

args = parser.parse_args()
config['agent_id'] = generate_unique_id()
config['server_url'] = args.serverUrl
Expand All @@ -927,6 +947,7 @@ def get_initial_config(parser) -> tuple[dict[str, Union[Union[bool, None, str, i
config['metrics_retention_days'] = args.metricsRetentionDays

config['upload_to_ac'] = args.uploadToAc
enable_stdout_logging_cmd = args.enableStdoutLogging

rate_limiter.set_limits(rate_limit_per_min, 60)
inward_proxy_https = args.inwardProxyHttps
Expand Down Expand Up @@ -961,6 +982,9 @@ def get_initial_config(parser) -> tuple[dict[str, Union[Union[bool, None, str, i
if str(debug_cmd).lower() == "true":
debug_mode = True

# Enable stdout logging flag (already converted to bool by str2bool)
enable_stdout_logging = enable_stdout_logging_cmd if isinstance(enable_stdout_logging_cmd, bool) else False

if verify_cmd is not None:
if str(verify_cmd).lower() == "false":
config['verify_cert'] = False
Expand All @@ -985,7 +1009,7 @@ def get_initial_config(parser) -> tuple[dict[str, Union[Union[bool, None, str, i
if config.get('api_key', None) is None:
config['api_key'] = os.getenv("api_key")
config['thread_pool'] = Pool(config.get('thread_pool_size', 5))
return config, agent_index, debug_mode
return config, agent_index, debug_mode, enable_stdout_logging


if __name__ == "__main__":
Expand Down
Loading