Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions qlib/workflow/expm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
# Licensed under the MIT License.

from urllib.parse import urlparse
from urllib.request import url2pathname
import mlflow
from filelock import FileLock
from mlflow.exceptions import MlflowException, RESOURCE_ALREADY_EXISTS, ErrorCode
from mlflow.entities import ViewType
import os
from typing import Optional, Text
from pathlib import Path

Expand Down Expand Up @@ -233,7 +233,14 @@ def _get_or_create_exp(self, experiment_id=None, experiment_name=None) -> (objec
# So we supported it in the interface wrapper
pr = urlparse(self.uri)
if pr.scheme == "file":
with FileLock(Path(os.path.join(pr.netloc, pr.path.lstrip("/"), "filelock"))): # pylint: disable=E0110
# `pr.path` of an absolute file:// URI is already an absolute path. The previous
# `os.path.join(pr.netloc, pr.path.lstrip("/"), ...)` stripped the leading "/" and
# produced a CWD-relative lock path, so the lock (and its parent dirs) landed wherever
# the process happened to be running instead of at the URI's location. `url2pathname`
# restores the absolute path and also handles Windows drive-letter URIs correctly.
lock_dir = Path(url2pathname(pr.path))
lock_dir.mkdir(parents=True, exist_ok=True)
with FileLock(lock_dir / "filelock"): # pylint: disable=E0110
return self.create_exp(experiment_name), True
# NOTE: for other schemes like http, we double check to avoid create exp conflicts
try:
Expand Down
30 changes: 25 additions & 5 deletions qlib/workflow/recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,18 +364,38 @@ def _log_uncommitted_code(self):
Mlflow only log the commit id of the current repo. But usually, user will have a lot of uncommitted changes.
So this tries to automatically to log them all.
"""
# This is an opportunistic reproducibility hook, not a precondition for running an experiment.
# When the CWD is not inside a git work tree (containers, CI sandboxes, a tempdir, ...) the git
# commands below fail; previously `shell=True` without capturing stderr leaked git's multi-line
# "usage: git diff --no-index ..." banner to the parent's stderr (bypassing this logger), and each
# failure emitted a noisy INFO record. So we first check for a work tree and skip silently if absent.
try:
probe = subprocess.run(
["git", "rev-parse", "--is-inside-work-tree"],
stdout=subprocess.PIPE,
stderr=subprocess.DEVNULL,
check=False,
)
except OSError:
# git is not installed / not on PATH
logger.debug("Skip logging uncommitted code: git is not available.")
return
if probe.returncode != 0 or probe.stdout.strip() != b"true":
logger.debug(f"Skip logging uncommitted code: $CWD({os.getcwd()}) is not a git work tree.")
return

# TODO: the sub-directories maybe git repos.
# So it will be better if we can walk the sub-directories and log the uncommitted changes.
for cmd, fname in [
("git diff", "code_diff.txt"),
("git status", "code_status.txt"),
("git diff --cached", "code_cached.txt"),
(["git", "diff"], "code_diff.txt"),
(["git", "status"], "code_status.txt"),
(["git", "diff", "--cached"], "code_cached.txt"),
]:
try:
out = subprocess.check_output(cmd, shell=True)
out = subprocess.check_output(cmd, stderr=subprocess.DEVNULL)
self.client.log_text(self.id, out.decode(), fname) # this behaves same as above
except subprocess.CalledProcessError:
logger.info(f"Fail to log the uncommitted code of $CWD({os.getcwd()}) when run {cmd}.")
logger.debug(f"Fail to log the uncommitted code of $CWD({os.getcwd()}) when run {' '.join(cmd)}.")

def end_run(self, status: str = Recorder.STATUS_S):
assert status in [
Expand Down
152 changes: 152 additions & 0 deletions tests/test_workflow_cwd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
Regression tests for the current-working-directory assumptions in ``qlib.workflow``.

See https://github.com/microsoft/qlib/issues/2252:

1. ``MLflowExpManager`` built its ``FileLock`` path from ``file://`` URIs in a way that
produced a CWD-relative path, so the lock (and its parent dirs) was created wherever
the process happened to be running instead of at the location the URI named.
2. ``MLflowRecorder._log_uncommitted_code`` shelled out to ``git`` without capturing
stderr; outside a git work tree git's usage banner leaked to the parent's stderr and
each command emitted a noisy log record.
"""

import os
import shutil
import tempfile
import unittest
import subprocess
from contextlib import contextmanager
from pathlib import Path

import mlflow

from qlib.config import C
from qlib.workflow.expm import MLflowExpManager
from qlib.workflow.recorder import MLflowRecorder


@contextmanager
def chdir(path):
"""Temporarily change the working directory."""
old = os.getcwd()
os.chdir(str(path))
try:
yield
finally:
os.chdir(old)


@contextmanager
def capture_fd_stderr():
"""
Capture OS-level file descriptor 2.

A child process spawned with ``shell=True`` and no ``stderr=`` redirection writes to
the inherited fd 2, which ``contextlib.redirect_stderr`` (it only swaps ``sys.stderr``)
would not see. We redirect the real fd so the test observes exactly what a user's
terminal / log pipe would.
"""
saved_fd = os.dup(2)
tmp = tempfile.TemporaryFile(mode="w+b")
os.dup2(tmp.fileno(), 2)
try:
yield tmp
finally:
os.dup2(saved_fd, 2)
os.close(saved_fd)


class MLflowExpManagerCWDTest(unittest.TestCase):
"""The experiment FileLock must live at the absolute path the file:// URI names."""

def setUp(self) -> None:
self.store = Path(tempfile.mkdtemp(prefix="qlib-store-"))
self.unrelated_cwd = Path(tempfile.mkdtemp(prefix="qlib-cwd-"))
# Constructing MLflowExpManager mutates the global config; snapshot to restore it.
self._saved_uri = C.exp_manager.get("kwargs", {}).get("uri")

def tearDown(self) -> None:
if self._saved_uri is not None:
C.exp_manager.setdefault("kwargs", {})["uri"] = self._saved_uri
shutil.rmtree(self.store, ignore_errors=True)
shutil.rmtree(self.unrelated_cwd, ignore_errors=True)

def test_filelock_resolves_to_absolute_uri_not_cwd(self):
mlruns = self.store / "mlruns" # absolute path that does not exist yet
expm = MLflowExpManager(uri="file:" + str(mlruns), default_exp_name="Experiment")

with chdir(self.unrelated_cwd):
expm._get_or_create_exp(experiment_name="cwd-test")
# The buggy code created `<cwd>/<abs-path-without-leading-slash>/filelock`,
# i.e. a stray tree under the (unrelated) working directory. Nothing may land here.
leftovers = sorted(p.name for p in Path.cwd().iterdir())
self.assertEqual(
leftovers,
[],
f"FileLock created files relative to CWD instead of the absolute URI: {leftovers}",
)

# And the lock/experiment store must exist at the absolute location the URI named.
self.assertTrue(mlruns.exists(), f"experiment store was not created at {mlruns}")


class MLflowRecorderGitSnapshotTest(unittest.TestCase):
"""`_log_uncommitted_code` must be quiet outside a repo and still work inside one."""

def setUp(self) -> None:
self.store = Path(tempfile.mkdtemp(prefix="qlib-store-"))
self.uri = "file:" + str(self.store / "mlruns")
self.client = mlflow.tracking.MlflowClient(tracking_uri=self.uri)
self.exp_id = self.client.create_experiment("git-snapshot-test")
self.non_git = Path(tempfile.mkdtemp(prefix="qlib-nongit-"))
self.git_repo = Path(tempfile.mkdtemp(prefix="qlib-gitrepo-"))

def tearDown(self) -> None:
for d in (self.store, self.non_git, self.git_repo):
shutil.rmtree(d, ignore_errors=True)

def _new_recorder(self) -> MLflowRecorder:
rec = MLflowRecorder(self.exp_id, self.uri)
run = rec.client.create_run(self.exp_id)
rec.id = run.info.run_id
return rec

def test_non_git_cwd_is_silent_and_logs_no_artifacts(self):
rec = self._new_recorder()
with chdir(self.non_git), capture_fd_stderr() as err:
rec._log_uncommitted_code()
err.seek(0)
leaked = err.read()

# git's "Not a git repository" / "usage: git ..." banner must not leak to stderr.
self.assertNotIn(b"Not a git repository", leaked)
self.assertNotIn(b"usage: git", leaked)
# No code snapshot artifacts should be produced outside a work tree.
self.assertEqual([a.path for a in rec.client.list_artifacts(rec.id)], [])

def test_git_work_tree_logs_code_snapshot(self):
env = {
**os.environ,
"GIT_AUTHOR_NAME": "qlib-test",
"GIT_AUTHOR_EMAIL": "qlib-test@example.com",
"GIT_COMMITTER_NAME": "qlib-test",
"GIT_COMMITTER_EMAIL": "qlib-test@example.com",
}
subprocess.run(["git", "init", "-q"], cwd=self.git_repo, check=True, env=env)
(self.git_repo / "a.txt").write_text("hello\n")
subprocess.run(["git", "add", "a.txt"], cwd=self.git_repo, check=True, env=env)
(self.git_repo / "a.txt").write_text("hello\nworld\n") # unstaged change on top

rec = self._new_recorder()
with chdir(self.git_repo):
rec._log_uncommitted_code()

artifacts = sorted(a.path for a in rec.client.list_artifacts(rec.id))
self.assertEqual(artifacts, ["code_cached.txt", "code_diff.txt", "code_status.txt"])


if __name__ == "__main__":
unittest.main()