Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 71 additions & 77 deletions intermediate_source/per_sample_grads.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
# -*- coding: utf-8 -*-
"""
Per-sample-gradients
====================
표본별 변화도
==============

What is it?
-----------
표본별 변화도란
----------------

Per-sample-gradient computation is computing the gradient for each and every
sample in a batch of data. It is a useful quantity in differential privacy,
meta-learning, and optimization research.
표본별 변화도(per-sample-gradient) 계산은 데이터 배치에 있는 각 표본의 변화도를 하나씩
계산하는 작업입니다. 이는 차등 개인정보 보호(differential privacy), 메타 학습(meta-learning),
최적화 연구에서 유용하게 쓰이는 값입니다.

.. note::

This tutorial requires PyTorch 2.0.0 or later.
이 튜토리얼을 실행하려면 PyTorch 2.0.0 이상이 필요합니다.

"""

Expand All @@ -21,7 +21,7 @@
import torch.nn.functional as F
torch.manual_seed(0)

# Here's a simple CNN and loss function:
# 간단한 CNN과 손실 함수를 정의합니다.

class SimpleCNN(nn.Module):
def __init__(self):
Expand Down Expand Up @@ -49,8 +49,8 @@ def loss_fn(predictions, targets):


######################################################################
# Let’s generate a batch of dummy data and pretend that we’re working with an MNIST dataset.
# The dummy images are 28 by 28 and we use a minibatch of size 64.
# 더미 데이터 배치를 만들고 MNIST 데이터셋으로 작업한다고 가정해 보겠습니다.
# 더미 이미지는 28 x 28 크기이며 크기가 64인 미니 배치를 사용합니다.

device = 'cuda'

Expand All @@ -61,25 +61,23 @@ def loss_fn(predictions, targets):
targets = torch.randint(10, (64,), device=device)

######################################################################
# In regular model training, one would forward the minibatch through the model,
# and then call .backward() to compute gradients. This would generate an
# 'average' gradient of the entire mini-batch:
# 일반적인 모델 학습에서는 미니 배치를 모델에 전달해 순전파를 수행한 다음 .backward()를
# 호출하여 변화도를 계산합니다. 그러면 전체 미니 배치에 대한 '평균' 변화도가 만들어집니다.

model = SimpleCNN().to(device=device)
predictions = model(data) # move the entire mini-batch through the model
predictions = model(data) # 전체 미니 배치를 모델에 전달합니다.

loss = loss_fn(predictions, targets)
loss.backward() # back propagate the 'average' gradient of this mini-batch
loss.backward() # 이 미니 배치의 '평균' 변화도를 역전파합니다.

######################################################################
# In contrast to the above approach, per-sample-gradient computation is
# equivalent to:
# 위 방식과 달리 표본별 변화도 계산은 다음 과정과 같습니다.
#
# - for each individual sample of the data, perform a forward and a backward
# pass to get an individual (per-sample) gradient.
# - 데이터의 각 표본에 대해 순전파와 역전파를 수행하여
# 개별 표본의 변화도, 즉 표본별 변화도를 얻습니다.

def compute_grad(sample, target):
sample = sample.unsqueeze(0) # prepend batch dimension for processing
sample = sample.unsqueeze(0) # 처리를 위해 배치 차원을 앞에 추가합니다.
target = target.unsqueeze(0)

prediction = model(sample)
Expand All @@ -89,7 +87,7 @@ def compute_grad(sample, target):


def compute_sample_grads(data, targets):
""" manually process each sample with per sample gradient """
"""각 표본을 직접 처리하여 표본별 변화도를 구합니다."""
sample_grads = [compute_grad(data[i], targets[i]) for i in range(batch_size)]
sample_grads = zip(*sample_grads)
sample_grads = [torch.stack(shards) for shards in sample_grads]
Expand All @@ -98,41 +96,40 @@ def compute_sample_grads(data, targets):
per_sample_grads = compute_sample_grads(data, targets)

######################################################################
# ``sample_grads[0]`` is the per-sample-grad for model.conv1.weight.
# ``model.conv1.weight.shape`` is ``[32, 1, 3, 3]``; notice how there is one
# gradient, per sample, in the batch for a total of 64.
# ``sample_grads[0]`` model.conv1.weight에 대한 표본별 변화도입니다.
# ``model.conv1.weight.shape`` ``[32, 1, 3, 3]`` 입니다.
# 배치의 각 표본마다 변화도가 하나씩 있으므로 총 64개라는 점을 확인할 수 있습니다.

print(per_sample_grads[0].shape)

######################################################################
# Per-sample-grads, *the efficient way*, using function transforms
# ----------------------------------------------------------------
# We can compute per-sample-gradients efficiently by using function transforms.
# 함수 변환으로 표본별 변화도를 *효율적으로* 계산하기
# ------------------------------------------------------
# 함수 변환(function transform)을 사용하면 표본별 변화도를 효율적으로 계산할 수 있습니다.
#
# The ``torch.func`` function transform API transforms over functions.
# Our strategy is to define a function that computes the loss and then apply
# transforms to construct a function that computes per-sample-gradients.
# ``torch.func`` 함수 변환 API는 함수에 변환을 적용합니다.
# 여기서는 먼저 손실을 계산하는 함수를 정의한 다음
# 변환을 적용하여 표본별 변화도를 계산하는 함수를 구성합니다.
#
# We'll use the ``torch.func.functional_call`` function to treat an ``nn.Module``
# like a function.
# ``torch.func.functional_call`` 함수를 사용하여 ``nn.Module`` 을 함수처럼 다룹니다.
#
# First, let’s extract the state from ``model`` into two dictionaries,
# parameters and buffers. We'll be detaching them because we won't use
# regular PyTorch autograd (e.g. Tensor.backward(), torch.autograd.grad).
# 먼저 ``model`` 의 상태를 parameters와 buffers라는 두 딕셔너리로 추출합니다.
# 일반적인 PyTorch autograd(예: Tensor.backward(), torch.autograd.grad)는 사용하지 않으므로
# 이 값을 detach합니다.

from torch.func import functional_call, vmap, grad

params = {k: v.detach() for k, v in model.named_parameters()}
buffers = {k: v.detach() for k, v in model.named_buffers()}

######################################################################
# Next, let's define a function to compute the loss of the model given a
# single input rather than a batch of inputs. It is important that this
# function accepts the parameters, the input, and the target, because we will
# be transforming over them.
# 다음으로 입력 배치가 아니라 단일 입력이 주어졌을 때
# 모델의 손실을 계산하는 함수를 정의하겠습니다.
# 이 함수는 매개변수, 입력, target을 인자로 받아야 합니다.
# 변환을 이 인자들에 대해 적용할 예정이기 때문입니다.
#
# Note - because the model was originally written to handle batches, we’ll
# use ``torch.unsqueeze`` to add a batch dimension.
# 참고로 모델은 원래 배치를 처리하도록 작성되었으므로 ``torch.unsqueeze`` 로
# 배치 차원을 추가합니다.

def compute_loss(params, buffers, sample, target):
batch = sample.unsqueeze(0)
Expand All @@ -143,53 +140,50 @@ def compute_loss(params, buffers, sample, target):
return loss

######################################################################
# Now, let’s use the ``grad`` transform to create a new function that computes
# the gradient with respect to the first argument of ``compute_loss``
# (i.e. the ``params``).
# 이제 ``grad`` 변환을 사용하여 ``compute_loss`` 의 첫 번째 인자,
# 즉 ``params`` 에 대한 변화도를 계산하는 새 함수를 만듭니다.

ft_compute_grad = grad(compute_loss)

######################################################################
# The ``ft_compute_grad`` function computes the gradient for a single
# (sample, target) pair. We can use ``vmap`` to get it to compute the gradient
# over an entire batch of samples and targets. Note that
# ``in_dims=(None, None, 0, 0)`` because we wish to map ``ft_compute_grad`` over
# the 0th dimension of the data and targets, and use the same ``params`` and
# buffers for each.
# ``ft_compute_grad`` 함수는 단일 (sample, target) 쌍에 대한 변화도를 계산합니다.
# ``vmap`` 을 사용하면 표본과 target의 전체 배치에 대해 변화도를 계산하게 할 수 있습니다.
# ``ft_compute_grad`` 를 data와 targets의 0번째 차원에 매핑하면서
# 각 표본에는 같은 ``params`` 와 buffers를 사용하려 하므로
# ``in_dims=(None, None, 0, 0)`` 으로 지정합니다.

ft_compute_sample_grad = vmap(ft_compute_grad, in_dims=(None, None, 0, 0))

######################################################################
# Finally, let's used our transformed function to compute per-sample-gradients:
# 마지막으로 변환된 함수를 사용하여 표본별 변화도를 계산합니다.

ft_per_sample_grads = ft_compute_sample_grad(params, buffers, data, targets)

######################################################################
# we can double check that the results using ``grad`` and ``vmap`` match the
# results of hand processing each one individually:
# ``grad`` ``vmap`` 을 사용한 결과가
# 각 표본을 직접 하나씩 처리한 결과와 일치하는지 다시 확인할 수 있습니다.

for per_sample_grad, ft_per_sample_grad in zip(per_sample_grads, ft_per_sample_grads.values()):
assert torch.allclose(per_sample_grad, ft_per_sample_grad, atol=1.2e-1, rtol=1e-5)

######################################################################
# A quick note: there are limitations around what types of functions can be
# transformed by ``vmap``. The best functions to transform are ones that are pure
# functions: a function where the outputs are only determined by the inputs,
# and that have no side effects (e.g. mutation). ``vmap`` is unable to handle
# mutation of arbitrary Python data structures, but it is able to handle many
# in-place PyTorch operations.
# 간단히 덧붙이면 ``vmap`` 으로 변환할 수 있는 함수 유형에는 제한이 있습니다.
# 변환하기에 가장 좋은 함수는 순수 함수입니다.
# 순수 함수는 출력이 오직 입력으로만 결정되고 부수 효과(예: 변경)가 없는 함수입니다.
# ``vmap`` 은 임의의 Python 자료 구조 변경을 처리할 수는 없지만,
# 많은 제자리 PyTorch 연산은 처리할 수 있습니다.
#
# Performance comparison
# ----------------------
# 성능 비교
# ---------
#
# Curious about how the performance of ``vmap`` compares?
# ``vmap`` 의 성능이 어느 정도인지 궁금할 수 있습니다.
#
# Currently the best results are obtained on newer GPU's such as the A100
# (Ampere) where we've seen up to 25x speedups on this example, but here are
# some results on our build machines:
# 현재는 A100(Ampere) 같은 최신 GPU에서 가장 좋은 결과를 얻을 수 있으며,
# 이 예제에서는 최대 25배의 속도 향상을 확인했습니다.
# 아래는 빌드 머신에서 얻은 몇 가지 결과입니다.

def get_perf(first, first_descriptor, second, second_descriptor):
"""takes torch.benchmark objects and compares delta of second vs first."""
"""torch.benchmark 객체를 받아 첫 번째 결과와 두 번째 결과의 차이를 비교합니다."""
second_res = second.times[0]
first_res = first.times[0]

Expand All @@ -212,14 +206,14 @@ def get_perf(first, first_descriptor, second, second_descriptor):
get_perf(with_vmap_timing, "vmap", no_vmap_timing, "no vmap")

######################################################################
# There are other optimized solutions (like in https://github.com/pytorch/opacus)
# to computing per-sample-gradients in PyTorch that also perform better than
# the naive method. But it’s cool that composing ``vmap`` and ``grad`` give us a
# nice speedup.
# PyTorch에서 표본별 변화도를 계산하는 데에는
# https://github.com/pytorch/opacus 같은 다른 최적화된 해법도 있으며,
# 이 방법들 역시 단순한 방법보다 더 좋은 성능을 냅니다.
# 그래도 ``vmap`` 과 ``grad`` 를 조합하는 것만으로도
# 꽤 좋은 속도 향상을 얻을 수 있다는 점은 흥미롭습니다.
#
# In general, vectorization with ``vmap`` should be faster than running a function
# in a for-loop and competitive with manual batching. There are some exceptions
# though, like if we haven’t implemented the ``vmap`` rule for a particular
# operation or if the underlying kernels weren’t optimized for older hardware
# (GPUs). If you see any of these cases, please let us know by opening an issue
# at on GitHub.
# 일반적으로 ``vmap`` 을 이용한 벡터화는 함수를 for 루프에서 실행하는 것보다 빠르고,
# 수동 배치 처리와 비교해도 경쟁력 있는 성능을 냅니다.
# 다만 예외도 있습니다. 특정 연산에 대한 ``vmap`` 규칙이 아직 구현되지 않았거나,
# 하위 커널이 오래된 하드웨어(GPU)에 맞게 최적화되지 않은 경우가 그렇습니다.
# 이런 사례를 발견하면 GitHub에 이슈를 열어 알려주세요.