본문 바로가기
PyTorch

[PyTorch] torch.scatter_ 알아보자

by holy_jjjae 2023. 12. 14.

torch.scatter_()

https://pytorch.org/docs/stable/generated/torch.Tensor.scatter_.html#torch.Tensor.scatter_

Parameter

  • dim : scatter 할 기준이 되는 축. '0'이면 행 방향,  '1'이면 열 방향 
  • index (LongTensor) : 흩뿌릴 element들의 index. 즉, 어떤 숫자를 어떤 규칙으로 옮길지 결정하는 tensor  
  • src : 어떤 숫자들이 옮겨지는지 그 후보를 담은 소스 tensor

 

 

 torch.scatter는 scatter_의 out of place 버전이다. 흔히 inplace = False 옵션을 쓰는 것과 동일하다. 따라서 scatter로 tensor를 조작하고 다시 변수에 할당해주어야 한다.

 


예시 코드

src = torch.arange(1, 11).reshape((2, 5))
src # tensor([[ 1,  2,  3,  4,  5],
    #         [ 6,  7,  8,  9, 10]])
        
index = torch.tensor([[0, 1, 2, 0]])
torch.zeros(3, 5, dtype=src.dtype).scatter_(0, index, src)
# tensor([[1, 0, 0, 4, 0],
#         [0, 2, 0, 0, 0],
#         [0, 0, 3, 0, 0]])
        
index = torch.tensor([[0, 1, 2], [0, 1, 4]])
torch.zeros(3, 5, dtype=src.dtype).scatter_(1, index, src)
# tensor([[1, 2, 3, 0, 0],
#         [6, 7, 0, 0, 8],
#         [0, 0, 0, 0, 0]])

torch.full((2, 4), 2.).scatter_(1, torch.tensor([[2], [3]]),
           1.23, reduce='multiply')
# tensor([[2.0000, 2.0000, 2.4600, 2.0000],
#         [2.0000, 2.0000, 2.0000, 2.4600]])
        
torch.full((2, 4), 2.).scatter_(1, torch.tensor([[2], [3]]),
           1.23, reduce='add')
# tensor([[2.0000, 2.0000, 3.2300, 2.0000],
#         [2.0000, 2.0000, 2.0000, 3.2300]])

 

 

첫번째 코드를 그림으로 이해해보면 다음과 같다.

`

  • dim = 0
  • index = [0, 1, 2, 0] : 각 원소들은 어떤 행으로 보낼지 결정한다. 순서대로, 0행 1행 2행 0행
  • src  = [[ 1,  2,  3,  4,  5], [ 6,  7,  8,  9, 10]] : index가 1 by 4 배열이기 때문에 1, 2, 3, 4 만 옮겨지게 되고, 각 숫자의 열정보 자체가 어떤 열로 보내지는지를 결정한다. 즉, 0열 1열 2열 3열

 

정리해보면, src에서 1, 2, 3, 4 가 각각 (0행 0열), (1행 1열), (2행 2열), (0행 3열) 으로 옮겨지게 된다.