torch.scatter_()
https://pytorch.org/docs/stable/generated/torch.Tensor.scatter_.html#torch.Tensor.scatter_
Parameter
- dim : scatter 할 기준이 되는 축. '0'이면 행 방향, '1'이면 열 방향
- index (LongTensor) : 흩뿌릴 element들의 index. 즉, 어떤 숫자를 어떤 규칙으로 옮길지 결정하는 tensor
- src : 어떤 숫자들이 옮겨지는지 그 후보를 담은 소스 tensor
torch.scatter는 scatter_의 out of place 버전이다. 흔히 inplace = False 옵션을 쓰는 것과 동일하다. 따라서 scatter로 tensor를 조작하고 다시 변수에 할당해주어야 한다.
예시 코드
src = torch.arange(1, 11).reshape((2, 5))
src # tensor([[ 1, 2, 3, 4, 5],
# [ 6, 7, 8, 9, 10]])
index = torch.tensor([[0, 1, 2, 0]])
torch.zeros(3, 5, dtype=src.dtype).scatter_(0, index, src)
# tensor([[1, 0, 0, 4, 0],
# [0, 2, 0, 0, 0],
# [0, 0, 3, 0, 0]])
index = torch.tensor([[0, 1, 2], [0, 1, 4]])
torch.zeros(3, 5, dtype=src.dtype).scatter_(1, index, src)
# tensor([[1, 2, 3, 0, 0],
# [6, 7, 0, 0, 8],
# [0, 0, 0, 0, 0]])
torch.full((2, 4), 2.).scatter_(1, torch.tensor([[2], [3]]),
1.23, reduce='multiply')
# tensor([[2.0000, 2.0000, 2.4600, 2.0000],
# [2.0000, 2.0000, 2.0000, 2.4600]])
torch.full((2, 4), 2.).scatter_(1, torch.tensor([[2], [3]]),
1.23, reduce='add')
# tensor([[2.0000, 2.0000, 3.2300, 2.0000],
# [2.0000, 2.0000, 2.0000, 3.2300]])
첫번째 코드를 그림으로 이해해보면 다음과 같다.
↓
- dim = 0
- index = [0, 1, 2, 0] : 각 원소들은 어떤 행으로 보낼지 결정한다. 순서대로, 0행 1행 2행 0행
- src = [[ 1, 2, 3, 4, 5], [ 6, 7, 8, 9, 10]] : index가 1 by 4 배열이기 때문에 1, 2, 3, 4 만 옮겨지게 되고, 각 숫자의 열정보 자체가 어떤 열로 보내지는지를 결정한다. 즉, 0열 1열 2열 3열
정리해보면, src에서 1, 2, 3, 4 가 각각 (0행 0열), (1행 1열), (2행 2열), (0행 3열) 으로 옮겨지게 된다.
'PyTorch' 카테고리의 다른 글
[PyTorch] torch.tensor.detach() (0) | 2024.07.17 |
---|---|
[PyTorch] make_grid() 사용하는 방법 (0) | 2023.12.13 |
[PyTorch] model.zero_grad()와 optimizer.zero_grad()의 차이 (1) | 2023.10.15 |
[PyTorch] torch.no_grad()와 model.eval()의 차이 (2) | 2023.09.18 |