diff --git a/slurm_usage.py b/slurm_usage.py index bf1b6de..ce4cf34 100755 --- a/slurm_usage.py +++ b/slurm_usage.py @@ -31,7 +31,7 @@ from datetime import datetime, timedelta, timezone from getpass import getuser from pathlib import Path -from typing import Annotated, Any, NamedTuple +from typing import Annotated, Any, Callable, Literal, Mapping, NamedTuple, cast import polars as pl import typer @@ -290,7 +290,7 @@ def run_squeue() -> CommandResult: CommandResult with stdout, stderr, and return code """ - cmd = ["squeue", "-ro", "%u/%t/%D/%P/%C/%N/%h"] + cmd = ["squeue", "-ro", "%u/%t/%D/%P/%C/%N/%h/%m"] return _run(cmd) @@ -930,11 +930,12 @@ class SlurmJob(NamedTuple): cores: int node: str oversubscribe: str + memory_mb: float @classmethod def from_line(cls, line: str) -> SlurmJob: """Create a SlurmJob from a squeue output line.""" - user, status, nnodes, partition, cores, node, oversubscribe = line.split("/") + user, status, nnodes, partition, cores, node, oversubscribe, memory = line.split("/") return cls( user, status, @@ -943,6 +944,7 @@ def from_line(cls, line: str) -> SlurmJob: int(cores), node, oversubscribe, + _parse_memory_mb(memory), ) @@ -968,53 +970,78 @@ def get_total_cores(node_name: str) -> int: return 0 # Return 0 if not found +ResourceMetric = Literal["cores", "nodes", "memory"] +STATUS_ROUNDING_EPSILON = 1e-9 +RESOURCE_OPTION = cast( + list[str] | None, + typer.Option( + None, + "--resource", + "-r", + help="Resources to summarize (repeatable). Choose from cores, nodes, memory.", + ), +) + + +class ResourceAggregation(NamedTuple): + """Aggregated resource counts for users, partitions, and overall totals.""" + + per_user: defaultdict[str, defaultdict[str, defaultdict[str, float]]] + per_partition: defaultdict[str, defaultdict[str, float]] + totals: defaultdict[str, float] + + def process_data( output: list[SlurmJob], - cores_or_nodes: str, -) -> tuple[ - defaultdict[str, defaultdict[str, defaultdict[str, int]]], - defaultdict[str, defaultdict[str, int]], - defaultdict[str, int], -]: - """Process SLURM job data and aggregate statistics.""" - data: defaultdict[str, defaultdict[str, defaultdict[str, int]]] = defaultdict( - lambda: defaultdict(lambda: defaultdict(int)), + metric: ResourceMetric, +) -> ResourceAggregation: + """Process SLURM job data and aggregate statistics by resource metric.""" + data: defaultdict[str, defaultdict[str, defaultdict[str, float]]] = defaultdict( + lambda: defaultdict(lambda: defaultdict(float)), ) - total_partition: defaultdict[str, defaultdict[str, int]] = defaultdict( - lambda: defaultdict(int), + total_partition: defaultdict[str, defaultdict[str, float]] = defaultdict( + lambda: defaultdict(float), ) - totals: defaultdict[str, int] = defaultdict(int) + totals: defaultdict[str, float] = defaultdict(float) - # Track which nodes have been counted for each user + # Track which nodes have been counted for each user when resources are exclusive counted_nodes: defaultdict[str, set[str]] = defaultdict(set) for s in output: - if s.oversubscribe in ["NO", "USER"]: + if metric == "memory": + value = s.memory_mb / 1024 if s.memory_mb > 0 else 0.0 # Convert to GB + elif s.oversubscribe in ["NO", "USER"]: if s.node not in counted_nodes[s.user]: - n = get_total_cores(s.node) # Get total cores in the node - # Mark this node as counted for this user + value = float(get_total_cores(s.node)) counted_nodes[s.user].add(s.node) else: - continue # Skip this job to prevent double-counting + continue # Skip to prevent double-counting exclusive nodes else: - n = s.nnodes if cores_or_nodes == "nodes" else s.cores + value = float(s.nnodes) if metric == "nodes" else float(s.cores) - # Update the data structures with the correct values - data[s.user][s.partition][s.status] += n - total_partition[s.partition][s.status] += n - totals[s.status] += n + data[s.user][s.partition][s.status] += value + total_partition[s.partition][s.status] += value + totals[s.status] += value - return data, total_partition, totals + return ResourceAggregation(data, total_partition, totals) -def summarize_status(d: dict[str, int]) -> str: +def summarize_status(d: Mapping[str, float], formatter: Callable[[float], str] | None = None) -> str: """Summarize status dictionary into a readable string.""" - return " / ".join([f"{status}={n}" for status, n in d.items()]) + def _format(value: float) -> str: + if formatter is not None: + return formatter(value) + if abs(value - round(value)) < STATUS_ROUNDING_EPSILON: + return str(int(round(value))) + return f"{value:.2f}".rstrip("0").rstrip(".") + + return " / ".join([f"{status}={_format(n)}" for status, n in d.items()]) -def combine_statuses(d: dict[str, Any]) -> dict[str, int]: + +def combine_statuses(d: dict[str, Any]) -> dict[str, float]: """Combine multiple status dictionaries into one.""" - tot: defaultdict[str, int] = defaultdict(int) + tot: defaultdict[str, float] = defaultdict(float) for dct in d.values(): for status, n in dct.items(): tot[status] += n @@ -2626,25 +2653,64 @@ def status( console.print(f"\n[bold]Disk Usage:[/bold] {total_size / (1024**2):.1f} MB") +def _resource_formatter(metric: ResourceMetric) -> Callable[[float], str]: + if metric == "memory": + return lambda value: f"{value:.1f} GB" + return lambda value: str(int(round(value))) + + @app.command() -def current() -> None: +def current( + resources: list[str] | None = RESOURCE_OPTION, +) -> None: """Display current cluster usage statistics from squeue.""" output = squeue_output() me = getuser() - for which in ["cores", "nodes"]: - data, total_partition, totals = process_data(output, which) + + if isinstance(resources, list): + resources_list: list[str] | None = resources if resources else None + else: + resources_list = None + + allowed_metrics: tuple[ResourceMetric, ...] = ("cores", "nodes", "memory") + seen: set[str] = set() + ordered_resources: list[ResourceMetric] = [] + + if resources_list is None: + ordered_resources = ["cores", "nodes"] + else: + for entry in resources_list: + normalized = entry.lower() + if normalized not in allowed_metrics: + error_message = "Invalid resource '{entry}'. Choose from cores, nodes, memory." + raise typer.BadParameter(error_message.format(entry=entry)) + if normalized in seen: + continue + seen.add(normalized) + ordered_resources.append(cast(ResourceMetric, normalized)) + + if not ordered_resources: + ordered_resources = ["cores", "nodes"] + + for which in ordered_resources: + aggregated = process_data(output, which) + data = aggregated.per_user + total_partition = aggregated.per_partition + totals = aggregated.totals + formatter = _resource_formatter(which) table = Table(title=f"SLURM statistics [b]{which}[/]", show_footer=True) partitions = sorted(total_partition.keys()) table.add_column("User", f"{len(data)} users", style="cyan") for partition in partitions: - tot = summarize_status(total_partition[partition]) + tot = summarize_status(total_partition[partition], formatter) table.add_column(partition, tot, style="magenta") - table.add_column("Total", summarize_status(totals), style="magenta") + table.add_column("Total", summarize_status(totals, formatter), style="magenta") for user, _stats in sorted(data.items()): kw = {"style": "bold italic"} if user == me else {} - partition_stats = [summarize_status(_stats[p]) if p in _stats else "-" for p in partitions] - table.add_row(user, *partition_stats, summarize_status(combine_statuses(_stats)), **kw) + partition_stats = [summarize_status(_stats[p], formatter) if p in _stats else "-" for p in partitions] + total_summary = summarize_status(combine_statuses(_stats), formatter) + table.add_row(user, *partition_stats, total_summary, **kw) console.print(table, justify="center") diff --git a/tests/snapshots/command_map.json b/tests/snapshots/command_map.json index 6392ba9..ed4040b 100644 --- a/tests/snapshots/command_map.json +++ b/tests/snapshots/command_map.json @@ -1,5 +1,5 @@ { - "squeue -ro %u/%t/%D/%P/%C/%N/%h": "squeue", + "squeue -ro %u/%t/%D/%P/%C/%N/%h/%m": "squeue", "sinfo -h -N --format='%N,%c'": "sinfo_cpus", "sinfo -h -N --format='%N,%G'": "sinfo_gpus", "sacct -a -S 2025-08-21T00:00:00 -E 2025-08-21T23:59:59 --format=JobID,JobIDRaw,JobName,User,UID,Group,GID,Account,Partition,QOS,State,ExitCode,Submit,Eligible,Start,End,Elapsed,ElapsedRaw,CPUTime,CPUTimeRAW,TotalCPU,UserCPU,SystemCPU,AllocCPUS,AllocNodes,NodeList,ReqCPUS,ReqMem,ReqNodes,Timelimit,TimelimitRaw,MaxRSS,MaxVMSize,MaxDiskRead,MaxDiskWrite,AveRSS,AveCPU,AveVMSize,ConsumedEnergy,ConsumedEnergyRaw,Priority,Reservation,ReservationId,WorkDir,Cluster,ReqTRES,AllocTRES,Comment,Constraints,Container,DerivedExitCode,Flags,Layout,MaxRSSNode,MaxVMSizeNode,MinCPU,NCPUS,NNodes,NTasks,Reason,SubmitLine -P -n": "sacct_day_0", diff --git a/tests/snapshots/squeue_output.txt b/tests/snapshots/squeue_output.txt index 16598dc..507a593 100644 --- a/tests/snapshots/squeue_output.txt +++ b/tests/snapshots/squeue_output.txt @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:55a30b738d8d3d4e694a77390200cf59ec310cbba4aca249bf63472e72cad6fc -size 10591 +oid sha256:ce806e379dfd5f8cf025f57b483f4a226b04f3b5ca66d747609b9c7ea6d4d464 +size 479 diff --git a/tests/test_cli_commands.py b/tests/test_cli_commands.py index b6551c0..9c4b5d3 100644 --- a/tests/test_cli_commands.py +++ b/tests/test_cli_commands.py @@ -163,7 +163,10 @@ def test_process_data(self) -> None: """Test process_data function.""" jobs = slurm_usage.squeue_output() - data, total_partition, totals = slurm_usage.process_data(jobs, "cores") + aggregated = slurm_usage.process_data(jobs, "cores") + data = aggregated.per_user + total_partition = aggregated.per_partition + totals = aggregated.totals assert isinstance(data, dict) assert isinstance(total_partition, dict) diff --git a/tests/test_data_processing.py b/tests/test_data_processing.py index 33bb4fb..2f54540 100644 --- a/tests/test_data_processing.py +++ b/tests/test_data_processing.py @@ -324,7 +324,7 @@ class TestSqueueParsing: def test_slurm_job_from_line(self) -> None: """Test creating SlurmJob from squeue line.""" - line = "alice/R/1/partition-01/4/node-001/OK" + line = "alice/R/1/partition-01/4/node-001/OK/8G" job = slurm_usage.SlurmJob.from_line(line) assert job.user == "alice" @@ -335,6 +335,7 @@ def test_slurm_job_from_line(self) -> None: assert job.cores == expected_cores assert job.node == "node-001" assert job.oversubscribe == "OK" + assert job.memory_mb == pytest.approx(8192.0) def test_squeue_output(self) -> None: """Test parsing full squeue output.""" diff --git a/tests/test_slurm_commands.py b/tests/test_slurm_commands.py index 1329339..9b9e131 100644 --- a/tests/test_slurm_commands.py +++ b/tests/test_slurm_commands.py @@ -30,8 +30,8 @@ def test_run_squeue(self) -> None: """Test squeue command.""" result = slurm_usage.run_squeue() assert result.returncode == 0 - assert "USER/ST/NODES/PARTITION" in result.stdout - assert result.command == "squeue -ro %u/%t/%D/%P/%C/%N/%h" + assert "USER/ST/NODES/PARTITION/CPUS/NODELIST/OVER_SUBSCRIBE/MEMORY" in result.stdout + assert result.command == "squeue -ro %u/%t/%D/%P/%C/%N/%h/%m" # Parse the output lines = result.stdout.strip().split("\n") @@ -40,7 +40,7 @@ def test_run_squeue(self) -> None: # Check first data line format if len(lines) > 1: parts = lines[1].split("/") - expected_parts = 7 # user/status/nodes/partition/cpus/nodelist/oversubscribe + expected_parts = 8 # user/status/nodes/partition/cpus/nodelist/oversubscribe/memory assert len(parts) == expected_parts def test_run_sinfo_cpus(self) -> None: diff --git a/tests/test_slurm_usage.py b/tests/test_slurm_usage.py index 28368f7..d4b3623 100644 --- a/tests/test_slurm_usage.py +++ b/tests/test_slurm_usage.py @@ -28,10 +28,10 @@ ) # Mock squeue output -squeue_mock_output = """USER/ST/NODES/PARTITION -bas.nijholt/PD/1/mypartition-10 -bas.nijholt/PD/1/mypartition-20 -bas.nijholt/PD/1/mypartition-20""" +squeue_mock_output = """USER/ST/NODES/PARTITION/CPUS/NODELIST/OVER_SUBSCRIBE/MEMORY +bas.nijholt/PD/1/mypartition-10/8/node-001/OK/8000M +bas.nijholt/PD/1/mypartition-20/16/node-002/OK/16000M +bas.nijholt/PD/1/mypartition-20/16/node-003/OK/16000M""" @pytest.fixture @@ -63,11 +63,13 @@ def test_process_data() -> None: # Create proper SlurmJob objects instead of strings output = [ - SlurmJob("user1", "R", 2, "partition1", 10, "node1", "YES"), - SlurmJob("user2", "PD", 1, "partition2", 5, "node2", "YES"), - SlurmJob("user1", "PD", 1, "partition1", 5, "node3", "YES"), + SlurmJob("user1", "R", 2, "partition1", 10, "node1", "YES", 2048.0), + SlurmJob("user2", "PD", 1, "partition2", 5, "node2", "YES", 1024.0), + SlurmJob("user1", "PD", 1, "partition1", 5, "node3", "YES", 1024.0), ] - data, total_partition, totals = process_data(output, "nodes") + aggregated_nodes = process_data(output, "nodes") + data = aggregated_nodes.per_user + totals = aggregated_nodes.totals expected_r_count = 2 expected_pd_single = 1 expected_pd_total = 2 @@ -76,6 +78,11 @@ def test_process_data() -> None: assert totals["PD"] == expected_pd_total assert totals["R"] == expected_r_count + aggregated_memory = process_data(output, "memory") + totals_memory = aggregated_memory.totals + assert totals_memory["R"] == pytest.approx(2.0) + assert totals_memory["PD"] == pytest.approx(2.0) + def test_summarize_status() -> None: """Test summarize_status function."""