Skip to content

Commit a8bb462

Browse files
jeffkbkimfacebook-github-bot
authored andcommitted
Cleaner MetricModule interface (#3554)
Summary: Utilities for CPUOffloadedRecMetricModule and RecMetricModule. Also raise exceptions in the main thread if any of the background threads. Added unit tests. Simplify the core metric types: - MetricsResult = Dict[str, MetricValue]: sync metrics computation - MetricsFuture = concurrent.futures.Future[MetricsResult]: for async computation - MetricsOutput = Union[MetricsResult, MetricsFuture]: Either a MetricsResult, or a MetricsFuture Introduce a metrics_output_util to handle the logic between futures and dicts. Users can schedule callbacks via `on_metrics_ready()` Introduce `device` argument to RecMetricModule constructor. It is a noop for the standard metric module, but CPUOffloadedRecMetricModule requires it to determine whether to perform GPU to CPU transfers. Differential Revision: D87110900
1 parent 32e5431 commit a8bb462

File tree

8 files changed

+452
-77
lines changed

8 files changed

+452
-77
lines changed

torchrec/metrics/cpu_comms_metric_module.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
import logging
1010
from typing import Any, cast, Dict
1111

12+
import torch
13+
1214
from torch import nn
1315

1416
from torch.profiler import record_function
@@ -48,7 +50,8 @@ def __init__(
4850
"""
4951
All arguments are the same as RecMetricModule
5052
"""
51-
53+
# Ensure device is set to CPU
54+
kwargs["device"] = torch.device("cpu")
5255
super().__init__(*args, **kwargs)
5356

5457
rec_metrics_clone = self._clone_rec_metrics()
@@ -106,9 +109,6 @@ def _load_metric_states(
106109
Uses aggregated states.
107110
"""
108111

109-
# All update() calls were done prior. Clear previous computed state.
110-
# Otherwise, we get warnings that compute() was called before
111-
# update() which is not the case.
112112
computation = cast(RecMetricComputation, computation)
113113
set_update_called(computation)
114114
computation._computed = None
@@ -157,8 +157,9 @@ def _clone_rec_metrics(self) -> RecMetricList:
157157

158158
def set_update_called(computation: RecMetricComputation) -> None:
159159
"""
160-
Set _update_called to True for RecMetricComputation.
161-
This is a workaround for torchmetrics 1.0.3+.
160+
All update() calls were done prior. Clear previous computed state.
161+
Otherwise, we get warnings that compute() was called before
162+
update() which is not the case.
162163
"""
163164
try:
164165
computation._update_called = True

torchrec/metrics/cpu_offloaded_metric_module.py

Lines changed: 48 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
MetricUpdateJob,
2323
SynchronizationMarker,
2424
)
25-
from torchrec.metrics.metric_module import MetricValue, RecMetricModule
25+
from torchrec.metrics.metric_module import MetricsFuture, MetricsResult, RecMetricModule
2626
from torchrec.metrics.metric_state_snapshot import MetricStateSnapshot
2727
from torchrec.metrics.model_utils import parse_task_model_outputs
2828
from torchrec.metrics.rec_metric import RecMetricException
@@ -74,7 +74,9 @@ def __init__(
7474
- compute_queue_size: Maximum size of the update queue. Default is 100.
7575
"""
7676
super().__init__(*args, **kwargs)
77-
self._shutdown_event = threading.Event()
77+
self._shutdown_event: threading.Event = threading.Event()
78+
self._captured_exception_event: threading.Event = threading.Event()
79+
self._captured_exception: Optional[Exception] = None
7880

7981
self.update_queue: queue.Queue[
8082
Union[MetricUpdateJob, SynchronizationMarker]
@@ -132,8 +134,16 @@ def _update_rec_metrics(
132134
if self._shutdown_event.is_set():
133135
raise RecMetricException("metric processor thread is shut down.")
134136

137+
if self._captured_exception_event.is_set():
138+
assert self._captured_exception is not None
139+
raise self._captured_exception
140+
135141
try:
136-
cpu_model_out, transfer_completed_event = self._transfer_to_cpu(model_out)
142+
cpu_model_out, transfer_completed_event = (
143+
self._transfer_to_cpu(model_out)
144+
if self.device == torch.device("cuda")
145+
else (model_out, None)
146+
)
137147
self.update_queue.put_nowait(
138148
MetricUpdateJob(
139149
model_out=cpu_model_out,
@@ -191,31 +201,25 @@ def _process_metric_update_job(self, metric_update_job: MetricUpdateJob) -> None
191201
"""
192202

193203
with record_function("## CPUOffloadedRecMetricModule:update ##"):
194-
try:
204+
if metric_update_job.transfer_completed_event is not None:
195205
metric_update_job.transfer_completed_event.synchronize()
196-
labels, predictions, weights, required_inputs = (
197-
parse_task_model_outputs(
198-
self.rec_tasks,
199-
metric_update_job.model_out,
200-
self.get_required_inputs(),
201-
)
202-
)
203-
if required_inputs:
204-
metric_update_job.kwargs["required_inputs"] = required_inputs
205-
206-
self.rec_metrics.update(
207-
predictions=predictions,
208-
labels=labels,
209-
weights=weights,
210-
**metric_update_job.kwargs,
211-
)
212-
213-
if self.throughput_metric:
214-
self.throughput_metric.update()
206+
labels, predictions, weights, required_inputs = parse_task_model_outputs(
207+
self.rec_tasks,
208+
metric_update_job.model_out,
209+
self.get_required_inputs(),
210+
)
211+
if required_inputs:
212+
metric_update_job.kwargs["required_inputs"] = required_inputs
213+
214+
self.rec_metrics.update(
215+
predictions=predictions,
216+
labels=labels,
217+
weights=weights,
218+
**metric_update_job.kwargs,
219+
)
215220

216-
except Exception as e:
217-
logger.exception("Error processing metric update: %s", e)
218-
raise e
221+
if self.throughput_metric:
222+
self.throughput_metric.update()
219223

220224
@override
221225
def shutdown(self) -> None:
@@ -248,30 +252,34 @@ def shutdown(self) -> None:
248252
logger.info("CPUOffloadedRecMetricModule has been successfully shutdown.")
249253

250254
@override
251-
def compute(self) -> Dict[str, MetricValue]:
255+
def compute(self) -> MetricsResult:
252256
raise RecMetricException(
253-
"compute() is not supported in CPUOffloadedRecMetricModule. Use async_compute() instead."
257+
"CPUOffloadedRecMetricModule does not support compute(). Use async_compute() instead."
254258
)
255259

256260
@override
257-
def async_compute(
258-
self, future: concurrent.futures.Future[Dict[str, MetricValue]]
259-
) -> None:
261+
def async_compute(self) -> MetricsFuture:
260262
"""
261263
Entry point for asynchronous metric compute. It enqueues a synchronization marker
262264
to the update queue.
263265
264-
Args:
266+
Returns:
265267
future: Pre-created future where the computed metrics will be set.
266268
"""
269+
metrics_future = concurrent.futures.Future()
267270
if self._shutdown_event.is_set():
268-
future.set_exception(
271+
metrics_future.set_exception(
269272
RecMetricException("metric processor thread is shut down.")
270273
)
271-
return
274+
return metrics_future
275+
276+
if self._captured_exception_event.is_set():
277+
assert self._captured_exception is not None
278+
raise self._captured_exception
272279

273-
self.update_queue.put_nowait(SynchronizationMarker(future))
280+
self.update_queue.put_nowait(SynchronizationMarker(metrics_future))
274281
self.update_queue_size_logger.add(self.update_queue.qsize())
282+
return metrics_future
275283

276284
def _process_synchronization_marker(
277285
self, synchronization_marker: SynchronizationMarker
@@ -304,7 +312,7 @@ def _process_synchronization_marker(
304312

305313
def _process_metric_compute_job(
306314
self, metric_compute_job: MetricComputeJob
307-
) -> Dict[str, MetricValue]:
315+
) -> MetricsResult:
308316
"""
309317
Process a metric compute job:
310318
1. Comms module performs all gather
@@ -355,6 +363,8 @@ def _update_loop(self) -> None:
355363
self._do_work(self.update_queue)
356364
except Exception as e:
357365
logger.exception(f"Exception in update loop: {e}")
366+
self._captured_exception_event.set()
367+
self._captured_exception = e
358368
raise e
359369

360370
remaining = self._flush_remaining_work(self.update_queue)
@@ -372,6 +382,8 @@ def _compute_loop(self) -> None:
372382
self._do_work(self.compute_queue)
373383
except Exception as e:
374384
logger.exception(f"Exception in compute loop: {e}")
385+
self._captured_exception_event.set()
386+
self._captured_exception = e
375387
raise e
376388

377389
remaining = self._flush_remaining_work(self.compute_queue)

torchrec/metrics/metric_job_types.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
# pyre-strict
99

1010
import concurrent
11-
from typing import Any, Dict
11+
from typing import Any, Dict, Optional
1212

1313
import torch
1414
from torchrec.metrics.metric_module import MetricValue
@@ -26,7 +26,7 @@ class MetricUpdateJob:
2626
def __init__(
2727
self,
2828
model_out: Dict[str, torch.Tensor],
29-
transfer_completed_event: torch.cuda.Event,
29+
transfer_completed_event: Optional[torch.cuda.Event],
3030
kwargs: Dict[str, Any],
3131
) -> None:
3232
"""
@@ -37,7 +37,9 @@ def __init__(
3737
"""
3838

3939
self.model_out: Dict[str, torch.Tensor] = model_out
40-
self.transfer_completed_event: torch.cuda.Event = transfer_completed_event
40+
self.transfer_completed_event: Optional[torch.cuda.Event] = (
41+
transfer_completed_event
42+
)
4143
self.kwargs: Dict[str, Any] = kwargs
4244

4345

torchrec/metrics/metric_module.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,9 @@
117117

118118

119119
MetricValue = Union[torch.Tensor, float]
120+
MetricsResult = Dict[str, MetricValue]
121+
MetricsFuture = concurrent.futures.Future[MetricsResult]
122+
MetricsOutput = Union[MetricsResult, MetricsFuture]
120123

121124

122125
class StateMetric(abc.ABC):
@@ -125,7 +128,7 @@ class StateMetric(abc.ABC):
125128
"""
126129

127130
@abc.abstractmethod
128-
def get_metrics(self) -> Dict[str, MetricValue]:
131+
def get_metrics(self) -> MetricsResult:
129132
pass
130133

131134

@@ -189,6 +192,7 @@ def __init__(
189192
self,
190193
batch_size: int,
191194
world_size: int,
195+
device: torch.device,
192196
rec_tasks: Optional[List[RecTaskInfo]] = None,
193197
rec_metrics: Optional[RecMetricList] = None,
194198
throughput_metric: Optional[ThroughputMetric] = None,
@@ -205,6 +209,7 @@ def __init__(
205209
self.trained_batches: int = 0
206210
self.batch_size = batch_size
207211
self.world_size = world_size
212+
self.device = device
208213
self.oom_count = 0
209214
self.compute_count = 0
210215

@@ -315,12 +320,12 @@ def _adjust_compute_interval(self) -> None:
315320
def should_compute(self) -> bool:
316321
return self.trained_batches % self.compute_interval_steps == 0
317322

318-
def compute(self) -> Dict[str, MetricValue]:
323+
def compute(self) -> MetricsResult:
319324
r"""compute() is called when the global metrics are required, usually
320325
right before logging the metrics results to the data sink.
321326
"""
322327
self.compute_count += 1
323-
ret: Dict[str, MetricValue] = {}
328+
ret: MetricsResult = {}
324329
with record_function("## RecMetricModule:compute ##"):
325330
if self.rec_metrics:
326331
self._adjust_compute_interval()
@@ -337,11 +342,11 @@ def compute(self) -> Dict[str, MetricValue]:
337342
)
338343
return ret
339344

340-
def local_compute(self) -> Dict[str, MetricValue]:
345+
def local_compute(self) -> MetricsResult:
341346
r"""local_compute() is called when per-trainer metrics are required. It's
342347
can be used for debugging. Currently only rec_metrics is supported.
343348
"""
344-
ret: Dict[str, MetricValue] = {}
349+
ret: MetricsResult = {}
345350
if self.rec_metrics:
346351
ret.update(self.rec_metrics.local_compute())
347352
return ret
@@ -492,9 +497,7 @@ def load_pre_compute_states(
492497
def shutdown(self) -> None:
493498
logger.info("Initiating graceful shutdown...")
494499

495-
def async_compute(
496-
self, future: concurrent.futures.Future[Dict[str, MetricValue]]
497-
) -> None:
500+
def async_compute(self) -> MetricsFuture:
498501
raise RecMetricException("async_compute is not supported in RecMetricModule")
499502

500503

@@ -610,6 +613,7 @@ def generate_metric_module(
610613
metrics = metric_class(
611614
batch_size=batch_size,
612615
world_size=world_size,
616+
device=device,
613617
rec_tasks=metrics_config.rec_tasks,
614618
rec_metrics=rec_metrics,
615619
throughput_metric=throughput_metric,

0 commit comments

Comments
 (0)