diff --git a/src/instructlab/training/config.py b/src/instructlab/training/config.py index 599bab4d..7b1c5f7f 100644 --- a/src/instructlab/training/config.py +++ b/src/instructlab/training/config.py @@ -305,3 +305,38 @@ class TrainingArgs(BaseModel): log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = Field( default="INFO" ) + + logger_type: str = Field( + default="async", + description="Comma-separated list of loggers to use: tensorboard, wandb, async, mlflow", + ) + + run_name: str | None = Field( + default=None, + description="Run name for logging. Supports placeholders: {time}, {rank}, {utc_time}, {local_rank}", + ) + + mlflow_tracking_uri: str | None = Field( + default=None, + description="MLflow tracking server URI (e.g., 'http://localhost:5000'). Falls back to MLFLOW_TRACKING_URI env var.", + ) + + mlflow_experiment_name: str | None = Field( + default=None, + description="MLflow experiment name. Falls back to MLFLOW_EXPERIMENT_NAME env var.", + ) + + wandb_project: str | None = Field( + default=None, + description="Weights & Biases project name.", + ) + + wandb_entity: str | None = Field( + default=None, + description="Weights & Biases team/entity name.", + ) + + tensorboard_log_dir: str | None = Field( + default=None, + description="Directory for TensorBoard logs. Defaults to ckpt_output_dir if not specified.", + ) diff --git a/src/instructlab/training/logger.py b/src/instructlab/training/logger.py index d92b8975..f09297b9 100644 --- a/src/instructlab/training/logger.py +++ b/src/instructlab/training/logger.py @@ -2,7 +2,7 @@ This module provides a logging system for training machine learning models, supporting multiple logging backends including TensorBoard (tensorboard), Weights & Biases (wandb), -and structured JSONL logging (async). +MLflow (mlflow), and structured JSONL logging (async). Example Usage: ```python @@ -73,6 +73,12 @@ except ImportError: wandb = None # type: ignore +try: + # Third Party + import mlflow +except ImportError: + mlflow = None # type: ignore + # Third Party from rich.logging import RichHandler import torch @@ -581,6 +587,151 @@ def emit(self, record: logging.LogRecord): self._wandb_run.log(flat_dict, step=step) +class MLflowHandler(logging.Handler): + """Logger that sends metrics to MLflow. + + This handler expects a (nested) dictionary of metrics to be logged with string keys. + A step can be specified by passing `extra={"step": }` to the logging method. + To log hyperparameters, pass a (nested) mapping of hyperparameters to the logging method + and set `extra={"hparams": True}`. + + Example: + ```python + import logging + from instructlab.training.logger import MLflowHandler + + # Create handler + handler = MLflowHandler( + level=logging.INFO, + run_name="experiment_{time}", + tracking_uri="http://localhost:5000", + experiment_name="my_experiment" + ) + + # Create logger + logger = logging.getLogger("metrics") + logger.addHandler(handler) + logger.setLevel(logging.INFO) + + # Log metrics + logger.info( + { + "training": { + "loss": 0.5, + "accuracy": 0.95 + } + }, + extra={"step": 100} + ) + + # Log hyperparameters + logger.info( + { + "learning_rate": 0.001, + "batch_size": 32 + }, + extra={"hparams": True} + ) + ``` + """ + + def __init__( + self, + level: int = logging.INFO, + run_name: str | None = None, + tracking_uri: str | None = None, + experiment_name: str | None = None, + **mlflow_init_kwargs: Any, + ): + """Initialize the MLflow logger and check for required dependencies. + + Args: + level: The logging level for this handler + run_name: Name of the run, can contain placeholders + tracking_uri: MLflow tracking server URI (e.g., "http://localhost:5000") + experiment_name: Name of the MLflow experiment + **mlflow_init_kwargs: Additional keyword arguments passed to mlflow.start_run() + """ + super().__init__(level) + + self.run_name = _substitute_placeholders(run_name) + self.tracking_uri = tracking_uri + self.experiment_name = experiment_name + self.mlflow_init_kwargs = mlflow_init_kwargs.copy() + + self._mlflow_run = None + + def _setup(self): + """Initialize the MLflow run with the configured settings.""" + if mlflow is None: + msg = ( + "Could not initialize MLflowHandler because package mlflow could not be imported.\n" + "Please ensure it is installed by running 'pip install mlflow'" + ) + raise RuntimeError(msg) + + if self.tracking_uri: + mlflow.set_tracking_uri(self.tracking_uri) + + if self.experiment_name: + mlflow.set_experiment(self.experiment_name) + + self._mlflow_run = mlflow.start_run( + run_name=self.run_name, **self.mlflow_init_kwargs + ) + + def emit(self, record: logging.LogRecord): + """Emit a log record to MLflow. + + Args: + record: The log record to emit + """ + if self._mlflow_run is None: + self._setup() + + if not isinstance(record.msg, Mapping): + warnings.warn( + f"MLflowHandler expected a mapping, got {type(record.msg)}. Skipping log. " + "Please ensure the handler is configured correctly to filter out non-mapping objects." + ) + return + + flat_dict = _flatten_dict(record.msg, sep=".") + step = getattr(record, "step", None) + + if getattr(record, "hparams", None): + # Log as parameters - MLflow params must be strings + params_dict = {k: str(v) for k, v in flat_dict.items()} + mlflow.log_params(params_dict) + return + + # Filter to only numeric values for metrics + metrics_dict = {} + skipped_keys = [] + for k, v in flat_dict.items(): + try: + metrics_dict[k] = float(v) + except (ValueError, TypeError): + # Skip non-numeric values for metrics + skipped_keys.append(k) + + if skipped_keys: + logging.debug( + f"MLflowHandler skipped non-numeric metrics: {skipped_keys}. " + "Only numeric values can be logged as MLflow metrics." + ) + + if metrics_dict: + mlflow.log_metrics(metrics_dict, step=step) + + def close(self): + """End the MLflow run and cleanup resources.""" + if self._mlflow_run is not None: + mlflow.end_run() + self._mlflow_run = None + super().close() + + class AsyncStructuredHandler(logging.Handler): """Logger that asynchronously writes data to a JSONL file. @@ -708,7 +859,17 @@ def setup_root_logger(level="DEBUG"): ) -def setup_metric_logger(loggers, run_name, output_dir): +def setup_metric_logger( + loggers, + run_name, + output_dir, + *, + mlflow_tracking_uri: str | None = None, + mlflow_experiment_name: str | None = None, + wandb_project: str | None = None, + wandb_entity: str | None = None, + tensorboard_log_dir: str | None = None, +): """Configure the metric logging system with specified backends. This function sets up a comprehensive logging configuration that supports @@ -717,10 +878,17 @@ def setup_metric_logger(loggers, run_name, output_dir): Args: loggers: A string or list of strings specifying which logging backends to use. - Supported values: "tensorboard", "wandb", "async" + Supported values: "tensorboard", "wandb", "mlflow", "async" run_name: Name for the current training run. Can include placeholders like {time}, {rank}, {utc_time}, {local_rank}. output_dir: Directory where log files will be stored + mlflow_tracking_uri: MLflow tracking server URI (e.g., "http://localhost:5000"). + Falls back to MLFLOW_TRACKING_URI environment variable if not provided. + mlflow_experiment_name: MLflow experiment name. + Falls back to MLFLOW_EXPERIMENT_NAME environment variable if not provided. + wandb_project: Weights & Biases project name. + wandb_entity: Weights & Biases team/entity name. + tensorboard_log_dir: Directory for TensorBoard logs. Defaults to output_dir if not provided. Example: ```python @@ -731,11 +899,13 @@ def setup_metric_logger(loggers, run_name, output_dir): output_dir="logs" ) - # Setup logging with a single backend + # Setup logging with MLflow setup_metric_logger( - loggers="tensorboard", + loggers=["mlflow"], run_name="my_run", - output_dir="logs" + output_dir="logs", + mlflow_tracking_uri="http://localhost:5000", + mlflow_experiment_name="my_experiment" ) ``` """ @@ -773,7 +943,7 @@ def setup_metric_logger(loggers, run_name, output_dir): }, "tensorboard": { "()": TensorBoardHandler, - "log_dir": output_dir, + "log_dir": tensorboard_log_dir or output_dir, "run_name": run_name, "filters": ["is_mapping", "is_rank0"], }, @@ -781,6 +951,17 @@ def setup_metric_logger(loggers, run_name, output_dir): "()": WandbHandler, "log_dir": output_dir, "run_name": run_name, + "project": wandb_project, + "entity": wandb_entity, + "filters": ["is_mapping", "is_rank0"], + }, + "mlflow": { + "()": MLflowHandler, + "run_name": run_name, + "tracking_uri": mlflow_tracking_uri + or os.environ.get("MLFLOW_TRACKING_URI"), + "experiment_name": mlflow_experiment_name + or os.environ.get("MLFLOW_EXPERIMENT_NAME"), "filters": ["is_mapping", "is_rank0"], }, }, diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py index ff5cea7d..ebb96734 100644 --- a/src/instructlab/training/main_ds.py +++ b/src/instructlab/training/main_ds.py @@ -272,7 +272,16 @@ def main(args): "DeepSpeed was selected and CPU offloading was requested, but DeepSpeedCPUAdam could not be imported. This likely means you need to build DeepSpeed with the CPU adam flags." ) - setup_metric_logger(args.logger_type, args.run_name, args.output_dir) + setup_metric_logger( + args.logger_type, + args.run_name, + args.output_dir, + mlflow_tracking_uri=args.mlflow_tracking_uri, + mlflow_experiment_name=args.mlflow_experiment_name, + wandb_project=args.wandb_project, + wandb_entity=args.wandb_entity, + tensorboard_log_dir=args.tensorboard_log_dir, + ) metric_logger = logging.getLogger("instructlab.training.metrics") if os.environ["LOCAL_RANK"] == "0": metric_logger.info(vars(args), extra={"hparams": True}) @@ -457,7 +466,16 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None: # Enable package logging propagation before setting up loggers propagate_package_logs(True) setup_root_logger(train_args.log_level) - setup_metric_logger("async", None, train_args.ckpt_output_dir) + setup_metric_logger( + train_args.logger_type, + train_args.run_name, + train_args.ckpt_output_dir, + mlflow_tracking_uri=train_args.mlflow_tracking_uri, + mlflow_experiment_name=train_args.mlflow_experiment_name, + wandb_project=train_args.wandb_project, + wandb_entity=train_args.wandb_entity, + tensorboard_log_dir=train_args.tensorboard_log_dir, + ) logger = logging.getLogger("instructlab.training") logger.info("Starting training setup...") @@ -545,9 +563,24 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None: f"--adamw_beta1={train_args.adamw_betas[0]}", f"--adamw_beta2={train_args.adamw_betas[1]}", f"--adamw_eps={train_args.adamw_eps}", + f"--logger_type={train_args.logger_type}", ] ) + # Add optional logging parameters + if train_args.run_name is not None: + command.append(f"--run_name={train_args.run_name}") + if train_args.mlflow_tracking_uri is not None: + command.append(f"--mlflow_tracking_uri={train_args.mlflow_tracking_uri}") + if train_args.mlflow_experiment_name is not None: + command.append(f"--mlflow_experiment_name={train_args.mlflow_experiment_name}") + if train_args.wandb_project is not None: + command.append(f"--wandb_project={train_args.wandb_project}") + if train_args.wandb_entity is not None: + command.append(f"--wandb_entity={train_args.wandb_entity}") + if train_args.tensorboard_log_dir is not None: + command.append(f"--tensorboard_log_dir={train_args.tensorboard_log_dir}") + if train_args.pretraining_config is not None: command.append(f"--block-size={train_args.pretraining_config.block_size}") command.append( @@ -766,6 +799,36 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None: parser.add_argument("--log_level", type=str, default="INFO") parser.add_argument("--run_name", type=str, default=None) parser.add_argument("--logger_type", type=str, default="async") + parser.add_argument( + "--mlflow_tracking_uri", + type=str, + default=None, + help="MLflow tracking server URI (e.g., 'http://localhost:5000')", + ) + parser.add_argument( + "--mlflow_experiment_name", + type=str, + default=None, + help="MLflow experiment name", + ) + parser.add_argument( + "--wandb_project", + type=str, + default=None, + help="Weights & Biases project name", + ) + parser.add_argument( + "--wandb_entity", + type=str, + default=None, + help="Weights & Biases team/entity name", + ) + parser.add_argument( + "--tensorboard_log_dir", + type=str, + default=None, + help="Directory for TensorBoard logs. Defaults to output_dir if not specified.", + ) parser.add_argument("--seed", type=int, default=42) parser.add_argument("--mock_data", action="store_true") parser.add_argument("--mock_len", type=int, default=2600)