Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions src/instructlab/training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,3 +305,38 @@ class TrainingArgs(BaseModel):
log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = Field(
default="INFO"
)

logger_type: str = Field(
default="async",
description="Comma-separated list of loggers to use: tensorboard, wandb, async, mlflow",
)

run_name: str | None = Field(
default=None,
description="Run name for logging. Supports placeholders: {time}, {rank}, {utc_time}, {local_rank}",
)

mlflow_tracking_uri: str | None = Field(
default=None,
description="MLflow tracking server URI (e.g., 'http://localhost:5000'). Falls back to MLFLOW_TRACKING_URI env var.",
)

mlflow_experiment_name: str | None = Field(
default=None,
description="MLflow experiment name. Falls back to MLFLOW_EXPERIMENT_NAME env var.",
)

wandb_project: str | None = Field(
default=None,
description="Weights & Biases project name.",
)

wandb_entity: str | None = Field(
default=None,
description="Weights & Biases team/entity name.",
)

tensorboard_log_dir: str | None = Field(
default=None,
description="Directory for TensorBoard logs. Defaults to ckpt_output_dir if not specified.",
)
195 changes: 188 additions & 7 deletions src/instructlab/training/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

This module provides a logging system for training machine learning models,
supporting multiple logging backends including TensorBoard (tensorboard), Weights & Biases (wandb),
and structured JSONL logging (async).
MLflow (mlflow), and structured JSONL logging (async).

Example Usage:
```python
Expand Down Expand Up @@ -73,6 +73,12 @@
except ImportError:
wandb = None # type: ignore

try:
# Third Party
import mlflow
except ImportError:
mlflow = None # type: ignore

# Third Party
from rich.logging import RichHandler
import torch
Expand Down Expand Up @@ -581,6 +587,151 @@ def emit(self, record: logging.LogRecord):
self._wandb_run.log(flat_dict, step=step)


class MLflowHandler(logging.Handler):
"""Logger that sends metrics to MLflow.

This handler expects a (nested) dictionary of metrics to be logged with string keys.
A step can be specified by passing `extra={"step": <step>}` to the logging method.
To log hyperparameters, pass a (nested) mapping of hyperparameters to the logging method
and set `extra={"hparams": True}`.

Example:
```python
import logging
from instructlab.training.logger import MLflowHandler

# Create handler
handler = MLflowHandler(
level=logging.INFO,
run_name="experiment_{time}",
tracking_uri="http://localhost:5000",
experiment_name="my_experiment"
)

# Create logger
logger = logging.getLogger("metrics")
logger.addHandler(handler)
logger.setLevel(logging.INFO)

# Log metrics
logger.info(
{
"training": {
"loss": 0.5,
"accuracy": 0.95
}
},
extra={"step": 100}
)

# Log hyperparameters
logger.info(
{
"learning_rate": 0.001,
"batch_size": 32
},
extra={"hparams": True}
)
```
"""

def __init__(
self,
level: int = logging.INFO,
run_name: str | None = None,
tracking_uri: str | None = None,
experiment_name: str | None = None,
**mlflow_init_kwargs: Any,
):
"""Initialize the MLflow logger and check for required dependencies.

Args:
level: The logging level for this handler
run_name: Name of the run, can contain placeholders
tracking_uri: MLflow tracking server URI (e.g., "http://localhost:5000")
experiment_name: Name of the MLflow experiment
**mlflow_init_kwargs: Additional keyword arguments passed to mlflow.start_run()
"""
super().__init__(level)

self.run_name = _substitute_placeholders(run_name)
self.tracking_uri = tracking_uri
self.experiment_name = experiment_name
self.mlflow_init_kwargs = mlflow_init_kwargs.copy()

self._mlflow_run = None

def _setup(self):
"""Initialize the MLflow run with the configured settings."""
if mlflow is None:
msg = (
"Could not initialize MLflowHandler because package mlflow could not be imported.\n"
"Please ensure it is installed by running 'pip install mlflow'"
)
raise RuntimeError(msg)

if self.tracking_uri:
mlflow.set_tracking_uri(self.tracking_uri)

if self.experiment_name:
mlflow.set_experiment(self.experiment_name)

self._mlflow_run = mlflow.start_run(
run_name=self.run_name, **self.mlflow_init_kwargs
)

def emit(self, record: logging.LogRecord):
"""Emit a log record to MLflow.

Args:
record: The log record to emit
"""
if self._mlflow_run is None:
self._setup()

if not isinstance(record.msg, Mapping):
warnings.warn(
f"MLflowHandler expected a mapping, got {type(record.msg)}. Skipping log. "
"Please ensure the handler is configured correctly to filter out non-mapping objects."
)
return

flat_dict = _flatten_dict(record.msg, sep=".")
step = getattr(record, "step", None)

if getattr(record, "hparams", None):
# Log as parameters - MLflow params must be strings
params_dict = {k: str(v) for k, v in flat_dict.items()}
mlflow.log_params(params_dict)
return

# Filter to only numeric values for metrics
metrics_dict = {}
skipped_keys = []
for k, v in flat_dict.items():
try:
metrics_dict[k] = float(v)
except (ValueError, TypeError):
# Skip non-numeric values for metrics
skipped_keys.append(k)

if skipped_keys:
logging.debug(
f"MLflowHandler skipped non-numeric metrics: {skipped_keys}. "
"Only numeric values can be logged as MLflow metrics."
)

if metrics_dict:
mlflow.log_metrics(metrics_dict, step=step)

def close(self):
"""End the MLflow run and cleanup resources."""
if self._mlflow_run is not None:
mlflow.end_run()
self._mlflow_run = None
super().close()


class AsyncStructuredHandler(logging.Handler):
"""Logger that asynchronously writes data to a JSONL file.

Expand Down Expand Up @@ -708,7 +859,17 @@ def setup_root_logger(level="DEBUG"):
)


def setup_metric_logger(loggers, run_name, output_dir):
def setup_metric_logger(
loggers,
run_name,
output_dir,
*,
mlflow_tracking_uri: str | None = None,
mlflow_experiment_name: str | None = None,
wandb_project: str | None = None,
wandb_entity: str | None = None,
tensorboard_log_dir: str | None = None,
):
"""Configure the metric logging system with specified backends.

This function sets up a comprehensive logging configuration that supports
Expand All @@ -717,10 +878,17 @@ def setup_metric_logger(loggers, run_name, output_dir):

Args:
loggers: A string or list of strings specifying which logging backends to use.
Supported values: "tensorboard", "wandb", "async"
Supported values: "tensorboard", "wandb", "mlflow", "async"
run_name: Name for the current training run. Can include placeholders like
{time}, {rank}, {utc_time}, {local_rank}.
output_dir: Directory where log files will be stored
mlflow_tracking_uri: MLflow tracking server URI (e.g., "http://localhost:5000").
Falls back to MLFLOW_TRACKING_URI environment variable if not provided.
mlflow_experiment_name: MLflow experiment name.
Falls back to MLFLOW_EXPERIMENT_NAME environment variable if not provided.
wandb_project: Weights & Biases project name.
wandb_entity: Weights & Biases team/entity name.
tensorboard_log_dir: Directory for TensorBoard logs. Defaults to output_dir if not provided.

Example:
```python
Expand All @@ -731,11 +899,13 @@ def setup_metric_logger(loggers, run_name, output_dir):
output_dir="logs"
)

# Setup logging with a single backend
# Setup logging with MLflow
setup_metric_logger(
loggers="tensorboard",
loggers=["mlflow"],
run_name="my_run",
output_dir="logs"
output_dir="logs",
mlflow_tracking_uri="http://localhost:5000",
mlflow_experiment_name="my_experiment"
)
```
"""
Expand Down Expand Up @@ -773,14 +943,25 @@ def setup_metric_logger(loggers, run_name, output_dir):
},
"tensorboard": {
"()": TensorBoardHandler,
"log_dir": output_dir,
"log_dir": tensorboard_log_dir or output_dir,
"run_name": run_name,
"filters": ["is_mapping", "is_rank0"],
},
"wandb": {
"()": WandbHandler,
"log_dir": output_dir,
"run_name": run_name,
"project": wandb_project,
"entity": wandb_entity,
"filters": ["is_mapping", "is_rank0"],
},
"mlflow": {
"()": MLflowHandler,
"run_name": run_name,
"tracking_uri": mlflow_tracking_uri
or os.environ.get("MLFLOW_TRACKING_URI"),
"experiment_name": mlflow_experiment_name
or os.environ.get("MLFLOW_EXPERIMENT_NAME"),
"filters": ["is_mapping", "is_rank0"],
},
},
Expand Down
67 changes: 65 additions & 2 deletions src/instructlab/training/main_ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,16 @@ def main(args):
"DeepSpeed was selected and CPU offloading was requested, but DeepSpeedCPUAdam could not be imported. This likely means you need to build DeepSpeed with the CPU adam flags."
)

setup_metric_logger(args.logger_type, args.run_name, args.output_dir)
setup_metric_logger(
args.logger_type,
args.run_name,
args.output_dir,
mlflow_tracking_uri=args.mlflow_tracking_uri,
mlflow_experiment_name=args.mlflow_experiment_name,
wandb_project=args.wandb_project,
wandb_entity=args.wandb_entity,
tensorboard_log_dir=args.tensorboard_log_dir,
)
metric_logger = logging.getLogger("instructlab.training.metrics")
if os.environ["LOCAL_RANK"] == "0":
metric_logger.info(vars(args), extra={"hparams": True})
Expand Down Expand Up @@ -457,7 +466,16 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
# Enable package logging propagation before setting up loggers
propagate_package_logs(True)
setup_root_logger(train_args.log_level)
setup_metric_logger("async", None, train_args.ckpt_output_dir)
setup_metric_logger(
train_args.logger_type,
train_args.run_name,
train_args.ckpt_output_dir,
mlflow_tracking_uri=train_args.mlflow_tracking_uri,
mlflow_experiment_name=train_args.mlflow_experiment_name,
wandb_project=train_args.wandb_project,
wandb_entity=train_args.wandb_entity,
tensorboard_log_dir=train_args.tensorboard_log_dir,
)

logger = logging.getLogger("instructlab.training")
logger.info("Starting training setup...")
Expand Down Expand Up @@ -545,9 +563,24 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
f"--adamw_beta1={train_args.adamw_betas[0]}",
f"--adamw_beta2={train_args.adamw_betas[1]}",
f"--adamw_eps={train_args.adamw_eps}",
f"--logger_type={train_args.logger_type}",
]
)

# Add optional logging parameters
if train_args.run_name is not None:
command.append(f"--run_name={train_args.run_name}")
if train_args.mlflow_tracking_uri is not None:
command.append(f"--mlflow_tracking_uri={train_args.mlflow_tracking_uri}")
if train_args.mlflow_experiment_name is not None:
command.append(f"--mlflow_experiment_name={train_args.mlflow_experiment_name}")
if train_args.wandb_project is not None:
command.append(f"--wandb_project={train_args.wandb_project}")
if train_args.wandb_entity is not None:
command.append(f"--wandb_entity={train_args.wandb_entity}")
if train_args.tensorboard_log_dir is not None:
command.append(f"--tensorboard_log_dir={train_args.tensorboard_log_dir}")

if train_args.pretraining_config is not None:
command.append(f"--block-size={train_args.pretraining_config.block_size}")
command.append(
Expand Down Expand Up @@ -766,6 +799,36 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
parser.add_argument("--log_level", type=str, default="INFO")
parser.add_argument("--run_name", type=str, default=None)
parser.add_argument("--logger_type", type=str, default="async")
parser.add_argument(
"--mlflow_tracking_uri",
type=str,
default=None,
help="MLflow tracking server URI (e.g., 'http://localhost:5000')",
)
parser.add_argument(
"--mlflow_experiment_name",
type=str,
default=None,
help="MLflow experiment name",
)
parser.add_argument(
"--wandb_project",
type=str,
default=None,
help="Weights & Biases project name",
)
parser.add_argument(
"--wandb_entity",
type=str,
default=None,
help="Weights & Biases team/entity name",
)
parser.add_argument(
"--tensorboard_log_dir",
type=str,
default=None,
help="Directory for TensorBoard logs. Defaults to output_dir if not specified.",
)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--mock_data", action="store_true")
parser.add_argument("--mock_len", type=int, default=2600)
Expand Down
Loading