From 9568145e69cf81c87d6d0665a77a53cdb00bbc8d Mon Sep 17 00:00:00 2001 From: Nikhil Sinha Date: Mon, 6 Oct 2025 10:57:07 -0700 Subject: [PATCH 01/39] Add mode for cog-runtime to use signals --- cmd/cog/main.go | 2 ++ internal/config/config.go | 1 + internal/runner/manager.go | 8 +++++- internal/runner/runner.go | 47 ++++++++++++++++++++++++++++++++++++ python/coglet/__main__.py | 5 +++- python/coglet/file_runner.py | 39 ++++++++++++++++++++++++++---- 6 files changed, 95 insertions(+), 7 deletions(-) diff --git a/cmd/cog/main.go b/cmd/cog/main.go index 96404148..90f4109c 100644 --- a/cmd/cog/main.go +++ b/cmd/cog/main.go @@ -23,6 +23,7 @@ type ServerCmd struct { UseProcedureMode bool `help:"Enable procedure mode for concurrent predictions" name:"use-procedure-mode" env:"COG_USE_PROCEDURE_MODE"` AwaitExplicitShutdown bool `help:"Wait for explicit shutdown signal instead of auto-shutdown" name:"await-explicit-shutdown" env:"COG_AWAIT_EXPLICIT_SHUTDOWN"` OneShot bool `help:"Enable one-shot mode (single runner, wait for cleanup before ready)" name:"one-shot" env:"COG_ONE_SHOT"` + SignalMode bool `help:"Enable signal mode (use signals instead of webhooks for IPC communication)" name:"signal-mode" env:"COG_SIGNAL_MODE"` UploadURL string `help:"Base URL for uploading prediction output files" name:"upload-url" env:"COG_UPLOAD_URL"` WorkingDirectory string `help:"Override the working directory for predictions" name:"working-directory" env:"COG_WORKING_DIRECTORY"` RunnerShutdownGracePeriod time.Duration `help:"Grace period before force-killing prediction runners" name:"runner-shutdown-grace-period" default:"600s" env:"COG_RUNNER_SHUTDOWN_GRACE_PERIOD"` @@ -78,6 +79,7 @@ func buildServiceConfig(s *ServerCmd) (config.Config, error) { WorkingDirectory: workingDir, UploadURL: s.UploadURL, IPCUrl: fmt.Sprintf("http://localhost:%d/_ipc", s.Port), + SignalMode: s.SignalMode, MaxRunners: s.MaxRunners, RunnerShutdownGracePeriod: s.RunnerShutdownGracePeriod, CleanupTimeout: s.CleanupTimeout, diff --git a/internal/config/config.go b/internal/config/config.go index 7514e562..2ab1e1a7 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -19,6 +19,7 @@ type Config struct { UseProcedureMode bool AwaitExplicitShutdown bool OneShot bool + SignalMode bool // Directory configuration WorkingDirectory string diff --git a/internal/runner/manager.go b/internal/runner/manager.go index 5da8658f..a1dfd20f 100644 --- a/internal/runner/manager.go +++ b/internal/runner/manager.go @@ -272,6 +272,7 @@ func (m *Manager) createDefaultRunner(ctx context.Context) (*Runner, error) { log.Debugw("creating default runner", "working_dir", workingDir, "ipc_url", m.cfg.IPCUrl, + "signal_mode", m.cfg.SignalMode, "python_bin", m.cfg.PythonBinPath, ) @@ -284,10 +285,15 @@ func (m *Manager) createDefaultRunner(ctx context.Context) (*Runner, error) { "-u", "-m", "coglet", "--name", DefaultRunnerName, - "--ipc-url", m.cfg.IPCUrl, "--working-dir", workingDir, } + if m.cfg.SignalMode { + args = append(args, "--signal-mode") + } else { + args = append(args, "--ipc-url", m.cfg.IPCUrl) + } + log.Debugw("runner command", "python_path", pythonPath, "args", args, "working_dir", workingDir) tmpDir, err := os.MkdirTemp("", "cog-runner-tmp-") diff --git a/internal/runner/runner.go b/internal/runner/runner.go index 69f78421..89321cdb 100644 --- a/internal/runner/runner.go +++ b/internal/runner/runner.go @@ -30,6 +30,12 @@ import ( "github.com/replicate/cog-runtime/internal/webhook" ) +const ( + SigOutput = syscall.SIGHUP + SigReady = syscall.SIGUSR1 + SigBusy = syscall.SIGUSR2 +) + var ( LogRegex = regexp.MustCompile(`^\[pid=(?P[^]]+)] (?P.*)$`) ResponseRegex = regexp.MustCompile(`^response-(?P\S+)-(?P\d+).json$`) @@ -968,6 +974,47 @@ func (r *Runner) HandleIPC(status string) error { return nil } +// HandleSignal does the exact same things as HandleIPC just using signals +// instead of webhooks. This only can be used in non-pipeline use cases +func (r *Runner) HandleSignal(status syscall.Signal) error { + switch status { + case SigReady: + if r.status == StatusStarting { + r.updateSchema() + r.updateSetupResult() + // Close setupComplete channel to signal first READY after setup + r.mu.Lock() + select { + case <-r.setupComplete: + // Already closed + default: + close(r.setupComplete) + } + r.mu.Unlock() + } + if err := r.updateStatus("READY"); err != nil { + return fmt.Errorf("failed to update status: %w", err) + } + case SigBusy: + if err := r.updateStatus("BUSY"); err != nil { + return fmt.Errorf("failed to update status: %w", err) + } + case SigOutput: + // Notify all active prediction watchers of OUTPUT event + r.mu.RLock() + for _, pending := range r.pending { + select { + case pending.outputNotify <- struct{}{}: + // Notification sent + default: + // Channel full, skip (watcher will poll anyway) + } + } + r.mu.RUnlock() + } + return nil +} + func (r *Runner) updateSchema() { r.mu.Lock() defer r.mu.Unlock() diff --git a/python/coglet/__main__.py b/python/coglet/__main__.py index 46abedee..b057d5d5 100644 --- a/python/coglet/__main__.py +++ b/python/coglet/__main__.py @@ -53,7 +53,9 @@ def pre_setup(logger: logging.Logger, working_dir: str) -> Optional[file_runner. def main() -> int: parser = argparse.ArgumentParser() parser.add_argument('--name', metavar='NAME', required=True, help='name') - parser.add_argument('--ipc-url', metavar='URL', required=True, help='IPC URL') + group = parser.add_mutually_exclusive_group() + group.add_argument('--ipc-url', metavar='URL', required=True, help='IPC URL') + group.add_argument('--signal_mode', action='store_true') parser.add_argument( '--working-dir', metavar='DIR', required=True, help='working directory' ) @@ -85,6 +87,7 @@ def main() -> int: logger=logger, name=args.name, ipc_url=args.ipc_url, + signal_mode=args.signal_mode, working_dir=args.working_dir, config=config, ).start() diff --git a/python/coglet/file_runner.py b/python/coglet/file_runner.py index 582253b9..67a37d9a 100644 --- a/python/coglet/file_runner.py +++ b/python/coglet/file_runner.py @@ -5,6 +5,7 @@ import pathlib import re import signal +import sys import tempfile import urllib.request from dataclasses import dataclass @@ -25,6 +26,13 @@ class FileRunner: REQUEST_RE = re.compile(r'^request-(?P\S+).json$') RESPONSE_FMT = 'response-{pid}-{epoch:05d}.json' + # Signal parent to scan output + SIG_OUTPUT = signal.SIGHUP + + # Signal ready or busy status + SIG_READY = signal.SIGUSR1 + SIG_BUSY = signal.SIGUSR2 + # IPC status updates to Go server IPC_READY = 'READY' IPC_BUSY = 'BUSY' @@ -35,16 +43,21 @@ def __init__( *, logger: logging.Logger, name: str, - ipc_url: str, + ipc_url: str|None, working_dir: str, config: Config, + signal_mode: bool=False, ): + if not signal_mode and not ipc_url: + raise ValueError("IPC URL cannot be null if signal mode is false") + self.signal_mode = signal_mode self.logger = logger self.name = name self.ipc_url = ipc_url self.working_dir = working_dir self.config = config self.runner: Optional[runner.Runner] = None + self.isatty = sys.stdout.isatty() async def start(self) -> int: self.logger.info( @@ -117,7 +130,10 @@ def _cancel_handler(signum, _) -> None: signal.signal(signal.SIGINT, signal.SIG_IGN) ready = True - self._send_ipc(FileRunner.IPC_READY) + if self.signal_mode: + self._signal(FileRunner.SIG_READY) + else: + self._send_ipc(FileRunner.IPC_READY) # Go server cannot receive IPC yet when a procedure is starting # Write a ready file as signal with open(ready_file, 'w') as f: @@ -127,7 +143,10 @@ def _cancel_handler(signum, _) -> None: while True: if not ready and len(pending) < self.config.max_concurrency: ready = True - self._send_ipc(FileRunner.IPC_READY) + if self.signal_mode: + self._signal(FileRunner.SIG_READY) + else: + self._send_ipc(FileRunner.IPC_READY) if os.path.exists(stop_file): self.logger.info('stopping file runner') @@ -172,7 +191,10 @@ def _cancel_handler(signum, _) -> None: if ready and len(pending) + 1 == self.config.max_concurrency: ready = False - self._send_ipc(FileRunner.IPC_BUSY) + if self.signal_mode: + self._signal(FileRunner.SIG_BUSY) + else: + self._send_ipc(FileRunner.IPC_BUSY) pending[pid] = asyncio.create_task(self._predict(pid, req)) self.logger.info('prediction started: id=%s', pid) @@ -284,7 +306,10 @@ def _respond( ) os.rename(temp_path, resp_path) - self._send_ipc(FileRunner.IPC_OUTPUT) + if self.signal_mode: + self._signal(FileRunner.SIG_OUTPUT) + else: + self._send_ipc(FileRunner.IPC_OUTPUT) def _send_ipc(self, status: str) -> None: try: @@ -297,3 +322,7 @@ def _send_ipc(self, status: str) -> None: urllib.request.urlopen(self.ipc_url, data=data).read() except Exception as e: self.logger.exception('IPC failed: %s', e) + + def _signal(self, signum: int) -> None: + if not self.isatty: + os.kill(os.getppid(), signum) From dd441e608182c317c7cd3e2eb9a28107f7150223 Mon Sep 17 00:00:00 2001 From: Nikhil Sinha Date: Mon, 6 Oct 2025 11:07:54 -0700 Subject: [PATCH 02/39] Use proper type hint --- python/coglet/file_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/coglet/file_runner.py b/python/coglet/file_runner.py index 67a37d9a..592b38eb 100644 --- a/python/coglet/file_runner.py +++ b/python/coglet/file_runner.py @@ -43,7 +43,7 @@ def __init__( *, logger: logging.Logger, name: str, - ipc_url: str|None, + ipc_url: Optional[str], working_dir: str, config: Config, signal_mode: bool=False, From dae112b7c03aadbbd5be8917e5bf259bfc8188a3 Mon Sep 17 00:00:00 2001 From: Nikhil Sinha Date: Mon, 6 Oct 2025 11:26:05 -0700 Subject: [PATCH 03/39] Make sure IPC URL exists when calling send_ipc --- python/coglet/file_runner.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/coglet/file_runner.py b/python/coglet/file_runner.py index 592b38eb..1649bac4 100644 --- a/python/coglet/file_runner.py +++ b/python/coglet/file_runner.py @@ -313,6 +313,8 @@ def _respond( def _send_ipc(self, status: str) -> None: try: + if not self.ipc_url: + raise RuntimeError("IPC invoked but IPC URL not provided") payload = { 'name': self.name, 'pid': os.getpid(), From 71f255b2fc8e702361f2c254498a92a36e0929bd Mon Sep 17 00:00:00 2001 From: Nikhil Sinha Date: Mon, 6 Oct 2025 12:33:42 -0700 Subject: [PATCH 04/39] Linting --- internal/server/server.go | 20 ++++++++++++++++++-- python/coglet/file_runner.py | 6 +++--- python/coglet/file_runner.pyi | 6 +++++- 3 files changed, 26 insertions(+), 6 deletions(-) diff --git a/internal/server/server.go b/internal/server/server.go index a26f7b33..30110bc2 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -10,6 +10,7 @@ import ( "net/http" "os" "path" + "path/filepath" "sync/atomic" "time" @@ -394,11 +395,26 @@ func writeReadyFile() error { dir := "/var/run/cog" file := path.Join(dir, "ready") - if _, err := os.Stat(file); os.IsNotExist(err) { + return writeFileIfNotExists(file) +} + +// If the checkpoint flow is turned on, write the ready file for checkpointing +func writeCheckpointReadyFile() error { + file := os.Getenv("ENTRYPOINT_CUDA_READY_LOCK_FILE") + if file == "" { + return nil + } + + return writeFileIfNotExists(file) +} + +func writeFileIfNotExists(path string) error { + dir := filepath.Dir(path) + if _, err := os.Stat(path); os.IsNotExist(err) { if err := os.MkdirAll(dir, 0o700); err != nil { return err } - if err := os.WriteFile(file, nil, 0o600); err != nil { + if err := os.WriteFile(path, nil, 0o600); err != nil { return err } } diff --git a/python/coglet/file_runner.py b/python/coglet/file_runner.py index 1649bac4..ed51d8a6 100644 --- a/python/coglet/file_runner.py +++ b/python/coglet/file_runner.py @@ -46,10 +46,10 @@ def __init__( ipc_url: Optional[str], working_dir: str, config: Config, - signal_mode: bool=False, + signal_mode: bool = False, ): if not signal_mode and not ipc_url: - raise ValueError("IPC URL cannot be null if signal mode is false") + raise ValueError('IPC URL cannot be null if signal mode is false') self.signal_mode = signal_mode self.logger = logger self.name = name @@ -314,7 +314,7 @@ def _respond( def _send_ipc(self, status: str) -> None: try: if not self.ipc_url: - raise RuntimeError("IPC invoked but IPC URL not provided") + raise RuntimeError('IPC invoked but IPC URL not provided') payload = { 'name': self.name, 'pid': os.getpid(), diff --git a/python/coglet/file_runner.pyi b/python/coglet/file_runner.pyi index 56f71f5a..907555fd 100644 --- a/python/coglet/file_runner.pyi +++ b/python/coglet/file_runner.pyi @@ -4,6 +4,7 @@ This type stub file was generated by pyright. import logging from dataclasses import dataclass +from typing import Optional @dataclass(frozen=True) class Config: @@ -17,10 +18,13 @@ class FileRunner: CANCEL_RE = ... REQUEST_RE = ... RESPONSE_FMT = ... + SIG_OUTPUT = ... + SIG_READY = ... + SIG_BUSY = ... IPC_READY = ... IPC_BUSY = ... IPC_OUTPUT = ... - def __init__(self, *, logger: logging.Logger, name: str, ipc_url: str, working_dir: str, config: Config) -> None: + def __init__(self, *, logger: logging.Logger, name: str, ipc_url: Optional[str], working_dir: str, config: Config, signal_mode: bool = ...) -> None: ... async def start(self) -> int: From 29d6d6a1b982436ccd72cb04f050551ce12efd7f Mon Sep 17 00:00:00 2001 From: Nikhil Sinha Date: Mon, 6 Oct 2025 12:48:11 -0700 Subject: [PATCH 05/39] Use func --- internal/server/server.go | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/internal/server/server.go b/internal/server/server.go index 30110bc2..72b9ec20 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -98,6 +98,11 @@ func (h *Handler) healthCheck() (*HealthCheck, error) { return nil, err } + if err := writeCheckpointReadyFile(); err != nil { + log.Errorw("failed to write checkpoint ready file", "error", err) + return nil, err + } + runnerSetupResult := h.runnerManager.SetupResult() concurrency := h.runnerManager.Concurrency() runnerStatus := h.runnerManager.Status() @@ -408,13 +413,13 @@ func writeCheckpointReadyFile() error { return writeFileIfNotExists(file) } -func writeFileIfNotExists(path string) error { - dir := filepath.Dir(path) - if _, err := os.Stat(path); os.IsNotExist(err) { +func writeFileIfNotExists(fpath string) error { + dir := filepath.Dir(fpath) + if _, err := os.Stat(fpath); os.IsNotExist(err) { if err := os.MkdirAll(dir, 0o700); err != nil { return err } - if err := os.WriteFile(path, nil, 0o600); err != nil { + if err := os.WriteFile(fpath, nil, 0o600); err != nil { return err } } From 17d67770478f37c2154823d34d50234c095ab04b Mon Sep 17 00:00:00 2001 From: Nikhil Sinha Date: Tue, 7 Oct 2025 12:30:46 -0700 Subject: [PATCH 06/39] Wire up checkpointing in cog-runtime --- internal/checkpointer/checkpointer.go | 253 ++++++++++++++++++++++++++ internal/checkpointer/utils.go | 184 +++++++++++++++++++ internal/runner/manager.go | 99 ++++++++-- internal/runner/runner.go | 2 +- internal/server/server.go | 38 ++-- internal/service/service.go | 6 + 6 files changed, 552 insertions(+), 30 deletions(-) create mode 100644 internal/checkpointer/checkpointer.go create mode 100644 internal/checkpointer/utils.go diff --git a/internal/checkpointer/checkpointer.go b/internal/checkpointer/checkpointer.go new file mode 100644 index 00000000..b3c21c2f --- /dev/null +++ b/internal/checkpointer/checkpointer.go @@ -0,0 +1,253 @@ +package checkpointer + +import ( + "context" + "errors" + "fmt" + "os" + "os/exec" + "path/filepath" + "strconv" + "strings" + "time" +) + +const ( + // Configuration environment variables + locationEnvVar = "R8_LOCATION" + shouldCheckpointEnvVar = "R8_CUDA_CHECKPOINT" + leaseFileEnvVar = "R8_LEASE_FILE" + cudaCheckpointDirEnvVar = "R8_CUDA_CHECKPOINT_DIR" + cudaReadyFileEnvVar = "R8_CUDA_READY_LOCK_FILE" + + // Dependencies for the checkpoint process + cudaCheckpointURLFmtStr = "https://r8-public-assets-%s.cwobject.com/cuda-checkpoint" + criuURLFmtStr = "https://r8-public-assets-%s.cwobject.com/criu.tar.gz" + cudaCheckpointPath = "/tmp/cuda-checkpoint" + criuPath = "/tmp/criu" + + // Metadata storage paths + cudaCmdFileName = "cuda-cmd" + checkpointSubdirName = "checkpoint" +) + +var ( + errNoCheckpointDir = errors.New("Could not find checkpoint directory environment variable") +) + +type FatalCheckpointErr struct { + err error +} + +func (e *FatalCheckpointErr) Error() string { + return e.Error() +} + +type Checkpointer interface { + Disable() + HasCheckpoint() bool + Prepare(ctx context.Context) error + Checkpoint(ctx context.Context, cmd *exec.Cmd) error + Restore(ctx context.Context) (*exec.Cmd, func(context.Context) error, error) +} + +type checkpointer struct { + enabled bool + hasCheckpoint bool + checkpointDir string + leaseFile string +} + +func NewCheckpointer(ctx context.Context) Checkpointer { + return &checkpointer{ + enabled: os.Getenv(shouldCheckpointEnvVar) == "true", + checkpointDir: os.Getenv(cudaCheckpointDirEnvVar), + leaseFile: os.Getenv(leaseFileEnvVar), + } +} + +func (c *checkpointer) Disable() { + c.enabled = false +} + +func (c *checkpointer) HasCheckpoint() bool { + if !c.enabled { + return false + } + + return c.hasCheckpoint +} + +func (c *checkpointer) Prepare(ctx context.Context) error { + if !c.enabled { + return nil + } + + // Download dependencies + err := downloadCUDACheckpointBinaries(ctx) + if err != nil { + return err + } + + // Wait for IPC lease file to be deleted + if c.leaseFile != "" { + err = pollForFileDeletion(c.leaseFile, 5*time.Minute, 10*time.Second) + if err != nil { + return err + } + } + + empty, err := isDirEmpty(filepath.Join(c.checkpointDir, checkpointSubdirName)) + // If the err is not nil, it probably means the directory does not exist + if err == nil && !empty { + c.hasCheckpoint = true + } + + return nil +} + +func (c *checkpointer) Checkpoint(ctx context.Context, cogletCmd *exec.Cmd) error { + if !c.enabled { + return nil + } + + if c.checkpointDir == "" { + return errNoCheckpointDir + } + + err := os.MkdirAll(filepath.Join(c.checkpointDir, checkpointSubdirName), 0o666) + if err != nil { + return err + } + + pid := strconv.Itoa(cogletCmd.Process.Pid) + + // Find the PID of the command that is actually using the GPU + cudaPIDBytes, err := exec.CommandContext(ctx, "nvidia-smi", "--query-compute-apps=pid", "--format=csv,noheader").Output() + if err != nil { + return err + } + + cudaPID := strings.TrimSpace(string(cudaPIDBytes)) + + // Get the command for this PID - it is _not_ always the root python process + data, err := exec.CommandContext(ctx, "ps", "-o", "cmd=", cudaPID).Output() + if err != nil { + return err + } + + cudaCmd := strings.TrimSpace(string(data)) + + // Write said command to a file for later + err = os.WriteFile(filepath.Join(c.checkpointDir, cudaCmdFileName), []byte(cudaCmd), 0o666) + if err != nil { + return err + } + + // Toggle CUDA off + cmd := exec.CommandContext(ctx, cudaCheckpointPath, "--toggle", "--pid", string(cudaPID)) + if err := cmd.Run(); err != nil { + return err + } + + // CRIU checkpoint (leaving process running) + cmd = exec.CommandContext(ctx, criuPath, "dump", "--leave-running", "--shell-job", "--images-dir", filepath.Join(c.checkpointDir, checkpointSubdirName), "--tree", pid) + if err := cmd.Run(); err != nil { + // Try to toggle CUDA back on. If we aren't able to restart CUDA, the process + // will hang indefinitely, so we should kill it and try to start a new one + // without checkpointing + cmd = exec.CommandContext(ctx, cudaCheckpointPath, "--toggle", "--pid", string(cudaPID)) + if cudaErr := cmd.Run(); cudaErr != nil { + // Return a fatal error so upstream knows we cannot continue in the current state + return &FatalCheckpointErr{ + err: cudaErr, + } + } + // Return the original checkpointing error + return err + } + + // Toggle CUDA back on. If we aren't able to restart CUDA, the process + // will hang indefinitely, so we should kill it and try to start a new + // one without checkpointing + cmd = exec.CommandContext(ctx, cudaCheckpointPath, "--toggle", "--pid", string(cudaPID)) + if err := cmd.Run(); err != nil { + // Return a fatal error so upstream knows we cannot continue in the current state + return &FatalCheckpointErr{ + err: err, + } + } + + return setStatusReady() +} + +func (c *checkpointer) Restore(ctx context.Context) (*exec.Cmd, func(context.Context) error, error) { + if !c.enabled { + return nil, nil, nil + } + + // Read process from sentinel file + cudaCmd, err := os.ReadFile(filepath.Join(c.checkpointDir, cudaCmdFileName)) + if err != nil { + return nil, nil, err + } + + // Set up restore command + restoreCmd := exec.CommandContext(ctx, criuPath, "restore", "--shell-job", "--images-dir", filepath.Join(c.checkpointDir, checkpointSubdirName)) + + // Set up callback function once restore is started + callback := func(con context.Context) error { + // Get the PID for the command + cudaPID, err := exec.CommandContext(con, "pgrep", "-fx", string(cudaCmd)).Output() + if err != nil { + // If this command failed, we want to best effort try to kill the started process, + // since we'll start a new one + restoreCmd.Process.Kill() + + return err + } + + // Toggle CUDA on for the restored process + cmd := exec.CommandContext(con, cudaCheckpointPath, "--toggle", "--pid", string(cudaPID)) + if err := cmd.Run(); err != nil { + // If this command failed, we want to best effort try to kill the started process, + // since we'll start a new one + restoreCmd.Process.Kill() + + return err + } + + err = setStatusReady() + if err != nil { + // If this command failed, we want to best effort try to kill the started process, + // since we'll start a new one + restoreCmd.Process.Kill() + + return err + } + + return nil + } + + // The restored command is a running instance of coglet + return restoreCmd, callback, nil +} + +func downloadCUDACheckpointBinaries(ctx context.Context) error { + location := os.Getenv("R8_LOCATION") + + // Download the cuda-checkpoint binary + err := downloadAndChmod(fmt.Sprintf(cudaCheckpointURLFmtStr, location), cudaCheckpointPath) + if err != nil { + return fmt.Errorf("failed to download and chmod cuda-checkpoint binary: %w", err) + } + // CRIU gets downloaded as a tar with its dependencies. So we need to extract the tar, then + // link the LD_LIBRARY_PATH to the dependencies + dir := filepath.Dir(criuPath) + err = downloadAndUntar(ctx, fmt.Sprintf(criuURLFmtStr, location), dir) + if err != nil { + return fmt.Errorf("failed to download and untar CRIU: %w", err) + } + updateEnvVar("LD_LIBRARY_PATH", filepath.Join(dir, "criu-lib")) + return nil +} diff --git a/internal/checkpointer/utils.go b/internal/checkpointer/utils.go new file mode 100644 index 00000000..c27195f6 --- /dev/null +++ b/internal/checkpointer/utils.go @@ -0,0 +1,184 @@ +package checkpointer + +import ( + "context" + "errors" + "fmt" + "io" + "net/http" + "os" + "os/exec" + "path/filepath" + "time" +) + +var ( + errTimedOutPolling = errors.New("Timed out while polling for file") +) + +// updateEnvVar updates an environment variable in-place, adding an item to it +// if it exists or creating it if it doesn't exist +func updateEnvVar(envVarName, newItem string) { + old := os.Getenv(envVarName) + if old == "" { + os.Setenv(envVarName, newItem) + return + } + path := newItem + string(os.PathListSeparator) + os.Getenv(envVarName) + os.Setenv(envVarName, path) +} + +// downloadFile downloads a file from the URL provided to the path provided +func downloadFile(url, path string) error { + filename := filepath.Base(path) + os.MkdirAll(filepath.Dir(path), 0o755) + + resp, err := http.Get(url) + if err != nil { + return fmt.Errorf("failed to download %s: %w", filename, err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("failed to download %s: %w", filename, err) + } + + binary, err := os.Create(path) + if err != nil { + return fmt.Errorf("failed to touch file: %w", err) + } + defer binary.Close() + + _, err = io.Copy(binary, resp.Body) + if err != nil { + return fmt.Errorf("failed to save %s: %w", filename, err) + } + + return nil +} + +// downloadAndChmod downloads a file to the path provided and chmods it for +// execution. This expects the downloaded file to be a binary +func downloadAndChmod(url, path string) error { + err := downloadFile(url, path) + if err != nil { + return err + } + + if err := os.Chmod(path, 0o755); err != nil { + return fmt.Errorf("failed to chmod file: %w", err) + } + return nil +} + +// downloadAndUntar downloads a tar and extracts it to a path. The path is expected +// to be a directory +func downloadAndUntar(ctx context.Context, url, path string) error { + // Download to `${path}/tmp.tar.gz` + downloadPath := filepath.Join(path, "tmp.tar.gz") + err := downloadFile(url, downloadPath) + if err != nil { + return err + } + + // Untar into the `${path}` dir + cmd := exec.CommandContext(ctx, "tar", "-xf", downloadPath, "-C", path) + devnull, err := os.OpenFile(os.DevNull, os.O_WRONLY, 0o755) + if err != nil { + return err + } + cmd.Stdout = devnull + cmd.Stderr = devnull + + if err := cmd.Run(); err != nil { + return fmt.Errorf("failed to extract tar: %w", err) + } + + return nil +} + +// pollForFileDeletion waits for a file to be deleted, up until a timeout. It returns an error if the +// timeout is hit +func pollForFileDeletion(target string, timeout time.Duration, pollInterval time.Duration) error { + deadline := time.After(timeout) + + for { + // Check if the file still exists, if it does keep looping + if _, err := os.Stat(target); errors.Is(err, os.ErrNotExist) { + return nil + } + + // Check for timeout before sleeping for the polling interval + select { + case <-deadline: + return errTimedOutPolling + default: + time.Sleep(pollInterval) + } + } +} + +// pollForFileExistance waits for a file to exist, up until a timeout. It returns an error if the +// timeout is hit +func pollForFileExistance(target string, timeout time.Duration, pollInterval time.Duration) error { + deadline := time.After(timeout) + + for { + // Check if the file exists, if it doesn't keep looping + if _, err := os.Stat(target); err == nil { + return nil + } + + // Check for timeout before sleeping for the polling interval + select { + case <-deadline: + return errTimedOutPolling + default: + time.Sleep(pollInterval) + } + } +} + +// https://stackoverflow.com/a/30708914/30548878 +func isDirEmpty(name string) (bool, error) { + f, err := os.Open(name) + if err != nil { + return false, err + } + defer f.Close() + + _, err = f.Readdirnames(1) + if err == io.EOF { + return true, nil + } + return false, err +} + +// Touch a file if it doesn't exist, otherwise wipes the contents of the file +func touchFile(name string) error { + // Ensure upstream directory exists for file + err := os.MkdirAll(filepath.Dir(name), 0o755) + if err != nil { + return err + } + + f, err := os.Create(name) + if err != nil { + return err + } + f.Close() + return nil +} + +// setStatusReady ensures the ready files exist +func setStatusReady() error { + cudaReadyFilePath := os.Getenv(cudaReadyFileEnvVar) + + // Touch CUDA ready file + err := touchFile(cudaReadyFilePath) + if err != nil { + return err + } + + return nil +} diff --git a/internal/runner/manager.go b/internal/runner/manager.go index a1dfd20f..eb6dbced 100644 --- a/internal/runner/manager.go +++ b/internal/runner/manager.go @@ -15,6 +15,7 @@ import ( "go.uber.org/zap" + "github.com/replicate/cog-runtime/internal/checkpointer" "github.com/replicate/cog-runtime/internal/config" "github.com/replicate/cog-runtime/internal/logging" "github.com/replicate/cog-runtime/internal/webhook" @@ -294,6 +295,13 @@ func (m *Manager) createDefaultRunner(ctx context.Context) (*Runner, error) { args = append(args, "--ipc-url", m.cfg.IPCUrl) } + // This returns an object that does nothing if it is not enabled. + cp := checkpointer.NewCheckpointer(ctx) + err := cp.Prepare(ctx) + if err != nil { + cp.Disable() + } + log.Debugw("runner command", "python_path", pythonPath, "args", args, "working_dir", workingDir) tmpDir, err := os.MkdirTemp("", "cog-runner-tmp-") @@ -301,11 +309,6 @@ func (m *Manager) createDefaultRunner(ctx context.Context) (*Runner, error) { return nil, fmt.Errorf("failed to create temp directory: %w", err) } - // Derive the runtime context from the manager's context - runtimeContext, runtimeCancel := context.WithCancel(ctx) - cmd := exec.CommandContext(runtimeContext, pythonPath, args...) //nolint:gosec // expected subprocess launched with variable - cmd.Dir = m.cfg.WorkingDirectory - cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true} env := mergeEnv(os.Environ(), m.cfg.EnvSet, m.cfg.EnvUnset) env = append(env, "TMPDIR="+tmpDir) @@ -316,8 +319,6 @@ func (m *Manager) createDefaultRunner(ctx context.Context) (*Runner, error) { env = append(env, "LOG_LEVEL=debug") } - cmd.Env = env - // Read cog.yaml for runner configuration (capacity was already set in newManager) cogYaml, err := ReadCogYaml(workingDir) if err != nil { @@ -336,16 +337,28 @@ func (m *Manager) createDefaultRunner(ctx context.Context) (*Runner, error) { tmpDir: tmpDir, uploader: uploader, } - runner, err := NewRunner(runtimeContext, runtimeCancel, runnerCtx, cmd, cogYaml.Concurrency.Max, m.cfg, m.baseLogger) - if err != nil { - return nil, err - } - runner.webhookSender = m.webhookSender - if err := runner.Start(ctx); err != nil { - return nil, fmt.Errorf("failed to start runner: %w", err) + // If there is an existing checkpoint, try to restore from the checkpoint + if cp.HasCheckpoint() { + runner, err := m.startRunnerFromCheckpoint(ctx, env, runnerCtx, cogYaml.Concurrency.Max, cp) + if err != nil { + cp.Disable() + } else { + m.runners[0] = runner + m.monitoringWG.Go(func() { + m.monitorRunnerSubprocess(m.ctx, DefaultRunnerName, runner) + }) + + return runner, nil + } } + // Derive the runtime context from the manager's context + runtimeContext, runtimeCancel := context.WithCancel(ctx) + + cmd := exec.CommandContext(runtimeContext, pythonPath, args...) //nolint:gosec // expected subprocess launched with variable + runner, err := m.setupRunner(runtimeContext, runtimeCancel, cmd, env, runnerCtx, cogYaml.Concurrency.Max) + if err := runner.Config(ctx); err != nil { if stopErr := runner.Stop(); stopErr != nil { log.Errorw("failed to stop runner", "name", DefaultRunnerName, "error", stopErr) @@ -359,6 +372,56 @@ func (m *Manager) createDefaultRunner(ctx context.Context) (*Runner, error) { m.monitorRunnerSubprocess(m.ctx, DefaultRunnerName, runner) }) + if !cp.HasCheckpoint() { + err := cp.Checkpoint(ctx, cmd) + var fatalCheckpointErr *checkpointer.FatalCheckpointErr + if errors.As(err, &fatalCheckpointErr) { + return nil, fmt.Errorf("fatal error while trying to checkpoint: %w", err) + } + // If the error is not fatal, we failed to create a checkpoint but are still + // running the original cog process, so we can just continue as if we did + // nothing + } + + return runner, nil +} + +func (m *Manager) setupRunner(runtimeContext context.Context, runtimeCancel context.CancelFunc, cmd *exec.Cmd, env []string, runnerCtx RunnerContext, maxConcurrency int) (*Runner, error) { + cmd.Dir = m.cfg.WorkingDirectory + cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true} + cmd.Env = env + + runner, err := NewRunner(runtimeContext, runtimeCancel, runnerCtx, cmd, maxConcurrency, m.cfg, m.baseLogger) + if err != nil { + return nil, err + } + + runner.webhookSender = m.webhookSender + if err := runner.Start(runtimeContext); err != nil { + return nil, fmt.Errorf("failed to start runner: %w", err) + } + + return runner, nil +} + +func (m *Manager) startRunnerFromCheckpoint(ctx context.Context, env []string, runnerCtx RunnerContext, maxConcurrency int, cp checkpointer.Checkpointer) (*Runner, error) { + // Derive the runtime context from the manager's context + runtimeContext, runtimeCancel := context.WithCancel(ctx) + + cmd, callback, err := cp.Restore(runtimeContext) + if err != nil { + return nil, err + } + + runner, err := m.setupRunner(runtimeContext, runtimeCancel, cmd, env, runnerCtx, maxConcurrency) + + err = callback(runtimeContext) + if err != nil { + return nil, fmt.Errorf("failed callback function: %w", err) + } + + // TODO: Send ready signal somehow, can we SIGHUP ourselves? + return runner, nil } @@ -937,6 +1000,14 @@ func (m *Manager) HandleRunnerIPC(runnerName, status string) error { return runner.HandleIPC(status) } +func (m *Manager) HandleRunnerSignal(runnerName string, signal os.Signal) error { + runner, _, exists := m.findRunner(runnerName) + if !exists { + return fmt.Errorf("%w: %s", ErrRunnerNotFound, runnerName) + } + return runner.HandleSignal(signal) +} + func (m *Manager) cleanupInProgress() bool { if !m.cfg.OneShot { return false diff --git a/internal/runner/runner.go b/internal/runner/runner.go index 89321cdb..b717e212 100644 --- a/internal/runner/runner.go +++ b/internal/runner/runner.go @@ -976,7 +976,7 @@ func (r *Runner) HandleIPC(status string) error { // HandleSignal does the exact same things as HandleIPC just using signals // instead of webhooks. This only can be used in non-pipeline use cases -func (r *Runner) HandleSignal(status syscall.Signal) error { +func (r *Runner) HandleSignal(status os.Signal) error { switch status { case SigReady: if r.status == StatusStarting { diff --git a/internal/server/server.go b/internal/server/server.go index 72b9ec20..502912d8 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -9,9 +9,11 @@ import ( "io" "net/http" "os" + "os/signal" "path" "path/filepath" "sync/atomic" + "syscall" "time" "github.com/replicate/go/httpclient" @@ -27,6 +29,10 @@ const ( IPCStatusReady IPCStatus = "READY" IPCStatusBUSY IPCStatus = "BUSY" IPCStatusOutput IPCStatus = "OUTPUT" + + SigOutput = syscall.SIGHUP + SigReady = syscall.SIGUSR1 + SigBusy = syscall.SIGUSR2 ) type IPC struct { @@ -98,11 +104,6 @@ func (h *Handler) healthCheck() (*HealthCheck, error) { return nil, err } - if err := writeCheckpointReadyFile(); err != nil { - log.Errorw("failed to write checkpoint ready file", "error", err) - return nil, err - } - runnerSetupResult := h.runnerManager.SetupResult() concurrency := h.runnerManager.Concurrency() runnerStatus := h.runnerManager.Status() @@ -155,6 +156,23 @@ func (h *Handler) Stop() error { return nil } +func (h *Handler) HandleSignals() { + log := h.logger.Sugar() + + ch := make(chan os.Signal, 1) + signal.Notify(ch, SigOutput, SigReady, SigBusy) + + for { + s := <-ch + err := h.runnerManager.HandleRunnerSignal(runner.DefaultRunnerName, s) + if err != nil { + log.Errorw("failed to handle IPC", "signal", s, "error", err) + // TODO: What do we do with this error? Put it on some error chan + // and ship it somewhere? + } + } +} + func (h *Handler) HandleIPC(w http.ResponseWriter, r *http.Request) { log := h.logger.Sugar() @@ -403,16 +421,6 @@ func writeReadyFile() error { return writeFileIfNotExists(file) } -// If the checkpoint flow is turned on, write the ready file for checkpointing -func writeCheckpointReadyFile() error { - file := os.Getenv("ENTRYPOINT_CUDA_READY_LOCK_FILE") - if file == "" { - return nil - } - - return writeFileIfNotExists(file) -} - func writeFileIfNotExists(fpath string) error { dir := filepath.Dir(fpath) if _, err := os.Stat(fpath); os.IsNotExist(err) { diff --git a/internal/service/service.go b/internal/service/service.go index 810b2666..b9caa02d 100644 --- a/internal/service/service.go +++ b/internal/service/service.go @@ -173,6 +173,12 @@ func (s *Service) Run(ctx context.Context) error { return fmt.Errorf("failed to start handler: %w", err) } + if s.cfg.SignalMode { + // This runs an infinite loop for handling signals, so we explicitly + // do not want to put it in a wait group of any kind + go s.handler.HandleSignals() + } + eg.Go(func() error { log.Info("starting HTTP server") if err := s.httpServer.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { From 2f5cf996dc0dfbe90b7e5acf5266727aa5867919 Mon Sep 17 00:00:00 2001 From: Nikhil Sinha Date: Tue, 7 Oct 2025 12:46:58 -0700 Subject: [PATCH 07/39] Linting --- internal/checkpointer/checkpointer.go | 4 +--- internal/checkpointer/utils.go | 8 +++----- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/internal/checkpointer/checkpointer.go b/internal/checkpointer/checkpointer.go index b3c21c2f..4def922c 100644 --- a/internal/checkpointer/checkpointer.go +++ b/internal/checkpointer/checkpointer.go @@ -31,9 +31,7 @@ const ( checkpointSubdirName = "checkpoint" ) -var ( - errNoCheckpointDir = errors.New("Could not find checkpoint directory environment variable") -) +var errNoCheckpointDir = errors.New("Could not find checkpoint directory environment variable") type FatalCheckpointErr struct { err error diff --git a/internal/checkpointer/utils.go b/internal/checkpointer/utils.go index c27195f6..c46b34e0 100644 --- a/internal/checkpointer/utils.go +++ b/internal/checkpointer/utils.go @@ -12,9 +12,7 @@ import ( "time" ) -var ( - errTimedOutPolling = errors.New("Timed out while polling for file") -) +var errTimedOutPolling = errors.New("Timed out while polling for file") // updateEnvVar updates an environment variable in-place, adding an item to it // if it exists or creating it if it doesn't exist @@ -99,7 +97,7 @@ func downloadAndUntar(ctx context.Context, url, path string) error { // pollForFileDeletion waits for a file to be deleted, up until a timeout. It returns an error if the // timeout is hit -func pollForFileDeletion(target string, timeout time.Duration, pollInterval time.Duration) error { +func pollForFileDeletion(target string, timeout, pollInterval time.Duration) error { deadline := time.After(timeout) for { @@ -120,7 +118,7 @@ func pollForFileDeletion(target string, timeout time.Duration, pollInterval time // pollForFileExistance waits for a file to exist, up until a timeout. It returns an error if the // timeout is hit -func pollForFileExistance(target string, timeout time.Duration, pollInterval time.Duration) error { +func pollForFileExistance(target string, timeout, pollInterval time.Duration) error { deadline := time.After(timeout) for { From 36b57c7aa30b4452bf3fd3f55d166dfe06d0f62d Mon Sep 17 00:00:00 2001 From: Nikhil Sinha Date: Tue, 7 Oct 2025 13:40:56 -0700 Subject: [PATCH 08/39] Lint fixes part 1 --- internal/checkpointer/checkpointer.go | 25 +++++++------ internal/checkpointer/utils.go | 52 +++++++++------------------ internal/runner/manager.go | 15 +++++--- python/coglet/__main__.py | 2 +- 4 files changed, 40 insertions(+), 54 deletions(-) diff --git a/internal/checkpointer/checkpointer.go b/internal/checkpointer/checkpointer.go index 4def922c..011bb15e 100644 --- a/internal/checkpointer/checkpointer.go +++ b/internal/checkpointer/checkpointer.go @@ -31,14 +31,14 @@ const ( checkpointSubdirName = "checkpoint" ) -var errNoCheckpointDir = errors.New("Could not find checkpoint directory environment variable") +var errNoCheckpointDir = errors.New("could not find checkpoint directory environment variable") -type FatalCheckpointErr struct { +type FatalCheckpointError struct { err error } -func (e *FatalCheckpointErr) Error() string { - return e.Error() +func (e *FatalCheckpointError) Error() string { + return e.err.Error() } type Checkpointer interface { @@ -113,7 +113,7 @@ func (c *checkpointer) Checkpoint(ctx context.Context, cogletCmd *exec.Cmd) erro return errNoCheckpointDir } - err := os.MkdirAll(filepath.Join(c.checkpointDir, checkpointSubdirName), 0o666) + err := os.MkdirAll(filepath.Join(c.checkpointDir, checkpointSubdirName), 0o666) //nolint:gosec // coglet needs to write here if err != nil { return err } @@ -137,7 +137,7 @@ func (c *checkpointer) Checkpoint(ctx context.Context, cogletCmd *exec.Cmd) erro cudaCmd := strings.TrimSpace(string(data)) // Write said command to a file for later - err = os.WriteFile(filepath.Join(c.checkpointDir, cudaCmdFileName), []byte(cudaCmd), 0o666) + err = os.WriteFile(filepath.Join(c.checkpointDir, cudaCmdFileName), []byte(cudaCmd), 0o644) //nolint:gosec if err != nil { return err } @@ -157,7 +157,7 @@ func (c *checkpointer) Checkpoint(ctx context.Context, cogletCmd *exec.Cmd) erro cmd = exec.CommandContext(ctx, cudaCheckpointPath, "--toggle", "--pid", string(cudaPID)) if cudaErr := cmd.Run(); cudaErr != nil { // Return a fatal error so upstream knows we cannot continue in the current state - return &FatalCheckpointErr{ + return &FatalCheckpointError{ err: cudaErr, } } @@ -171,7 +171,7 @@ func (c *checkpointer) Checkpoint(ctx context.Context, cogletCmd *exec.Cmd) erro cmd = exec.CommandContext(ctx, cudaCheckpointPath, "--toggle", "--pid", string(cudaPID)) if err := cmd.Run(); err != nil { // Return a fatal error so upstream knows we cannot continue in the current state - return &FatalCheckpointErr{ + return &FatalCheckpointError{ err: err, } } @@ -200,7 +200,7 @@ func (c *checkpointer) Restore(ctx context.Context) (*exec.Cmd, func(context.Con if err != nil { // If this command failed, we want to best effort try to kill the started process, // since we'll start a new one - restoreCmd.Process.Kill() + restoreCmd.Process.Kill() //nolint:errcheck return err } @@ -210,7 +210,7 @@ func (c *checkpointer) Restore(ctx context.Context) (*exec.Cmd, func(context.Con if err := cmd.Run(); err != nil { // If this command failed, we want to best effort try to kill the started process, // since we'll start a new one - restoreCmd.Process.Kill() + restoreCmd.Process.Kill() //nolint:errcheck return err } @@ -219,7 +219,7 @@ func (c *checkpointer) Restore(ctx context.Context) (*exec.Cmd, func(context.Con if err != nil { // If this command failed, we want to best effort try to kill the started process, // since we'll start a new one - restoreCmd.Process.Kill() + restoreCmd.Process.Kill() //nolint:errcheck return err } @@ -246,6 +246,5 @@ func downloadCUDACheckpointBinaries(ctx context.Context) error { if err != nil { return fmt.Errorf("failed to download and untar CRIU: %w", err) } - updateEnvVar("LD_LIBRARY_PATH", filepath.Join(dir, "criu-lib")) - return nil + return updateEnvVar("LD_LIBRARY_PATH", filepath.Join(dir, "criu-lib")) } diff --git a/internal/checkpointer/utils.go b/internal/checkpointer/utils.go index c46b34e0..05541563 100644 --- a/internal/checkpointer/utils.go +++ b/internal/checkpointer/utils.go @@ -12,24 +12,26 @@ import ( "time" ) -var errTimedOutPolling = errors.New("Timed out while polling for file") +var errTimedOutPolling = errors.New("timed out while polling for file") // updateEnvVar updates an environment variable in-place, adding an item to it // if it exists or creating it if it doesn't exist -func updateEnvVar(envVarName, newItem string) { +func updateEnvVar(envVarName, newItem string) error { old := os.Getenv(envVarName) if old == "" { - os.Setenv(envVarName, newItem) - return + return os.Setenv(envVarName, newItem) } path := newItem + string(os.PathListSeparator) + os.Getenv(envVarName) - os.Setenv(envVarName, path) + return os.Setenv(envVarName, path) } // downloadFile downloads a file from the URL provided to the path provided func downloadFile(url, path string) error { filename := filepath.Base(path) - os.MkdirAll(filepath.Dir(path), 0o755) + err := os.MkdirAll(filepath.Dir(path), 0o600) + if err != nil { + return err + } resp, err := http.Get(url) if err != nil { @@ -41,13 +43,13 @@ func downloadFile(url, path string) error { return fmt.Errorf("failed to download %s: %w", filename, err) } - binary, err := os.Create(path) + file, err := os.Create(path) if err != nil { return fmt.Errorf("failed to touch file: %w", err) } - defer binary.Close() + defer file.Close() //nolint: errcheck - _, err = io.Copy(binary, resp.Body) + _, err = io.Copy(file, resp.Body) if err != nil { return fmt.Errorf("failed to save %s: %w", filename, err) } @@ -63,7 +65,7 @@ func downloadAndChmod(url, path string) error { return err } - if err := os.Chmod(path, 0o755); err != nil { + if err := os.Chmod(path, 0o600); err != nil { return fmt.Errorf("failed to chmod file: %w", err) } return nil @@ -116,37 +118,16 @@ func pollForFileDeletion(target string, timeout, pollInterval time.Duration) err } } -// pollForFileExistance waits for a file to exist, up until a timeout. It returns an error if the -// timeout is hit -func pollForFileExistance(target string, timeout, pollInterval time.Duration) error { - deadline := time.After(timeout) - - for { - // Check if the file exists, if it doesn't keep looping - if _, err := os.Stat(target); err == nil { - return nil - } - - // Check for timeout before sleeping for the polling interval - select { - case <-deadline: - return errTimedOutPolling - default: - time.Sleep(pollInterval) - } - } -} - // https://stackoverflow.com/a/30708914/30548878 func isDirEmpty(name string) (bool, error) { f, err := os.Open(name) if err != nil { return false, err } - defer f.Close() + defer f.Close() //nolint:errcheck _, err = f.Readdirnames(1) - if err == io.EOF { + if errors.Is(err, io.EOF) { return true, nil } return false, err @@ -155,7 +136,7 @@ func isDirEmpty(name string) (bool, error) { // Touch a file if it doesn't exist, otherwise wipes the contents of the file func touchFile(name string) error { // Ensure upstream directory exists for file - err := os.MkdirAll(filepath.Dir(name), 0o755) + err := os.MkdirAll(filepath.Dir(name), 0o644) if err != nil { return err } @@ -164,8 +145,7 @@ func touchFile(name string) error { if err != nil { return err } - f.Close() - return nil + return f.Close() } // setStatusReady ensures the ready files exist diff --git a/internal/runner/manager.go b/internal/runner/manager.go index eb6dbced..9c43cab0 100644 --- a/internal/runner/manager.go +++ b/internal/runner/manager.go @@ -356,6 +356,7 @@ func (m *Manager) createDefaultRunner(ctx context.Context) (*Runner, error) { // Derive the runtime context from the manager's context runtimeContext, runtimeCancel := context.WithCancel(ctx) +commandSetup: cmd := exec.CommandContext(runtimeContext, pythonPath, args...) //nolint:gosec // expected subprocess launched with variable runner, err := m.setupRunner(runtimeContext, runtimeCancel, cmd, env, runnerCtx, cogYaml.Concurrency.Max) @@ -374,12 +375,18 @@ func (m *Manager) createDefaultRunner(ctx context.Context) (*Runner, error) { if !cp.HasCheckpoint() { err := cp.Checkpoint(ctx, cmd) - var fatalCheckpointErr *checkpointer.FatalCheckpointErr - if errors.As(err, &fatalCheckpointErr) { - return nil, fmt.Errorf("fatal error while trying to checkpoint: %w", err) + var FatalCheckpointError *checkpointer.FatalCheckpointError + // If we saw an error that would leave the runner unusable, turn off the + // checkpointer and recreate the command and runner + if errors.As(err, &FatalCheckpointError) { + // TODO: Is this bad? Should we just return the error back up? + // The main concern is what `runner.Config` does leaving artifacts + // between runs, although I think that should be fine? + cp.Disable() + goto commandSetup } // If the error is not fatal, we failed to create a checkpoint but are still - // running the original cog process, so we can just continue as if we did + // running the cog process successfully, so we can just continue as if we did // nothing } diff --git a/python/coglet/__main__.py b/python/coglet/__main__.py index b057d5d5..d705a853 100644 --- a/python/coglet/__main__.py +++ b/python/coglet/__main__.py @@ -54,7 +54,7 @@ def main() -> int: parser = argparse.ArgumentParser() parser.add_argument('--name', metavar='NAME', required=True, help='name') group = parser.add_mutually_exclusive_group() - group.add_argument('--ipc-url', metavar='URL', required=True, help='IPC URL') + group.add_argument('--ipc-url', metavar='URL', help='IPC URL') group.add_argument('--signal_mode', action='store_true') parser.add_argument( '--working-dir', metavar='DIR', required=True, help='working directory' From 636fdd87aab7b2f9d744c0ffb86cc2fc5bc35af0 Mon Sep 17 00:00:00 2001 From: Nikhil Sinha Date: Tue, 7 Oct 2025 14:47:48 -0700 Subject: [PATCH 09/39] nolint for the checkpointer file --- internal/checkpointer/checkpointer.go | 15 ++++++++++----- internal/checkpointer/utils.go | 4 ++-- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/internal/checkpointer/checkpointer.go b/internal/checkpointer/checkpointer.go index 011bb15e..fe1a0018 100644 --- a/internal/checkpointer/checkpointer.go +++ b/internal/checkpointer/checkpointer.go @@ -1,3 +1,8 @@ +// There are some commands in here that are susceptible to injection. However, cog +// is a vehicle to let people run their own code... so why go through the hassle of +// injection? Cog is not run with any more permissions than the user code. +// +//nolint:gosec package checkpointer import ( @@ -113,7 +118,7 @@ func (c *checkpointer) Checkpoint(ctx context.Context, cogletCmd *exec.Cmd) erro return errNoCheckpointDir } - err := os.MkdirAll(filepath.Join(c.checkpointDir, checkpointSubdirName), 0o666) //nolint:gosec // coglet needs to write here + err := os.MkdirAll(filepath.Join(c.checkpointDir, checkpointSubdirName), 0o666) if err != nil { return err } @@ -137,7 +142,7 @@ func (c *checkpointer) Checkpoint(ctx context.Context, cogletCmd *exec.Cmd) erro cudaCmd := strings.TrimSpace(string(data)) // Write said command to a file for later - err = os.WriteFile(filepath.Join(c.checkpointDir, cudaCmdFileName), []byte(cudaCmd), 0o644) //nolint:gosec + err = os.WriteFile(filepath.Join(c.checkpointDir, cudaCmdFileName), []byte(cudaCmd), 0o644) if err != nil { return err } @@ -200,7 +205,7 @@ func (c *checkpointer) Restore(ctx context.Context) (*exec.Cmd, func(context.Con if err != nil { // If this command failed, we want to best effort try to kill the started process, // since we'll start a new one - restoreCmd.Process.Kill() //nolint:errcheck + restoreCmd.Process.Kill() //nolint:errcheck // This is just best effort return err } @@ -210,7 +215,7 @@ func (c *checkpointer) Restore(ctx context.Context) (*exec.Cmd, func(context.Con if err := cmd.Run(); err != nil { // If this command failed, we want to best effort try to kill the started process, // since we'll start a new one - restoreCmd.Process.Kill() //nolint:errcheck + restoreCmd.Process.Kill() //nolint:errcheck // This is just best effort return err } @@ -219,7 +224,7 @@ func (c *checkpointer) Restore(ctx context.Context) (*exec.Cmd, func(context.Con if err != nil { // If this command failed, we want to best effort try to kill the started process, // since we'll start a new one - restoreCmd.Process.Kill() //nolint:errcheck + restoreCmd.Process.Kill() //nolint:errcheck // This is just best effort return err } diff --git a/internal/checkpointer/utils.go b/internal/checkpointer/utils.go index 05541563..ff36148a 100644 --- a/internal/checkpointer/utils.go +++ b/internal/checkpointer/utils.go @@ -47,7 +47,7 @@ func downloadFile(url, path string) error { if err != nil { return fmt.Errorf("failed to touch file: %w", err) } - defer file.Close() //nolint: errcheck + defer file.Close() //nolint:errcheck // nothing to do with this error _, err = io.Copy(file, resp.Body) if err != nil { @@ -124,7 +124,7 @@ func isDirEmpty(name string) (bool, error) { if err != nil { return false, err } - defer f.Close() //nolint:errcheck + defer f.Close() //nolint:errcheck // nothing to do with this error _, err = f.Readdirnames(1) if errors.Is(err, io.EOF) { From 3075bad133b4ca9ae92ce0b36b95c6881ecccbb1 Mon Sep 17 00:00:00 2001 From: Nikhil Sinha Date: Tue, 7 Oct 2025 15:02:12 -0700 Subject: [PATCH 10/39] Lint fixes part 2 --- internal/checkpointer/checkpointer.go | 8 ++++---- internal/checkpointer/utils.go | 5 +++++ internal/runner/manager.go | 13 ++++++++++--- 3 files changed, 19 insertions(+), 7 deletions(-) diff --git a/internal/checkpointer/checkpointer.go b/internal/checkpointer/checkpointer.go index fe1a0018..6c6839f2 100644 --- a/internal/checkpointer/checkpointer.go +++ b/internal/checkpointer/checkpointer.go @@ -2,7 +2,7 @@ // is a vehicle to let people run their own code... so why go through the hassle of // injection? Cog is not run with any more permissions than the user code. // -//nolint:gosec +//nolint:gosec // See above package checkpointer import ( @@ -148,7 +148,7 @@ func (c *checkpointer) Checkpoint(ctx context.Context, cogletCmd *exec.Cmd) erro } // Toggle CUDA off - cmd := exec.CommandContext(ctx, cudaCheckpointPath, "--toggle", "--pid", string(cudaPID)) + cmd := exec.CommandContext(ctx, cudaCheckpointPath, "--toggle", "--pid", cudaPID) if err := cmd.Run(); err != nil { return err } @@ -159,7 +159,7 @@ func (c *checkpointer) Checkpoint(ctx context.Context, cogletCmd *exec.Cmd) erro // Try to toggle CUDA back on. If we aren't able to restart CUDA, the process // will hang indefinitely, so we should kill it and try to start a new one // without checkpointing - cmd = exec.CommandContext(ctx, cudaCheckpointPath, "--toggle", "--pid", string(cudaPID)) + cmd = exec.CommandContext(ctx, cudaCheckpointPath, "--toggle", "--pid", cudaPID) if cudaErr := cmd.Run(); cudaErr != nil { // Return a fatal error so upstream knows we cannot continue in the current state return &FatalCheckpointError{ @@ -173,7 +173,7 @@ func (c *checkpointer) Checkpoint(ctx context.Context, cogletCmd *exec.Cmd) erro // Toggle CUDA back on. If we aren't able to restart CUDA, the process // will hang indefinitely, so we should kill it and try to start a new // one without checkpointing - cmd = exec.CommandContext(ctx, cudaCheckpointPath, "--toggle", "--pid", string(cudaPID)) + cmd = exec.CommandContext(ctx, cudaCheckpointPath, "--toggle", "--pid", cudaPID) if err := cmd.Run(); err != nil { // Return a fatal error so upstream knows we cannot continue in the current state return &FatalCheckpointError{ diff --git a/internal/checkpointer/utils.go b/internal/checkpointer/utils.go index ff36148a..0ea76f59 100644 --- a/internal/checkpointer/utils.go +++ b/internal/checkpointer/utils.go @@ -1,3 +1,8 @@ +// There are some commands in here that are susceptible to injection. However, cog +// is a vehicle to let people run their own code... so why go through the hassle of +// injection? Cog is not run with any more permissions than the user code. +// +//nolint:gosec // See above package checkpointer import ( diff --git a/internal/runner/manager.go b/internal/runner/manager.go index 9c43cab0..d6787aec 100644 --- a/internal/runner/manager.go +++ b/internal/runner/manager.go @@ -341,9 +341,7 @@ func (m *Manager) createDefaultRunner(ctx context.Context) (*Runner, error) { // If there is an existing checkpoint, try to restore from the checkpoint if cp.HasCheckpoint() { runner, err := m.startRunnerFromCheckpoint(ctx, env, runnerCtx, cogYaml.Concurrency.Max, cp) - if err != nil { - cp.Disable() - } else { + if err == nil { m.runners[0] = runner m.monitoringWG.Go(func() { m.monitorRunnerSubprocess(m.ctx, DefaultRunnerName, runner) @@ -351,6 +349,8 @@ func (m *Manager) createDefaultRunner(ctx context.Context) (*Runner, error) { return runner, nil } + // If the error was non-nil, disable the checkpointer and continue + cp.Disable() } // Derive the runtime context from the manager's context @@ -359,6 +359,9 @@ func (m *Manager) createDefaultRunner(ctx context.Context) (*Runner, error) { commandSetup: cmd := exec.CommandContext(runtimeContext, pythonPath, args...) //nolint:gosec // expected subprocess launched with variable runner, err := m.setupRunner(runtimeContext, runtimeCancel, cmd, env, runnerCtx, cogYaml.Concurrency.Max) + if err != nil { + return nil, fmt.Errorf("failed to set up runner: %w", err) + } if err := runner.Config(ctx); err != nil { if stopErr := runner.Stop(); stopErr != nil { @@ -417,10 +420,14 @@ func (m *Manager) startRunnerFromCheckpoint(ctx context.Context, env []string, r cmd, callback, err := cp.Restore(runtimeContext) if err != nil { + runtimeCancel() return nil, err } runner, err := m.setupRunner(runtimeContext, runtimeCancel, cmd, env, runnerCtx, maxConcurrency) + if err != nil { + return nil, fmt.Errorf("failed to set up runner: %w", err) + } err = callback(runtimeContext) if err != nil { From 4aaf4bb530b637b525991e9c446a97fee34a0035 Mon Sep 17 00:00:00 2001 From: Nikhil Sinha Date: Tue, 7 Oct 2025 16:02:35 -0700 Subject: [PATCH 11/39] Write ready file when ready --- internal/checkpointer/checkpointer.go | 20 ++++++++++---------- internal/checkpointer/utils.go | 4 ++-- internal/runner/manager.go | 4 ++-- 3 files changed, 14 insertions(+), 14 deletions(-) diff --git a/internal/checkpointer/checkpointer.go b/internal/checkpointer/checkpointer.go index 6c6839f2..999478d8 100644 --- a/internal/checkpointer/checkpointer.go +++ b/internal/checkpointer/checkpointer.go @@ -52,6 +52,7 @@ type Checkpointer interface { Prepare(ctx context.Context) error Checkpoint(ctx context.Context, cmd *exec.Cmd) error Restore(ctx context.Context) (*exec.Cmd, func(context.Context) error, error) + WriteReadyFile() error } type checkpointer struct { @@ -181,7 +182,7 @@ func (c *checkpointer) Checkpoint(ctx context.Context, cogletCmd *exec.Cmd) erro } } - return setStatusReady() + return nil } func (c *checkpointer) Restore(ctx context.Context) (*exec.Cmd, func(context.Context) error, error) { @@ -220,15 +221,6 @@ func (c *checkpointer) Restore(ctx context.Context) (*exec.Cmd, func(context.Con return err } - err = setStatusReady() - if err != nil { - // If this command failed, we want to best effort try to kill the started process, - // since we'll start a new one - restoreCmd.Process.Kill() //nolint:errcheck // This is just best effort - - return err - } - return nil } @@ -236,6 +228,14 @@ func (c *checkpointer) Restore(ctx context.Context) (*exec.Cmd, func(context.Con return restoreCmd, callback, nil } +func (c *checkpointer) WriteReadyFile() error { + // If it isn't expected, make this a no-op + if os.Getenv(shouldCheckpointEnvVar) != "true" { + return nil + } + return writeCudaReadyFile() +} + func downloadCUDACheckpointBinaries(ctx context.Context) error { location := os.Getenv("R8_LOCATION") diff --git a/internal/checkpointer/utils.go b/internal/checkpointer/utils.go index 0ea76f59..d2c85dd1 100644 --- a/internal/checkpointer/utils.go +++ b/internal/checkpointer/utils.go @@ -153,8 +153,8 @@ func touchFile(name string) error { return f.Close() } -// setStatusReady ensures the ready files exist -func setStatusReady() error { +// writeCudaReadyFile ensures the ready files exist +func writeCudaReadyFile() error { cudaReadyFilePath := os.Getenv(cudaReadyFileEnvVar) // Touch CUDA ready file diff --git a/internal/runner/manager.go b/internal/runner/manager.go index d6787aec..9ac387e5 100644 --- a/internal/runner/manager.go +++ b/internal/runner/manager.go @@ -347,7 +347,7 @@ func (m *Manager) createDefaultRunner(ctx context.Context) (*Runner, error) { m.monitorRunnerSubprocess(m.ctx, DefaultRunnerName, runner) }) - return runner, nil + return runner, cp.WriteReadyFile() } // If the error was non-nil, disable the checkpointer and continue cp.Disable() @@ -393,7 +393,7 @@ commandSetup: // nothing } - return runner, nil + return runner, cp.WriteReadyFile() } func (m *Manager) setupRunner(runtimeContext context.Context, runtimeCancel context.CancelFunc, cmd *exec.Cmd, env []string, runnerCtx RunnerContext, maxConcurrency int) (*Runner, error) { From d44231ea21dbc138ca7fbc02b905dd157a596d4e Mon Sep 17 00:00:00 2001 From: Nikhil Sinha Date: Tue, 7 Oct 2025 16:21:22 -0700 Subject: [PATCH 12/39] Send ready signal --- internal/checkpointer/checkpointer.go | 8 ++++++-- internal/runner/manager.go | 9 ++++++--- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/internal/checkpointer/checkpointer.go b/internal/checkpointer/checkpointer.go index 999478d8..1193ccc7 100644 --- a/internal/checkpointer/checkpointer.go +++ b/internal/checkpointer/checkpointer.go @@ -50,7 +50,7 @@ type Checkpointer interface { Disable() HasCheckpoint() bool Prepare(ctx context.Context) error - Checkpoint(ctx context.Context, cmd *exec.Cmd) error + Checkpoint(ctx context.Context, cmd *exec.Cmd, waitFunc func() error) error Restore(ctx context.Context) (*exec.Cmd, func(context.Context) error, error) WriteReadyFile() error } @@ -110,7 +110,7 @@ func (c *checkpointer) Prepare(ctx context.Context) error { return nil } -func (c *checkpointer) Checkpoint(ctx context.Context, cogletCmd *exec.Cmd) error { +func (c *checkpointer) Checkpoint(ctx context.Context, cogletCmd *exec.Cmd, waitFunc func() error) error { if !c.enabled { return nil } @@ -119,6 +119,10 @@ func (c *checkpointer) Checkpoint(ctx context.Context, cogletCmd *exec.Cmd) erro return errNoCheckpointDir } + if err := waitFunc(); err != nil { + return err + } + err := os.MkdirAll(filepath.Join(c.checkpointDir, checkpointSubdirName), 0o666) if err != nil { return err diff --git a/internal/runner/manager.go b/internal/runner/manager.go index 9ac387e5..4246b28f 100644 --- a/internal/runner/manager.go +++ b/internal/runner/manager.go @@ -377,7 +377,7 @@ commandSetup: }) if !cp.HasCheckpoint() { - err := cp.Checkpoint(ctx, cmd) + err = cp.Checkpoint(ctx, cmd, func() error { return waitForRunnerSetup(ctx, runner) }) var FatalCheckpointError *checkpointer.FatalCheckpointError // If we saw an error that would leave the runner unusable, turn off the // checkpointer and recreate the command and runner @@ -434,9 +434,12 @@ func (m *Manager) startRunnerFromCheckpoint(ctx context.Context, env []string, r return nil, fmt.Errorf("failed callback function: %w", err) } - // TODO: Send ready signal somehow, can we SIGHUP ourselves? + // We checkpointed the model after it ran setup, so we need to manually send the ready signal + // to the runner. We can do this by sending the SigReady signal to the current PID, as signal + // mode should be on if the checkpoint exists + err = syscall.Kill(syscall.Getpid(), SigReady) - return runner, nil + return runner, err } // allocatePrediction reserves a slot in the runner for the prediction From ecbc90b8f5c5155f5c00a4bfdef01c56493b09e1 Mon Sep 17 00:00:00 2001 From: Nikhil Sinha Date: Wed, 8 Oct 2025 16:54:51 -0700 Subject: [PATCH 13/39] Standardize flag casing --- python/coglet/__main__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/coglet/__main__.py b/python/coglet/__main__.py index d705a853..7499f422 100644 --- a/python/coglet/__main__.py +++ b/python/coglet/__main__.py @@ -55,7 +55,7 @@ def main() -> int: parser.add_argument('--name', metavar='NAME', required=True, help='name') group = parser.add_mutually_exclusive_group() group.add_argument('--ipc-url', metavar='URL', help='IPC URL') - group.add_argument('--signal_mode', action='store_true') + group.add_argument('--signal-mode', action='store_true') parser.add_argument( '--working-dir', metavar='DIR', required=True, help='working directory' ) From 55073f334933dbe23eade4d2e4919d8e313e4535 Mon Sep 17 00:00:00 2001 From: Nikhil Sinha Date: Thu, 9 Oct 2025 12:16:53 -0700 Subject: [PATCH 14/39] Testing --- cmd/cog/main.go | 6 ++++++ internal/service/service.go | 16 ++++++++++------ 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/cmd/cog/main.go b/cmd/cog/main.go index 90f4109c..7ceb533b 100644 --- a/cmd/cog/main.go +++ b/cmd/cog/main.go @@ -115,15 +115,21 @@ func (s *ServerCmd) Run() error { // Create service with base logger svc := service.New(cfg, baseLogger) + log.Infow("created service") + // Create root context for the entire service ctx, cancel := context.WithCancel(context.Background()) defer cancel() + log.Infow("created context") + // Initialize service components if err := svc.Initialize(ctx); err != nil { return err } + log.Infow("initialized service") + return svc.Run(ctx) } diff --git a/internal/service/service.go b/internal/service/service.go index b9caa02d..8dcd854e 100644 --- a/internal/service/service.go +++ b/internal/service/service.go @@ -103,6 +103,12 @@ func (s *Service) Initialize(ctx context.Context) error { return err } + if s.cfg.SignalMode { + // This runs an infinite loop for handling signals, so we explicitly + // do not want to put it in a wait group of any kind + go s.handler.HandleSignals() + } + return nil } @@ -150,6 +156,8 @@ func (s *Service) initializeHTTPServer(ctx context.Context) error { func (s *Service) Run(ctx context.Context) error { log := s.logger.Sugar() + log.Infow("started running") + select { case <-s.started: log.Errorw("service already started") @@ -157,6 +165,8 @@ func (s *Service) Run(ctx context.Context) error { default: } + log.Infow("channel did not return error") + if s.httpServer == nil { return fmt.Errorf("service not initialized - call Initialize() first") } @@ -173,12 +183,6 @@ func (s *Service) Run(ctx context.Context) error { return fmt.Errorf("failed to start handler: %w", err) } - if s.cfg.SignalMode { - // This runs an infinite loop for handling signals, so we explicitly - // do not want to put it in a wait group of any kind - go s.handler.HandleSignals() - } - eg.Go(func() error { log.Info("starting HTTP server") if err := s.httpServer.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { From f14484b908a0414da68034aeff71b58588e78b5a Mon Sep 17 00:00:00 2001 From: Nikhil Sinha Date: Thu, 9 Oct 2025 12:39:39 -0700 Subject: [PATCH 15/39] Testing further --- internal/service/service.go | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/internal/service/service.go b/internal/service/service.go index 8dcd854e..1337b356 100644 --- a/internal/service/service.go +++ b/internal/service/service.go @@ -95,6 +95,12 @@ func (s *Service) Initialize(ctx context.Context) error { s.cfg.ForceShutdown = s.forceShutdown } + if s.cfg.SignalMode { + // This runs an infinite loop for handling signals, so we explicitly + // do not want to put it in a wait group of any kind + go s.handler.HandleSignals() + } + if err := s.initializeHandler(ctx); err != nil { return err } @@ -103,12 +109,6 @@ func (s *Service) Initialize(ctx context.Context) error { return err } - if s.cfg.SignalMode { - // This runs an infinite loop for handling signals, so we explicitly - // do not want to put it in a wait group of any kind - go s.handler.HandleSignals() - } - return nil } From 99c0bb0060172c23a693bdae386d1c15aad7788e Mon Sep 17 00:00:00 2001 From: Nikhil Sinha Date: Thu, 9 Oct 2025 13:12:16 -0700 Subject: [PATCH 16/39] Ordering --- internal/runner/manager.go | 22 ++++++++++++++++++++++ internal/server/server.go | 18 ------------------ internal/service/service.go | 6 ------ 3 files changed, 22 insertions(+), 24 deletions(-) diff --git a/internal/runner/manager.go b/internal/runner/manager.go index 4246b28f..c71fc3d8 100644 --- a/internal/runner/manager.go +++ b/internal/runner/manager.go @@ -8,6 +8,7 @@ import ( "io/fs" "os" "os/exec" + "os/signal" "runtime" "sync" "syscall" @@ -291,6 +292,10 @@ func (m *Manager) createDefaultRunner(ctx context.Context) (*Runner, error) { if m.cfg.SignalMode { args = append(args, "--signal-mode") + // Make sure the signal handling is running + // This runs an infinite loop for handling signals, so we explicitly + // do not want to put it in a wait group of any kind + go m.HandleSignals() } else { args = append(args, "--ipc-url", m.cfg.IPCUrl) } @@ -1025,6 +1030,23 @@ func (m *Manager) HandleRunnerSignal(runnerName string, signal os.Signal) error return runner.HandleSignal(signal) } +func (m *Manager) HandleSignals() { + log := m.logger.Sugar() + + ch := make(chan os.Signal, 1) + signal.Notify(ch, SigOutput, SigReady, SigBusy) + + for { + s := <-ch + err := m.HandleRunnerSignal(DefaultRunnerName, s) + if err != nil { + log.Errorw("failed to handle IPC", "signal", s, "error", err) + // TODO: What do we do with this error? Put it on some error chan + // and ship it somewhere? + } + } +} + func (m *Manager) cleanupInProgress() bool { if !m.cfg.OneShot { return false diff --git a/internal/server/server.go b/internal/server/server.go index 502912d8..2f0796ac 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -9,7 +9,6 @@ import ( "io" "net/http" "os" - "os/signal" "path" "path/filepath" "sync/atomic" @@ -156,23 +155,6 @@ func (h *Handler) Stop() error { return nil } -func (h *Handler) HandleSignals() { - log := h.logger.Sugar() - - ch := make(chan os.Signal, 1) - signal.Notify(ch, SigOutput, SigReady, SigBusy) - - for { - s := <-ch - err := h.runnerManager.HandleRunnerSignal(runner.DefaultRunnerName, s) - if err != nil { - log.Errorw("failed to handle IPC", "signal", s, "error", err) - // TODO: What do we do with this error? Put it on some error chan - // and ship it somewhere? - } - } -} - func (h *Handler) HandleIPC(w http.ResponseWriter, r *http.Request) { log := h.logger.Sugar() diff --git a/internal/service/service.go b/internal/service/service.go index 1337b356..96b3cee0 100644 --- a/internal/service/service.go +++ b/internal/service/service.go @@ -95,12 +95,6 @@ func (s *Service) Initialize(ctx context.Context) error { s.cfg.ForceShutdown = s.forceShutdown } - if s.cfg.SignalMode { - // This runs an infinite loop for handling signals, so we explicitly - // do not want to put it in a wait group of any kind - go s.handler.HandleSignals() - } - if err := s.initializeHandler(ctx); err != nil { return err } From c0b9ff4974c95627dfddff6e3e5c3643e0ff502b Mon Sep 17 00:00:00 2001 From: Nikhil Sinha Date: Thu, 9 Oct 2025 13:17:20 -0700 Subject: [PATCH 17/39] No shadowing --- internal/runner/manager.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/internal/runner/manager.go b/internal/runner/manager.go index c71fc3d8..36d66641 100644 --- a/internal/runner/manager.go +++ b/internal/runner/manager.go @@ -1022,12 +1022,12 @@ func (m *Manager) HandleRunnerIPC(runnerName, status string) error { return runner.HandleIPC(status) } -func (m *Manager) HandleRunnerSignal(runnerName string, signal os.Signal) error { +func (m *Manager) HandleRunnerSignal(runnerName string, s os.Signal) error { runner, _, exists := m.findRunner(runnerName) if !exists { return fmt.Errorf("%w: %s", ErrRunnerNotFound, runnerName) } - return runner.HandleSignal(signal) + return runner.HandleSignal(s) } func (m *Manager) HandleSignals() { From ae1163db3b94700e9784f30ed941dcfee289e0e7 Mon Sep 17 00:00:00 2001 From: Nikhil Sinha Date: Thu, 9 Oct 2025 14:11:52 -0700 Subject: [PATCH 18/39] Close TCP connections --- cmd/cog/main.go | 6 ------ internal/checkpointer/checkpointer.go | 4 ++-- internal/runner/runner.go | 16 ---------------- internal/service/service.go | 4 ---- 4 files changed, 2 insertions(+), 28 deletions(-) diff --git a/cmd/cog/main.go b/cmd/cog/main.go index 7ceb533b..90f4109c 100644 --- a/cmd/cog/main.go +++ b/cmd/cog/main.go @@ -115,21 +115,15 @@ func (s *ServerCmd) Run() error { // Create service with base logger svc := service.New(cfg, baseLogger) - log.Infow("created service") - // Create root context for the entire service ctx, cancel := context.WithCancel(context.Background()) defer cancel() - log.Infow("created context") - // Initialize service components if err := svc.Initialize(ctx); err != nil { return err } - log.Infow("initialized service") - return svc.Run(ctx) } diff --git a/internal/checkpointer/checkpointer.go b/internal/checkpointer/checkpointer.go index 1193ccc7..a6e156a1 100644 --- a/internal/checkpointer/checkpointer.go +++ b/internal/checkpointer/checkpointer.go @@ -159,7 +159,7 @@ func (c *checkpointer) Checkpoint(ctx context.Context, cogletCmd *exec.Cmd, wait } // CRIU checkpoint (leaving process running) - cmd = exec.CommandContext(ctx, criuPath, "dump", "--leave-running", "--shell-job", "--images-dir", filepath.Join(c.checkpointDir, checkpointSubdirName), "--tree", pid) + cmd = exec.CommandContext(ctx, criuPath, "dump", "--leave-running", "--shell-job", "--tcp-close", "--images-dir", filepath.Join(c.checkpointDir, checkpointSubdirName), "--tree", pid) if err := cmd.Run(); err != nil { // Try to toggle CUDA back on. If we aren't able to restart CUDA, the process // will hang indefinitely, so we should kill it and try to start a new one @@ -201,7 +201,7 @@ func (c *checkpointer) Restore(ctx context.Context) (*exec.Cmd, func(context.Con } // Set up restore command - restoreCmd := exec.CommandContext(ctx, criuPath, "restore", "--shell-job", "--images-dir", filepath.Join(c.checkpointDir, checkpointSubdirName)) + restoreCmd := exec.CommandContext(ctx, criuPath, "restore", "--shell-job", "--tcp-close", "--images-dir", filepath.Join(c.checkpointDir, checkpointSubdirName)) // Set up callback function once restore is started callback := func(con context.Context) error { diff --git a/internal/runner/runner.go b/internal/runner/runner.go index b717e212..df353c2d 100644 --- a/internal/runner/runner.go +++ b/internal/runner/runner.go @@ -875,22 +875,6 @@ func (r *Runner) predict(reqID string) (chan PredictionResponse, *PredictionResp log.Tracew("wrote prediction request file", "prediction_id", reqID, "path", requestPath, "working_dir", r.runnerCtx.workingdir, "request_data", string(requestData)) - // Debug: Check if file actually exists and list directory contents - if _, err := os.Stat(requestPath); err != nil { - log.Tracew("ERROR: written request file does not exist", "prediction_id", reqID, "path", requestPath, "error", err) - } else { - log.Tracew("confirmed request file exists", "prediction_id", reqID, "path", requestPath) - } - - // Debug: List all files in working directory - if entries, err := os.ReadDir(r.runnerCtx.workingdir); err == nil { - fileNames := make([]string, len(entries)) - for i, entry := range entries { - fileNames[i] = entry.Name() - } - log.Tracew("working directory contents after write", "prediction_id", reqID, "working_dir", r.runnerCtx.workingdir, "files", fileNames) - } - log.Tracew("returning prediction channel", "prediction_id", reqID) initialResponse := &PredictionResponse{ Status: PredictionStarting, diff --git a/internal/service/service.go b/internal/service/service.go index 96b3cee0..810b2666 100644 --- a/internal/service/service.go +++ b/internal/service/service.go @@ -150,8 +150,6 @@ func (s *Service) initializeHTTPServer(ctx context.Context) error { func (s *Service) Run(ctx context.Context) error { log := s.logger.Sugar() - log.Infow("started running") - select { case <-s.started: log.Errorw("service already started") @@ -159,8 +157,6 @@ func (s *Service) Run(ctx context.Context) error { default: } - log.Infow("channel did not return error") - if s.httpServer == nil { return fmt.Errorf("service not initialized - call Initialize() first") } From 87eeac8fb00413f4342bf50dcc0ad72cb6c5cdf6 Mon Sep 17 00:00:00 2001 From: Nikhil Sinha Date: Thu, 9 Oct 2025 14:47:22 -0700 Subject: [PATCH 19/39] Correct perms, as well as some testing --- internal/checkpointer/checkpointer.go | 10 ++++++++++ internal/checkpointer/utils.go | 2 +- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/internal/checkpointer/checkpointer.go b/internal/checkpointer/checkpointer.go index a6e156a1..6da6b78c 100644 --- a/internal/checkpointer/checkpointer.go +++ b/internal/checkpointer/checkpointer.go @@ -154,17 +154,23 @@ func (c *checkpointer) Checkpoint(ctx context.Context, cogletCmd *exec.Cmd, wait // Toggle CUDA off cmd := exec.CommandContext(ctx, cudaCheckpointPath, "--toggle", "--pid", cudaPID) + cmd.Stderr = os.Stderr + cmd.Stdout = os.Stdout if err := cmd.Run(); err != nil { return err } // CRIU checkpoint (leaving process running) cmd = exec.CommandContext(ctx, criuPath, "dump", "--leave-running", "--shell-job", "--tcp-close", "--images-dir", filepath.Join(c.checkpointDir, checkpointSubdirName), "--tree", pid) + cmd.Stderr = os.Stderr + cmd.Stdout = os.Stdout if err := cmd.Run(); err != nil { // Try to toggle CUDA back on. If we aren't able to restart CUDA, the process // will hang indefinitely, so we should kill it and try to start a new one // without checkpointing cmd = exec.CommandContext(ctx, cudaCheckpointPath, "--toggle", "--pid", cudaPID) + cmd.Stderr = os.Stderr + cmd.Stdout = os.Stdout if cudaErr := cmd.Run(); cudaErr != nil { // Return a fatal error so upstream knows we cannot continue in the current state return &FatalCheckpointError{ @@ -179,6 +185,8 @@ func (c *checkpointer) Checkpoint(ctx context.Context, cogletCmd *exec.Cmd, wait // will hang indefinitely, so we should kill it and try to start a new // one without checkpointing cmd = exec.CommandContext(ctx, cudaCheckpointPath, "--toggle", "--pid", cudaPID) + cmd.Stderr = os.Stderr + cmd.Stdout = os.Stdout if err := cmd.Run(); err != nil { // Return a fatal error so upstream knows we cannot continue in the current state return &FatalCheckpointError{ @@ -217,6 +225,8 @@ func (c *checkpointer) Restore(ctx context.Context) (*exec.Cmd, func(context.Con // Toggle CUDA on for the restored process cmd := exec.CommandContext(con, cudaCheckpointPath, "--toggle", "--pid", string(cudaPID)) + cmd.Stderr = os.Stderr + cmd.Stdout = os.Stdout if err := cmd.Run(); err != nil { // If this command failed, we want to best effort try to kill the started process, // since we'll start a new one diff --git a/internal/checkpointer/utils.go b/internal/checkpointer/utils.go index d2c85dd1..2c2b7ce7 100644 --- a/internal/checkpointer/utils.go +++ b/internal/checkpointer/utils.go @@ -70,7 +70,7 @@ func downloadAndChmod(url, path string) error { return err } - if err := os.Chmod(path, 0o600); err != nil { + if err := os.Chmod(path, 0o700); err != nil { return fmt.Errorf("failed to chmod file: %w", err) } return nil From 96666c3f03efefebbc22ca819b18279b505789c1 Mon Sep 17 00:00:00 2001 From: Nikhil Sinha Date: Thu, 9 Oct 2025 16:04:27 -0700 Subject: [PATCH 20/39] Remove test logging --- internal/checkpointer/checkpointer.go | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/internal/checkpointer/checkpointer.go b/internal/checkpointer/checkpointer.go index 6da6b78c..a6e156a1 100644 --- a/internal/checkpointer/checkpointer.go +++ b/internal/checkpointer/checkpointer.go @@ -154,23 +154,17 @@ func (c *checkpointer) Checkpoint(ctx context.Context, cogletCmd *exec.Cmd, wait // Toggle CUDA off cmd := exec.CommandContext(ctx, cudaCheckpointPath, "--toggle", "--pid", cudaPID) - cmd.Stderr = os.Stderr - cmd.Stdout = os.Stdout if err := cmd.Run(); err != nil { return err } // CRIU checkpoint (leaving process running) cmd = exec.CommandContext(ctx, criuPath, "dump", "--leave-running", "--shell-job", "--tcp-close", "--images-dir", filepath.Join(c.checkpointDir, checkpointSubdirName), "--tree", pid) - cmd.Stderr = os.Stderr - cmd.Stdout = os.Stdout if err := cmd.Run(); err != nil { // Try to toggle CUDA back on. If we aren't able to restart CUDA, the process // will hang indefinitely, so we should kill it and try to start a new one // without checkpointing cmd = exec.CommandContext(ctx, cudaCheckpointPath, "--toggle", "--pid", cudaPID) - cmd.Stderr = os.Stderr - cmd.Stdout = os.Stdout if cudaErr := cmd.Run(); cudaErr != nil { // Return a fatal error so upstream knows we cannot continue in the current state return &FatalCheckpointError{ @@ -185,8 +179,6 @@ func (c *checkpointer) Checkpoint(ctx context.Context, cogletCmd *exec.Cmd, wait // will hang indefinitely, so we should kill it and try to start a new // one without checkpointing cmd = exec.CommandContext(ctx, cudaCheckpointPath, "--toggle", "--pid", cudaPID) - cmd.Stderr = os.Stderr - cmd.Stdout = os.Stdout if err := cmd.Run(); err != nil { // Return a fatal error so upstream knows we cannot continue in the current state return &FatalCheckpointError{ @@ -225,8 +217,6 @@ func (c *checkpointer) Restore(ctx context.Context) (*exec.Cmd, func(context.Con // Toggle CUDA on for the restored process cmd := exec.CommandContext(con, cudaCheckpointPath, "--toggle", "--pid", string(cudaPID)) - cmd.Stderr = os.Stderr - cmd.Stdout = os.Stdout if err := cmd.Run(); err != nil { // If this command failed, we want to best effort try to kill the started process, // since we'll start a new one From c10f2bad3685a935a5db9c221768b2847e7ddffd Mon Sep 17 00:00:00 2001 From: Nikhil Sinha Date: Thu, 9 Oct 2025 16:51:14 -0700 Subject: [PATCH 21/39] Comments --- internal/runner/manager.go | 32 ++++++++++++++++++-------------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/internal/runner/manager.go b/internal/runner/manager.go index 36d66641..c468b6a1 100644 --- a/internal/runner/manager.go +++ b/internal/runner/manager.go @@ -376,11 +376,6 @@ commandSetup: return nil, fmt.Errorf("failed to config runner: %w", err) } - m.runners[0] = runner - m.monitoringWG.Go(func() { - m.monitorRunnerSubprocess(m.ctx, DefaultRunnerName, runner) - }) - if !cp.HasCheckpoint() { err = cp.Checkpoint(ctx, cmd, func() error { return waitForRunnerSetup(ctx, runner) }) var FatalCheckpointError *checkpointer.FatalCheckpointError @@ -398,6 +393,11 @@ commandSetup: // nothing } + m.runners[0] = runner + m.monitoringWG.Go(func() { + m.monitorRunnerSubprocess(m.ctx, DefaultRunnerName, runner) + }) + return runner, cp.WriteReadyFile() } @@ -423,7 +423,7 @@ func (m *Manager) startRunnerFromCheckpoint(ctx context.Context, env []string, r // Derive the runtime context from the manager's context runtimeContext, runtimeCancel := context.WithCancel(ctx) - cmd, callback, err := cp.Restore(runtimeContext) + cmd, postSetupCallback, err := cp.Restore(runtimeContext) if err != nil { runtimeCancel() return nil, err @@ -434,7 +434,7 @@ func (m *Manager) startRunnerFromCheckpoint(ctx context.Context, env []string, r return nil, fmt.Errorf("failed to set up runner: %w", err) } - err = callback(runtimeContext) + err = postSetupCallback(runtimeContext) if err != nil { return nil, fmt.Errorf("failed callback function: %w", err) } @@ -1030,19 +1030,23 @@ func (m *Manager) HandleRunnerSignal(runnerName string, s os.Signal) error { return runner.HandleSignal(s) } -func (m *Manager) HandleSignals() { +func (m *Manager) HandleSignals(ctx context.Context) { log := m.logger.Sugar() ch := make(chan os.Signal, 1) signal.Notify(ch, SigOutput, SigReady, SigBusy) for { - s := <-ch - err := m.HandleRunnerSignal(DefaultRunnerName, s) - if err != nil { - log.Errorw("failed to handle IPC", "signal", s, "error", err) - // TODO: What do we do with this error? Put it on some error chan - // and ship it somewhere? + select { + case s := <-ch: + err := m.HandleRunnerSignal(DefaultRunnerName, s) + if err != nil { + log.Errorw("failed to handle IPC", "signal", s, "error", err) + // TODO: What do we do with this error? Put it on some error chan + // and ship it somewhere? + } + case <-ctx.Done(): + return } } } From 056b883857ef8935b7b028cbd431ebc50da841d6 Mon Sep 17 00:00:00 2001 From: Nikhil Sinha Date: Thu, 9 Oct 2025 16:52:04 -0700 Subject: [PATCH 22/39] Comments --- internal/checkpointer/utils.go | 1 + 1 file changed, 1 insertion(+) diff --git a/internal/checkpointer/utils.go b/internal/checkpointer/utils.go index 2c2b7ce7..822de32c 100644 --- a/internal/checkpointer/utils.go +++ b/internal/checkpointer/utils.go @@ -92,6 +92,7 @@ func downloadAndUntar(ctx context.Context, url, path string) error { if err != nil { return err } + defer devnull.Close() cmd.Stdout = devnull cmd.Stderr = devnull From 4a15dcdec823ae6125cec33d55ada40a8d2752da Mon Sep 17 00:00:00 2001 From: Nikhil Sinha Date: Thu, 9 Oct 2025 17:01:16 -0700 Subject: [PATCH 23/39] Pass context to function --- internal/runner/manager.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/runner/manager.go b/internal/runner/manager.go index c468b6a1..451936bd 100644 --- a/internal/runner/manager.go +++ b/internal/runner/manager.go @@ -295,7 +295,7 @@ func (m *Manager) createDefaultRunner(ctx context.Context) (*Runner, error) { // Make sure the signal handling is running // This runs an infinite loop for handling signals, so we explicitly // do not want to put it in a wait group of any kind - go m.HandleSignals() + go m.HandleSignals(m.ctx) } else { args = append(args, "--ipc-url", m.cfg.IPCUrl) } From 64171ab682d7f50ed96cb55c06c79d2ecc44098e Mon Sep 17 00:00:00 2001 From: Nikhil Sinha Date: Fri, 10 Oct 2025 12:06:22 -0700 Subject: [PATCH 24/39] Linter --- internal/checkpointer/utils.go | 2 +- internal/runner/manager.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/internal/checkpointer/utils.go b/internal/checkpointer/utils.go index 822de32c..a63b5a9d 100644 --- a/internal/checkpointer/utils.go +++ b/internal/checkpointer/utils.go @@ -92,7 +92,7 @@ func downloadAndUntar(ctx context.Context, url, path string) error { if err != nil { return err } - defer devnull.Close() + defer devnull.Close() //nolint:errcheck // What would we do with this error anyways cmd.Stdout = devnull cmd.Stderr = devnull diff --git a/internal/runner/manager.go b/internal/runner/manager.go index 451936bd..fe387e94 100644 --- a/internal/runner/manager.go +++ b/internal/runner/manager.go @@ -295,7 +295,7 @@ func (m *Manager) createDefaultRunner(ctx context.Context) (*Runner, error) { // Make sure the signal handling is running // This runs an infinite loop for handling signals, so we explicitly // do not want to put it in a wait group of any kind - go m.HandleSignals(m.ctx) + go m.HandleSignals(m.ctx) //nolint:contextcheck // We want this to live for the lifetime of the manager } else { args = append(args, "--ipc-url", m.cfg.IPCUrl) } From 788ccbe65bfbe616f6296cd7b8e98d592bfc1510 Mon Sep 17 00:00:00 2001 From: Nikhil Sinha Date: Fri, 10 Oct 2025 13:09:36 -0700 Subject: [PATCH 25/39] Testing --- internal/runner/manager.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/internal/runner/manager.go b/internal/runner/manager.go index fe387e94..1374f857 100644 --- a/internal/runner/manager.go +++ b/internal/runner/manager.go @@ -845,7 +845,9 @@ func (m *Manager) Runners() []*Runner { // findRunner finds a runner by name in the slice func (m *Manager) findRunner(name string) (*Runner, int, bool) { + log := m.logger.Sugar() for i, runner := range m.runners { + log.Warnw("Runner", "name", runner.runnerCtx.id) if runner != nil && runner.runnerCtx.id == name { return runner, i, true } From 7dccc148a875591c82d86fee8c37f3c15a187775 Mon Sep 17 00:00:00 2001 From: Nikhil Sinha Date: Fri, 10 Oct 2025 13:16:07 -0700 Subject: [PATCH 26/39] Testing --- internal/runner/manager.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/internal/runner/manager.go b/internal/runner/manager.go index 1374f857..85dd828f 100644 --- a/internal/runner/manager.go +++ b/internal/runner/manager.go @@ -847,7 +847,9 @@ func (m *Manager) Runners() []*Runner { func (m *Manager) findRunner(name string) (*Runner, int, bool) { log := m.logger.Sugar() for i, runner := range m.runners { - log.Warnw("Runner", "name", runner.runnerCtx.id) + if runner != nil { + log.Warnw("Runner", "name", runner.runnerCtx.id) + } if runner != nil && runner.runnerCtx.id == name { return runner, i, true } From 74a613412515dca50095a19ab09355f90e1a0aee Mon Sep 17 00:00:00 2001 From: Nikhil Sinha Date: Fri, 10 Oct 2025 13:37:14 -0700 Subject: [PATCH 27/39] Move assignment earlier --- internal/runner/manager.go | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/internal/runner/manager.go b/internal/runner/manager.go index 85dd828f..bc3cadb3 100644 --- a/internal/runner/manager.go +++ b/internal/runner/manager.go @@ -376,6 +376,8 @@ commandSetup: return nil, fmt.Errorf("failed to config runner: %w", err) } + m.runners[0] = runner + if !cp.HasCheckpoint() { err = cp.Checkpoint(ctx, cmd, func() error { return waitForRunnerSetup(ctx, runner) }) var FatalCheckpointError *checkpointer.FatalCheckpointError @@ -393,7 +395,6 @@ commandSetup: // nothing } - m.runners[0] = runner m.monitoringWG.Go(func() { m.monitorRunnerSubprocess(m.ctx, DefaultRunnerName, runner) }) @@ -845,11 +846,7 @@ func (m *Manager) Runners() []*Runner { // findRunner finds a runner by name in the slice func (m *Manager) findRunner(name string) (*Runner, int, bool) { - log := m.logger.Sugar() for i, runner := range m.runners { - if runner != nil { - log.Warnw("Runner", "name", runner.runnerCtx.id) - } if runner != nil && runner.runnerCtx.id == name { return runner, i, true } From bb7a47e6c7c16efa28753620cfd1a933850eeda3 Mon Sep 17 00:00:00 2001 From: Nikhil Sinha Date: Fri, 10 Oct 2025 15:55:29 -0700 Subject: [PATCH 28/39] Testing --- internal/checkpointer/checkpointer.go | 8 +++++++- internal/runner/manager.go | 6 +++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/internal/checkpointer/checkpointer.go b/internal/checkpointer/checkpointer.go index a6e156a1..d87132b4 100644 --- a/internal/checkpointer/checkpointer.go +++ b/internal/checkpointer/checkpointer.go @@ -15,6 +15,8 @@ import ( "strconv" "strings" "time" + + "github.com/replicate/cog-runtime/internal/logging" ) const ( @@ -60,13 +62,15 @@ type checkpointer struct { hasCheckpoint bool checkpointDir string leaseFile string + log *logging.SugaredLogger } -func NewCheckpointer(ctx context.Context) Checkpointer { +func NewCheckpointer(ctx context.Context, log *logging.SugaredLogger) Checkpointer { return &checkpointer{ enabled: os.Getenv(shouldCheckpointEnvVar) == "true", checkpointDir: os.Getenv(cudaCheckpointDirEnvVar), leaseFile: os.Getenv(leaseFileEnvVar), + log: log, } } @@ -208,6 +212,7 @@ func (c *checkpointer) Restore(ctx context.Context) (*exec.Cmd, func(context.Con // Get the PID for the command cudaPID, err := exec.CommandContext(con, "pgrep", "-fx", string(cudaCmd)).Output() if err != nil { + c.log.Errorw("failed to pgrep the CUDA command", "error", err) // If this command failed, we want to best effort try to kill the started process, // since we'll start a new one restoreCmd.Process.Kill() //nolint:errcheck // This is just best effort @@ -218,6 +223,7 @@ func (c *checkpointer) Restore(ctx context.Context) (*exec.Cmd, func(context.Con // Toggle CUDA on for the restored process cmd := exec.CommandContext(con, cudaCheckpointPath, "--toggle", "--pid", string(cudaPID)) if err := cmd.Run(); err != nil { + c.log.Errorw("failed to toggle CUDA on", "error", err) // If this command failed, we want to best effort try to kill the started process, // since we'll start a new one restoreCmd.Process.Kill() //nolint:errcheck // This is just best effort diff --git a/internal/runner/manager.go b/internal/runner/manager.go index bc3cadb3..f5ff3c3a 100644 --- a/internal/runner/manager.go +++ b/internal/runner/manager.go @@ -301,7 +301,7 @@ func (m *Manager) createDefaultRunner(ctx context.Context) (*Runner, error) { } // This returns an object that does nothing if it is not enabled. - cp := checkpointer.NewCheckpointer(ctx) + cp := checkpointer.NewCheckpointer(ctx, m.logger.Sugar()) err := cp.Prepare(ctx) if err != nil { cp.Disable() @@ -432,6 +432,7 @@ func (m *Manager) startRunnerFromCheckpoint(ctx context.Context, env []string, r runner, err := m.setupRunner(runtimeContext, runtimeCancel, cmd, env, runnerCtx, maxConcurrency) if err != nil { + m.logger.Sugar().Errorw("failed to set up runner", "error", err) return nil, fmt.Errorf("failed to set up runner: %w", err) } @@ -444,6 +445,9 @@ func (m *Manager) startRunnerFromCheckpoint(ctx context.Context, env []string, r // to the runner. We can do this by sending the SigReady signal to the current PID, as signal // mode should be on if the checkpoint exists err = syscall.Kill(syscall.Getpid(), SigReady) + if err != nil { + m.logger.Sugar().Errorw("failed to send SIGUSR1", "error", err) + } return runner, err } From a1366c4065bd58f774dc8ee64e4f3f525c0aef7b Mon Sep 17 00:00:00 2001 From: Nikhil Sinha Date: Mon, 13 Oct 2025 09:34:50 -0700 Subject: [PATCH 29/39] Do we need shell-job? --- internal/checkpointer/checkpointer.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/internal/checkpointer/checkpointer.go b/internal/checkpointer/checkpointer.go index d87132b4..064fc92d 100644 --- a/internal/checkpointer/checkpointer.go +++ b/internal/checkpointer/checkpointer.go @@ -163,7 +163,7 @@ func (c *checkpointer) Checkpoint(ctx context.Context, cogletCmd *exec.Cmd, wait } // CRIU checkpoint (leaving process running) - cmd = exec.CommandContext(ctx, criuPath, "dump", "--leave-running", "--shell-job", "--tcp-close", "--images-dir", filepath.Join(c.checkpointDir, checkpointSubdirName), "--tree", pid) + cmd = exec.CommandContext(ctx, criuPath, "dump", "--leave-running", "--tcp-close", "--images-dir", filepath.Join(c.checkpointDir, checkpointSubdirName), "--tree", pid) if err := cmd.Run(); err != nil { // Try to toggle CUDA back on. If we aren't able to restart CUDA, the process // will hang indefinitely, so we should kill it and try to start a new one @@ -205,7 +205,7 @@ func (c *checkpointer) Restore(ctx context.Context) (*exec.Cmd, func(context.Con } // Set up restore command - restoreCmd := exec.CommandContext(ctx, criuPath, "restore", "--shell-job", "--tcp-close", "--images-dir", filepath.Join(c.checkpointDir, checkpointSubdirName)) + restoreCmd := exec.CommandContext(ctx, criuPath, "restore", "--tcp-close", "--images-dir", filepath.Join(c.checkpointDir, checkpointSubdirName)) // Set up callback function once restore is started callback := func(con context.Context) error { From f85fd2e72ad03f58332859401198ccf5e10c0187 Mon Sep 17 00:00:00 2001 From: Nikhil Sinha Date: Mon, 13 Oct 2025 13:45:16 -0700 Subject: [PATCH 30/39] Remove stuff only needed for non-cog runtime --- internal/checkpointer/checkpointer.go | 54 ++++++++++----------------- 1 file changed, 20 insertions(+), 34 deletions(-) diff --git a/internal/checkpointer/checkpointer.go b/internal/checkpointer/checkpointer.go index 064fc92d..ee4b8a6d 100644 --- a/internal/checkpointer/checkpointer.go +++ b/internal/checkpointer/checkpointer.go @@ -34,7 +34,6 @@ const ( criuPath = "/tmp/criu" // Metadata storage paths - cudaCmdFileName = "cuda-cmd" checkpointSubdirName = "checkpoint" ) @@ -142,20 +141,6 @@ func (c *checkpointer) Checkpoint(ctx context.Context, cogletCmd *exec.Cmd, wait cudaPID := strings.TrimSpace(string(cudaPIDBytes)) - // Get the command for this PID - it is _not_ always the root python process - data, err := exec.CommandContext(ctx, "ps", "-o", "cmd=", cudaPID).Output() - if err != nil { - return err - } - - cudaCmd := strings.TrimSpace(string(data)) - - // Write said command to a file for later - err = os.WriteFile(filepath.Join(c.checkpointDir, cudaCmdFileName), []byte(cudaCmd), 0o644) - if err != nil { - return err - } - // Toggle CUDA off cmd := exec.CommandContext(ctx, cudaCheckpointPath, "--toggle", "--pid", cudaPID) if err := cmd.Run(); err != nil { @@ -198,35 +183,18 @@ func (c *checkpointer) Restore(ctx context.Context) (*exec.Cmd, func(context.Con return nil, nil, nil } - // Read process from sentinel file - cudaCmd, err := os.ReadFile(filepath.Join(c.checkpointDir, cudaCmdFileName)) - if err != nil { - return nil, nil, err - } - // Set up restore command restoreCmd := exec.CommandContext(ctx, criuPath, "restore", "--tcp-close", "--images-dir", filepath.Join(c.checkpointDir, checkpointSubdirName)) // Set up callback function once restore is started callback := func(con context.Context) error { - // Get the PID for the command - cudaPID, err := exec.CommandContext(con, "pgrep", "-fx", string(cudaCmd)).Output() - if err != nil { - c.log.Errorw("failed to pgrep the CUDA command", "error", err) - // If this command failed, we want to best effort try to kill the started process, - // since we'll start a new one - restoreCmd.Process.Kill() //nolint:errcheck // This is just best effort - - return err - } - // Toggle CUDA on for the restored process - cmd := exec.CommandContext(con, cudaCheckpointPath, "--toggle", "--pid", string(cudaPID)) + cmd := exec.CommandContext(con, cudaCheckpointPath, "--toggle", "--pid", strconv.Itoa(restoreCmd.Process.Pid)) if err := cmd.Run(); err != nil { c.log.Errorw("failed to toggle CUDA on", "error", err) // If this command failed, we want to best effort try to kill the started process, // since we'll start a new one - restoreCmd.Process.Kill() //nolint:errcheck // This is just best effort + killProcess(restoreCmd) //nolint:errcheck // This is just best effort return err } @@ -238,6 +206,24 @@ func (c *checkpointer) Restore(ctx context.Context) (*exec.Cmd, func(context.Con return restoreCmd, callback, nil } +func killProcess(cmd *exec.Cmd) error { + err := cmd.Process.Kill() + if err != nil { + return err + } + + // Wait for the process to exit with a 5 second timeout + done := make(chan error, 1) + go func() { done <- cmd.Wait() }() + + select { + case err = <-done: + return err + case <-time.After(5 * time.Second): + return nil + } +} + func (c *checkpointer) WriteReadyFile() error { // If it isn't expected, make this a no-op if os.Getenv(shouldCheckpointEnvVar) != "true" { From df9a17f33dd9f28cafcfc9e671aa8262c5351d53 Mon Sep 17 00:00:00 2001 From: Nikhil Sinha Date: Mon, 13 Oct 2025 15:02:21 -0700 Subject: [PATCH 31/39] Testing --- internal/runner/manager.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/internal/runner/manager.go b/internal/runner/manager.go index f5ff3c3a..2f4c9a23 100644 --- a/internal/runner/manager.go +++ b/internal/runner/manager.go @@ -436,6 +436,8 @@ func (m *Manager) startRunnerFromCheckpoint(ctx context.Context, env []string, r return nil, fmt.Errorf("failed to set up runner: %w", err) } + time.Sleep(10 * time.Minute) + err = postSetupCallback(runtimeContext) if err != nil { return nil, fmt.Errorf("failed callback function: %w", err) From 3b8fcfaa8336eed48e0b16f803abd347f3ae557b Mon Sep 17 00:00:00 2001 From: Nikhil Sinha Date: Mon, 13 Oct 2025 15:05:10 -0700 Subject: [PATCH 32/39] Return shell job to command --- internal/checkpointer/checkpointer.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/internal/checkpointer/checkpointer.go b/internal/checkpointer/checkpointer.go index ee4b8a6d..95982be0 100644 --- a/internal/checkpointer/checkpointer.go +++ b/internal/checkpointer/checkpointer.go @@ -148,7 +148,7 @@ func (c *checkpointer) Checkpoint(ctx context.Context, cogletCmd *exec.Cmd, wait } // CRIU checkpoint (leaving process running) - cmd = exec.CommandContext(ctx, criuPath, "dump", "--leave-running", "--tcp-close", "--images-dir", filepath.Join(c.checkpointDir, checkpointSubdirName), "--tree", pid) + cmd = exec.CommandContext(ctx, criuPath, "dump", "--shell-job", "--leave-running", "--tcp-close", "--images-dir", filepath.Join(c.checkpointDir, checkpointSubdirName), "--tree", pid) if err := cmd.Run(); err != nil { // Try to toggle CUDA back on. If we aren't able to restart CUDA, the process // will hang indefinitely, so we should kill it and try to start a new one @@ -184,7 +184,7 @@ func (c *checkpointer) Restore(ctx context.Context) (*exec.Cmd, func(context.Con } // Set up restore command - restoreCmd := exec.CommandContext(ctx, criuPath, "restore", "--tcp-close", "--images-dir", filepath.Join(c.checkpointDir, checkpointSubdirName)) + restoreCmd := exec.CommandContext(ctx, criuPath, "restore", "--shell-job", "--tcp-close", "--images-dir", filepath.Join(c.checkpointDir, checkpointSubdirName)) // Set up callback function once restore is started callback := func(con context.Context) error { From a79c0f556d8a4a87f574be0bb1cbe4c497e13122 Mon Sep 17 00:00:00 2001 From: Nikhil Sinha Date: Mon, 13 Oct 2025 16:00:02 -0700 Subject: [PATCH 33/39] Testing v2 --- internal/runner/manager.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/internal/runner/manager.go b/internal/runner/manager.go index 2f4c9a23..3548dbed 100644 --- a/internal/runner/manager.go +++ b/internal/runner/manager.go @@ -430,14 +430,14 @@ func (m *Manager) startRunnerFromCheckpoint(ctx context.Context, env []string, r return nil, err } + time.Sleep(10 * time.Minute) + runner, err := m.setupRunner(runtimeContext, runtimeCancel, cmd, env, runnerCtx, maxConcurrency) if err != nil { m.logger.Sugar().Errorw("failed to set up runner", "error", err) return nil, fmt.Errorf("failed to set up runner: %w", err) } - time.Sleep(10 * time.Minute) - err = postSetupCallback(runtimeContext) if err != nil { return nil, fmt.Errorf("failed callback function: %w", err) From bc1e6083e451017e8dd403ff850ee1bfa67b4d3a Mon Sep 17 00:00:00 2001 From: Nikhil Sinha Date: Mon, 13 Oct 2025 17:02:33 -0700 Subject: [PATCH 34/39] Remove testing sleep --- internal/runner/manager.go | 2 -- 1 file changed, 2 deletions(-) diff --git a/internal/runner/manager.go b/internal/runner/manager.go index 3548dbed..f5ff3c3a 100644 --- a/internal/runner/manager.go +++ b/internal/runner/manager.go @@ -430,8 +430,6 @@ func (m *Manager) startRunnerFromCheckpoint(ctx context.Context, env []string, r return nil, err } - time.Sleep(10 * time.Minute) - runner, err := m.setupRunner(runtimeContext, runtimeCancel, cmd, env, runnerCtx, maxConcurrency) if err != nil { m.logger.Sugar().Errorw("failed to set up runner", "error", err) From 9058499d3009f518d62e025efebe435974525d5c Mon Sep 17 00:00:00 2001 From: Nikhil Sinha Date: Mon, 13 Oct 2025 17:10:52 -0700 Subject: [PATCH 35/39] Add verbose logging --- internal/checkpointer/checkpointer.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/checkpointer/checkpointer.go b/internal/checkpointer/checkpointer.go index 95982be0..643f067d 100644 --- a/internal/checkpointer/checkpointer.go +++ b/internal/checkpointer/checkpointer.go @@ -184,7 +184,7 @@ func (c *checkpointer) Restore(ctx context.Context) (*exec.Cmd, func(context.Con } // Set up restore command - restoreCmd := exec.CommandContext(ctx, criuPath, "restore", "--shell-job", "--tcp-close", "--images-dir", filepath.Join(c.checkpointDir, checkpointSubdirName)) + restoreCmd := exec.CommandContext(ctx, criuPath, "restore", "--shell-job", "--tcp-close", "-v4", "--log-file", "restore.log", "--images-dir", filepath.Join(c.checkpointDir, checkpointSubdirName)) // Set up callback function once restore is started callback := func(con context.Context) error { From 32c2e81fc46b922b7d7aba4e4e85b26f05002f37 Mon Sep 17 00:00:00 2001 From: Nikhil Sinha Date: Mon, 13 Oct 2025 17:40:17 -0700 Subject: [PATCH 36/39] Different verbose logging cmd --- internal/checkpointer/checkpointer.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/checkpointer/checkpointer.go b/internal/checkpointer/checkpointer.go index 643f067d..1a62fad6 100644 --- a/internal/checkpointer/checkpointer.go +++ b/internal/checkpointer/checkpointer.go @@ -184,7 +184,7 @@ func (c *checkpointer) Restore(ctx context.Context) (*exec.Cmd, func(context.Con } // Set up restore command - restoreCmd := exec.CommandContext(ctx, criuPath, "restore", "--shell-job", "--tcp-close", "-v4", "--log-file", "restore.log", "--images-dir", filepath.Join(c.checkpointDir, checkpointSubdirName)) + restoreCmd := exec.CommandContext(ctx, criuPath, "restore", "--shell-job", "--tcp-close", "-v4", "-o", "/tmp/restore.log", "--images-dir", filepath.Join(c.checkpointDir, checkpointSubdirName)) // Set up callback function once restore is started callback := func(con context.Context) error { From d77288b3982e4924df45f866a6306bc30eb90df6 Mon Sep 17 00:00:00 2001 From: Nikhil Sinha Date: Mon, 13 Oct 2025 17:47:41 -0700 Subject: [PATCH 37/39] Print logs from cuda checkpoint --- internal/checkpointer/checkpointer.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/internal/checkpointer/checkpointer.go b/internal/checkpointer/checkpointer.go index 1a62fad6..7ec7ad10 100644 --- a/internal/checkpointer/checkpointer.go +++ b/internal/checkpointer/checkpointer.go @@ -190,6 +190,8 @@ func (c *checkpointer) Restore(ctx context.Context) (*exec.Cmd, func(context.Con callback := func(con context.Context) error { // Toggle CUDA on for the restored process cmd := exec.CommandContext(con, cudaCheckpointPath, "--toggle", "--pid", strconv.Itoa(restoreCmd.Process.Pid)) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr if err := cmd.Run(); err != nil { c.log.Errorw("failed to toggle CUDA on", "error", err) // If this command failed, we want to best effort try to kill the started process, From f0ebc3721ab2f11c407fb33cc761e56705dfd79b Mon Sep 17 00:00:00 2001 From: Nikhil Sinha Date: Mon, 13 Oct 2025 18:06:58 -0700 Subject: [PATCH 38/39] More logging --- internal/checkpointer/checkpointer.go | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/internal/checkpointer/checkpointer.go b/internal/checkpointer/checkpointer.go index 7ec7ad10..5ecee259 100644 --- a/internal/checkpointer/checkpointer.go +++ b/internal/checkpointer/checkpointer.go @@ -184,10 +184,16 @@ func (c *checkpointer) Restore(ctx context.Context) (*exec.Cmd, func(context.Con } // Set up restore command - restoreCmd := exec.CommandContext(ctx, criuPath, "restore", "--shell-job", "--tcp-close", "-v4", "-o", "/tmp/restore.log", "--images-dir", filepath.Join(c.checkpointDir, checkpointSubdirName)) + restoreCmd := exec.CommandContext(ctx, criuPath, "restore", "--shell-job", "--tcp-close", "--images-dir", filepath.Join(c.checkpointDir, checkpointSubdirName)) // Set up callback function once restore is started callback := func(con context.Context) error { + out, err := exec.CommandContext(con, "ps", "aux").Output() + if err != nil { + fmt.Println(err.Error()) + } + fmt.Println(out) + fmt.Println(strconv.Itoa(restoreCmd.Process.Pid)) // Toggle CUDA on for the restored process cmd := exec.CommandContext(con, cudaCheckpointPath, "--toggle", "--pid", strconv.Itoa(restoreCmd.Process.Pid)) cmd.Stdout = os.Stdout From 44d0a1e8cb36a99645b4295a421c13858b098b6d Mon Sep 17 00:00:00 2001 From: Nikhil Sinha Date: Mon, 13 Oct 2025 18:09:35 -0700 Subject: [PATCH 39/39] log instead --- internal/checkpointer/checkpointer.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/internal/checkpointer/checkpointer.go b/internal/checkpointer/checkpointer.go index 5ecee259..e850789a 100644 --- a/internal/checkpointer/checkpointer.go +++ b/internal/checkpointer/checkpointer.go @@ -190,10 +190,10 @@ func (c *checkpointer) Restore(ctx context.Context) (*exec.Cmd, func(context.Con callback := func(con context.Context) error { out, err := exec.CommandContext(con, "ps", "aux").Output() if err != nil { - fmt.Println(err.Error()) + c.log.Infow(err.Error()) } - fmt.Println(out) - fmt.Println(strconv.Itoa(restoreCmd.Process.Pid)) + c.log.Infow(string(out)) + c.log.Infow(strconv.Itoa(restoreCmd.Process.Pid)) // Toggle CUDA on for the restored process cmd := exec.CommandContext(con, cudaCheckpointPath, "--toggle", "--pid", strconv.Itoa(restoreCmd.Process.Pid)) cmd.Stdout = os.Stdout