Skip to content

Conversation

@binary-husky
Copy link
Collaborator

No description provided.

@gemini-code-assist
Copy link

Summary of Changes

Hello @binary-husky, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces the 'Swarm training' feature, a significant architectural shift designed to enable distributed reinforcement learning for large language models. It establishes a client-server model where the heavy computational load of LLM inference and training resides on a GPU-equipped server, while agent execution and reward calculation can be performed on lighter client machines. This decoupling facilitates development on non-GPU environments and scales data generation by allowing multiple clients to contribute to the training process, ultimately making LLM agent training more accessible and flexible.

Highlights

  • Introducing Swarm Training Mode: This pull request introduces a new 'Swarm training' mode, enabling the decoupling of training logic from agent execution. This allows for distributed rollouts where agents can run on client machines (e.g., laptops without GPUs) while the core training and LLM inference occur on a remote GPU server.
  • Client-Server Architecture for Distributed Rollouts: A new client-server architecture is implemented using FastAPI and ZeroMQ. The Swarm server manages training tasks and model weights, while Swarm clients claim tasks, execute agent logic, and report results, facilitating GPU-less local development and distributed data generation.
  • Dynamic Rollout Management: The DynamicRolloutManager has been significantly refactored to support the new Swarm mode, including thread pooling, shared observation windows, and dynamic stop conditions for efficient sample collection.
  • Lazy Imports for ajet Package: The ajet/__init__.py file now uses lazy imports for several modules, which can improve startup performance by deferring the loading of modules until they are actually needed.
  • Enhanced Task Reading with Generators: The task_to_standard_dataset function and RouterTaskReader have been updated to leverage Python generators, allowing for more memory-efficient processing of potentially large or infinite streams of tasks from datasets.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • ajet/init.py
    • Removed direct imports for several modules and implemented lazy loading using __getattr__.
    • Added __version__ definition.
  • ajet/backbone/main_verl.py
    • Removed top-level loguru import and print_dict call for runtime environment variables.
    • Added DictConfig type hints to main and run_ppo functions.
    • Moved several imports (e.g., collate_fn, RouterTaskReader) from within run method to top-level.
    • Updated task_to_standard_dataset calls to use generator functions from task_reader.
  • ajet/backbone/main_vllm.py
    • Added enable_swarm_mode configuration parameter.
    • Modified parallel_env.rollout mode selection to dynamically switch between 'sample' and 'sample-ts' based on enable_swarm_mode.
    • Integrated http_change_engine_status call to update the interchange server status when swarm mode is enabled.
  • ajet/backbone/trainer_trinity.py
    • Updated task_to_standard_dataset calls to use generator functions (generate_training_tasks, generate_validation_tasks) for dataset creation.
  • ajet/backbone/trainer_verl.py
    • Modified union_gen_batch_via_task_id to accept a discard_original_batch parameter, enabling different batch union strategies.
    • Introduced _update_interchange_server_status_flag method to communicate engine status changes to the interchange server.
    • Adjusted validation and logging conditions to account for enable_swarm_mode.
    • Replaced critic/round_cnt metric with critic/llm_call_cnt.
    • Removed AdvantageEstimator.REMAX check, indicating it's no longer supported.
    • Updated union_gen_batch_via_task_id calls to pass discard_original_batch based on enable_swarm_mode.
  • ajet/context_tracker/base_tracker.py
    • Refactored __init__ method to group related attribute initializations.
    • Added _read_only and _discarded flags.
    • Introduced a reset method to reinitialize tracker state.
    • Renamed round_cnt to llm_call_cnt for clarity.
  • ajet/context_tracker/basic_tracker.py
    • Removed terminal_rewards_dict from the class docstring.
    • Corrected a typo from split_prompt_reponse_index to split_prompt_response_index.
    • Added an assertion to ensure reward step count matches total steps.
  • ajet/context_tracker/multiagent_tracking.py
    • Split should_interrupt_fn into should_interrupt_soft_fn and should_interrupt_hard_fn for more granular control over interruption.
    • Added a warning message when context overflow occurs.
    • Incremented llm_call_cnt within step_track.
    • Implemented a read-only mode for timelines after merging to prevent further modifications.
  • ajet/copilot/job.py
    • Updated AgentJetJob constructor to default backbone to 'verl' and added grpo_n, batch_size, and swarm_mode parameters.
    • Modified build_job_from_yaml to load ajet_ts_default.yaml when swarm_mode is enabled.
    • Moved check_avail_gpu, get_backbone_target, setup_environment_vars imports to ajet.utils.launch_utils.
  • ajet/default_config/ajet_default.py
    • Added num_repeat field to AjetRollout dataclass.
  • ajet/default_config/ajet_default.yaml
    • Added enable_swarm_mode, swarm_mode_sample_collection_method, and swarm_mode_sample_collection_max_cached_episodes configurations.
    • Updated interchange_server configuration with max_fastapi_threads and already_started.
    • Added a comment for task_reader type.
    • Removed a comment about multi-processing wrapper type.
  • ajet/default_config/ajet_ts_default.yaml
    • Added a new default configuration file specifically for Swarm (tinkerscript) mode, enabling experimental interchange server and swarm mode by default.
  • ajet/default_config/verl/verl_default.yaml
    • Removed enable_oversample from ajet.rollout configuration.
  • ajet/launcher.py
    • Moved several utility functions (check_debugpy_version, check_avail_gpu, dict_to_namespace, get_backbone_target, setup_environment_vars) to ajet.utils.launch_utils.
    • Added --swarm-server command-line argument to launch the interchange server in blocking mode.
    • Implemented start_swarm_server function to initiate the Swarm server process.
    • Updated configuration loading logic to use ajet_ts_default.yaml when --swarm-server is active without a specified config.
  • ajet/schema/task.py
    • Added descriptive comments to fields within the Task BaseModel.
  • ajet/task_reader/init.py
    • Added list_to_generator helper function to convert lists of tasks to generators.
    • Modified RouterTaskReader to include generate_training_tasks and generate_validation_tasks methods that return generators.
    • Updated task_to_standard_dataset to accept a generator of tasks instead of a list, improving memory efficiency.
  • ajet/task_reader/document_reader/doc_reader.py
    • Changed a logger.warning message to logger.info when the unstructured dependency is not installed.
  • ajet/task_reader/hf_dataset_reader.py
    • Modified _load_dataset_split to return a generator of tasks and support loading from local .parquet files.
    • Added as_generator flag to track if the reader is operating in generator mode.
    • Introduced generate_training_tasks and generate_validation_tasks methods that return generators, and retained get_training_tasks/get_validation_tasks to return lists by converting the generators.
  • ajet/task_rollout/async_llm_bridge.py
    • Reordered imports for better organization, with no functional changes.
  • ajet/task_rollout/native_parallel_worker.py
    • Introduced spawn_thread_shared_observation_window for managing shared state across rollout threads.
    • Added _write_swarm_rollout_dynamic_log for debugging swarm rollout progress.
    • Replaced the rollout_dynamic method with rollout_swarm, implementing a new thread-pooling strategy with dynamic stop conditions for swarm mode.
    • Implemented count_tasks, enough_sample_stop_condition, enough_finished_task_stop_condition, and enough_non_dummy_task_stop_condition for flexible termination criteria in swarm mode.
    • Added soft and hard stop mechanisms for threads, communicating with the interchange server.
  • ajet/task_rollout/single_worker.py
    • Imported SwarmRunner and SwarmReceiveAbortException.
    • Added enable_swarm_mode attribute to BaseRolloutManager.
    • Modified rollout_env_worker to conditionally use SwarmRunner when enable_swarm_mode is active.
    • Introduced rollout_env_worker_loop to enable continuous task processing in swarm mode, handling thread-specific observations and task completion.
  • ajet/task_runner/base_runner.py
    • Imported is_episode_claimed from interchange_utils.
    • Refactored runner_hooks to provide separate should_interrupt_soft_fn and should_interrupt_hard_fn for context trackers, including logic for checking episode claims in hard stop scenarios.
  • ajet/task_runner/general_runner.py
    • Updated imports to explicitly reference AjetTuner from ajet.tuner and WorkflowOutput, WorkflowTask from ajet.schema.task.
  • ajet/task_runner/swarm_runner.py
    • Added a new SwarmRunner class, inheriting from BaseAgentRunner.
    • Implemented register_episode_and_wait_output to handle communication with the Swarm server via ZeroMQ for episode registration and result reception.
    • The execute method now handles swarm-specific logic, including claiming episodes, executing the workflow, processing rewards, and managing context tracker state in a distributed environment.
  • ajet/tuner.py
    • Updated imports to explicitly reference MultiAgentContextTracker from ajet.context_tracker.multiagent_tracking.
  • ajet/tuner_lib/weight_tuner/as_oai_baseurl_apikey.py
    • Introduced OpenaiBaseUrlAndApiKey Pydantic model to encapsulate OpenAI API credentials and episode UUID.
    • Modified OpenaiClientBaseUrlTuner to dynamically set the interchange server port based on configuration.
  • ajet/tuner_lib/weight_tuner/experimental/as_oai_model_client.py
    • Updated imports to use get_zmq_socket and is_episode_claimed from interchange_utils.
    • Replaced should_terminate property with should_soft_terminate and should_hard_terminate for more nuanced termination control.
    • Added logic to handle hard termination during the begin_service wait loop, ensuring proper resource cleanup.
  • ajet/tuner_lib/weight_tuner/experimental/as_oai_model_server.py
    • Updated imports to include Manager, Process, Coroutine, Optional, Tuple from multiprocessing and typing.
    • Imported find_free_port, get_host_ip from ajet.utils.networking and EpisodeStatus from interchange_utils.
    • Modified get_app to accept enable_swarm_mode, shared_mem_dict, and shared_mem_dict_lock for shared memory management.
    • Updated _begin_handle_chat_completion to check engine status in swarm mode before accepting requests.
    • Added logic within chat_completions to update episode activity timestamps in swarm mode.
    • Integrated register_enable_swarm_mode_routes for swarm-specific API endpoints.
    • Modified InterchangeServer to support enable_swarm_mode and handle shared memory, and updated start_interchange_server to support blocking mode and graceful shutdown.
  • ajet/tuner_lib/weight_tuner/experimental/as_swarm_client.py
    • Added a new SwarmClient class for interacting with the Swarm server.
    • Implemented methods for begin_episode (claiming tasks), end_episode (submitting results), abort_episode, sync_train_config, start_engine, get_engine_status, can_continue_episode, get_episode_buffer, and auto_sync_train_config_and_start_engine.
  • ajet/tuner_lib/weight_tuner/experimental/as_swarm_server.py
    • Added a new as_swarm_server module, implementing the core logic for the Swarm server.
    • Includes shared memory management (DictProxy), ZeroMQ communication, and FastAPI endpoints for client interaction (e.g., sync_train_config, start_engine, claim_episode, end_episode, abort_episode, update_engine_status).
    • Introduced an episode state machine (EpisodeStatus) to manage the lifecycle of tasks and handle client timeouts.
    • Implemented kill_process_tree for robust termination of training processes.
  • ajet/tuner_lib/weight_tuner/experimental/interchange_utils.py
    • Added a new interchange_utils module, defining Pydantic models for various Swarm communication requests and responses (e.g., SyncTrainConfigRequest, ClaimEpisodeRequest, EpisodeStatus).
    • Provided utility functions like get_interchange_server_url, http_change_engine_status, is_episode_claimed, http_register_episode, and get_zmq_socket.
  • ajet/utils/config_utils.py
    • Added override_param_callback argument to read_ajet_hierarchical_config and prepare_experiment_config to allow dynamic modification of configurations.
  • ajet/utils/core_env_vars.py
    • Added copy import.
    • Refined the logic for setting AJET_DAT_INTERCHANGE_PORT and MASTER_NODE_IP environment variables.
    • Implemented a print formatting utility for runtime_env to limit the length of printed values.
  • ajet/utils/launch_utils.py
    • Moved get_backbone_target, setup_environment_vars, check_debugpy_version, check_avail_gpu, and dict_to_namespace from ajet/launcher.py to centralize utility functions.
  • ajet/utils/retry.py
    • Introduced SwarmReceiveAbortException to handle specific abortion signals in swarm mode.
    • Modified the retry_with_backoff decorator to catch SwarmReceiveAbortException and return None silently, allowing for graceful termination of retries in swarm contexts.
  • ajet/utils/thread_executors.py
    • Changed concurrent.futures.ThreadPoolExecutor to ThreadPoolExecutor for consistency.
    • Added BoundedThreadPoolExecutor to manage a thread pool with a bounded queue size, preventing excessive task queuing.
  • ajet/workflow.py
    • Updated imports to explicitly reference AjetTuner from ajet.tuner.
  • ajet_swarm_threading.py
    • Added a new example client script demonstrating how to use SwarmClient to interact with a remote Swarm server for distributed training rollouts.
  • ajet_tinkerscript.md
    • Added a new markdown file providing an example command for launching the Swarm server.
  • docs/en/platform_comparison.md
    • Updated the 'Without-GPU (Cloud-Computing)' feature description to mention 'AgentJet-Swarm' as an upcoming capability.
  • docs/en/workflow.md
    • Renamed the 'TinkerJet' section to 'Swarm' to reflect the new feature.
  • docs/index.md
    • Updated the 'Any Training Engine' description to mention 'Swarm backbone support will be released soon'.
  • tests/bench/benchmark_learn2ask/benchmark_learn2ask.py
    • Adjusted the benchmark value for step 200 from 2.9 to 2.6.
  • tinkerscript.md
    • Added a new markdown file detailing the 'Swarm Design Blueprint', including its overview, core architecture, detailed workflow, and episode state machine.
  • tinkerscript_1.md
    • Added a new markdown file providing a detailed explanation of the 'Swarm Design Blueprint', covering system architecture, detailed workflow, data structures, API design, and key configurations.
  • tutorial/demo_tinkerjet/README.md
    • Added a new README file for the demo_tinkerjet tutorial, explaining the Swarm concept, its benefits, and core training code examples.
  • tutorial/demo_tinkerjet/ajet_tinkerscript_default.yaml
    • Added a new default YAML configuration file for the tinkerjet tutorial, enabling experimental interchange server and swarm mode.
  • tutorial/example_academic_trans/trans.py
    • Added a new Python script implementing an academic translation agent, including rough translation, proper noun detection, and final translation production using OpenAI API calls.
  • tutorial/example_academic_trans/trans_reward.py
    • Added a new Python script defining TranslationQualityGrader for evaluating academic translation quality based on specific error patterns, using an LLM-based grader.
  • tutorial/example_academic_trans/trans_roll.py
    • Added a new Python script demonstrating how to perform distributed rollouts for academic translation tasks using SwarmClient and a BoundedThreadPoolExecutor.
  • tutorial/example_deep_finance/deep_finance.yaml
    • Removed enable_oversample: False from the rollout configuration.
  • tutorial/example_deep_finance/yaml_template/deep_finance_template.yaml
    • Removed enable_oversample: False from the rollout configuration template.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the "Swarm training" feature, a significant architectural change that decouples the agent execution and data collection from the main training loop. This enables distributed data generation, potentially on machines without GPUs, by using a client-server model with FastAPI and ZMQ. The changes are extensive, affecting configuration, the launcher, task readers, and the core training logic. While the overall direction is promising, I've identified a couple of critical issues, including a leftover debug statement that will crash the application and a bug in path handling. I've also included some medium-severity suggestions to improve the robustness of the new networking components.

"""
assert yaml_path.endswith(".yaml"), "Configuration file must be a YAML file"
exp_base = os.path.dirname(yaml_path)
exp_base = os.path.exists(os.path.dirname(yaml_path))

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This line incorrectly assigns a boolean value from os.path.exists() to exp_base, which is expected to be a path string. This will cause the subsequent check os.path.exists(exp_base) to fail. The intention was likely to get the directory name.

Suggested change
exp_base = os.path.exists(os.path.dirname(yaml_path))
exp_base = os.path.dirname(yaml_path)

self._write_swarm_rollout_dynamic_log(observation_window)

time.sleep(10)
raise RuntimeError("DEBUG")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This raise RuntimeError("DEBUG") appears to be a leftover debugging statement. It will cause the swarm rollout process to crash and must be removed for the feature to be usable.

self.socket = context.socket(zmq.REP)
self.socket.bind(f"{self.episode_contect_address}")
self.socket.setsockopt(zmq.RCVTIMEO, 3*1000) # 3 second timeout for REP
self.socket.setsockopt(zmq.RCVTIMEO, 1*1000) # 3 second timeout for REP

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The ZMQ receive timeout has been reduced to 1 second. This might be too aggressive and could lead to premature timeouts, especially in environments with high network latency or server load. Consider making this value configurable or reverting to a more conservative timeout to improve robustness.

Suggested change
self.socket.setsockopt(zmq.RCVTIMEO, 1*1000) # 3 second timeout for REP
self.socket.setsockopt(zmq.RCVTIMEO, 3*1000) # 3 second timeout for REP

host="0.0.0.0",
port=self.port,
log_level="error",
log_level="info",

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The uvicorn log level is now hardcoded to "info". While this is helpful for debugging, it can generate a large volume of logs in a production or large-scale training environment. It would be better to make this configurable or set it to a less verbose level like "warning" or "error" by default.

Suggested change
log_level="info",
log_level="error",

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant