본문 바로가기
PyTorch

[PyTorch] torch.tensor.detach()

by holy_jjjae 2024. 7. 17.

https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html

 

torch.Tensor.detach — PyTorch 2.3 documentation

Shortcuts

pytorch.org

 

 파이토치에서 Tensor 객체의 detach() 메소드는 현재 Tensor 객체와 동일한 데이터를 가지지만 연산 그래프(Computational Graph) 에서 분리된 새로운 Tensor 객체를 생성한다. 즉, 이 메소드는 일반적으로 Tensor 객체를 다른 Tensor 객체로 변환하고자 할 때 사용된다.

 

 자세히 설명해보면, 주어진 Tensor 객체에 대한 연산의 결과로 생성된 새로운 Tensor 객체가 있을 때, 이 새로운 Tensor 객체를 사용하여 추가적인 계산을 수행하고자 할 때, 기존 Tensor 객체의 연산 그래프와의 의존성을 제거하여 메모리 사용량을 줄이고 계산 속도를 향상시키는 데 유용하다. 여기서 연산 그래프와의 의존성을 제거한다는 의미는 requires_grad 속성을 False로 설정하여 기존 Tensor 객체와 다르게 자동 미분 기능에서 제외된다는 뜻이다.

 

따라서 detach() 메소드를 사용하여 생성된 Tensor 객체는

 

(1) 그라디언트(gradient) 계산에 사용되지 않으며,

(2) 그라디언트를 계산하지 않는 모델에서 중간 출력 값을 얻을 때 특히 유용하다.

 

예를 들어, 다음과 같이 Tensor 객체 a와 b가 있을 때, detach() 메소드를 사용하여 Tensor 객체 b를 새로운 Tensor 객체 c로 분리하고자 할 때 다음과 같이 작성할 수 있다.

import torch 

a = torch.tensor([1.0, 2.0, 3.0], requires_grad=True) 
b = 2 * a 
c = b.detach()

 

위 코드에서, Tensor 객체 b는 Tensor 객체 a를 사용하여 생성되었으며, requires_grad 속성이 True로 설정되어 있다.  그러나 detach() 메소드를 사용하여 생성된 Tensor 객체 c는 requires_grad 속성이 False로 설정되어있다.

 

따라서, c에 대한 연산은 a와의 의존성을 가지지 않으며, c로부터 추가적인 Tensor 객체를 생성할 때 메모리 사용량이 줄어들게 된다.

 

좀 더 직관적인 예제를 한번 살펴보자.

 

import torch
import torch.nn as nn

class Test(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = nn.Linear(10, 10)
        self.layer2 = nn.Linear(10, 10)
        
    def forward(self, x):
        out1 = self.layer1(x)
        out2 = self.layer2(out1.detach())   # detach 사용
        return out2
        
model = Test()

 

 forward 함수를 살펴보면, layer1에서 나온 output이 detach되는 것을 볼 수 있다.  이 경우 역전파 때 gradient가 이전 layer인 layer1으로 흘러가지 않는다.