torchvision.utils.make_grid()
(https://pytorch.org/vision/stable/generated/torchvision.utils.make_grid.html)
이미지 텐서들을 모아서 grid 형태로 만들어주는 함수로, 이미지 시각화에 유용하다.
공식문서를 참고해서 설명해보면 다음과 같다.
Parameters
- tensor (Tensor or list) : 4D mini-batch Tensor (Batch, Channel, Height, Width) 또는 같은 크기의 이미지 리스트
- nrow (int, optional) : 한 행에 표시될 이미지의 개수. 최종 그리드의 형태는 ( Batch / nrow, nrow )가 됩니다. (Default : 8)
- padding (int, optional) : 전체 텐서의 좌우상하 각각 padding + 이미지 사이 간격 pdding (Default : 2)
- normalize (bool, optional) : True 일 경우, image 를 0~1 값으로 변환. (아래 value_range의 min, max 값을 기준) (Default : False)
- value_range (tuple, optional) : 튜플 (min, max) 을 따로 설정가능. (Default : tensor에서 자동으로 min, max 계산)
- scale_each (bool, optional) : True 일 경우, 전체 이미지에서 min, max를 찾는 것이 아니라 각 이미지 별로 scaling 진행. (Default : False)
- pad_value (float, optional) : 패딩 되는 픽셀의 값 (Default : 0)
Returns
- tensor : 이미지의 그리드를 담은 텐서 반환. (type : grid(Tensor))
- 쉽게말해 여러 이미지 텐서를 타일처럼 합쳐서 하나의 텐서로 만들어 준다.
예시
현재 구현중인 "Semi-supervised Learning with Deep Generative Models" 에서 사용한 코드의 일부를 이용해보겠다.
- 좌표공간 생성하는 사용자 정의함수
def latent_img(grid_size, grid_range = (-5, 5), latent_size = 2):
grid = []
for _ in range(latent_size):
axis = np.linspace(grid_range[0], grid_range[1], grid_size) # -5부터 5까지 10개
grid.append(axis)
grid_points = np.meshgrid(*grid) # 한개의 axis로 좌표공간 생성
grid_points = np.column_stack([point.ravel() for point in grid_points])
return grid_points
미리 학습시켜서 저장한 M2 model 을 이용하여 논문 7p 있는 이미지를 구현하고자 한다.
model = torch.load('models/M2.pt') # 기존에 학습시킨 M2 모델을 로드
save_path = 'imgs/'
for i in range(10):
grids = latent_img(10)
latent_image = [model.M2Decoder(torch.cat([torch.FloatTensor(grid),
F.one_hot(torch.LongTensor([i]), num_classes = 10).reshape(-1)])).reshape(-1, 28, 28)
for grid in grids] # latent_size + label_size = 12 : M2Decoder input
# latent_image[1].size() # 1x28x28
latent_grid_img = torchvision.utils.make_grid(latent_image, nrow = 10)
# Channel : 3 (RGB)
# Height : 302
# Width : 302
# 3x302x302 (280+4+2x9) 4:좌우끝, 2:이미지 사이 간격
# grid shape (Channel, Height, Width)
plt.imshow(latent_grid_img.permute(1, 2, 0), cmap = 'gray') # CxHxW -> HxWxC
plt.axis('off') # 축 정보 제거
plt.savefig(f'{save_path}latent_image_{i}.png', bbox_inches = 'tight', pad_inches = 0)
plt.show()
모델에 대한 이해가 없으면 나머지 차원을 이해하기 어렵겠지만, 한가지 이해하고 갈 수 있는 부분은
make_grid를 이용하여 grid를 생성하면 다음과 같이 차원이 설정된다.
- Channel : 3 (RGB)
- Height : 302 = 28x10(세로 이미지 10 개) + 4 (위아래 끝) + 2(두 이미지 사이 간격 패딩) x9
- Wdith : 302 = 28x10(가로 이미지 2 개) + 4 (좌우 끝) + 2(두 이미지 사이 간격 패딩) x9
이런식으로 한번에 100개의 이미지를 한번에 그려주는 grid를 만들어주는 역할을 하는 함수인 make_grid()에 대해 알아보았다.
'PyTorch' 카테고리의 다른 글
[PyTorch] torch.tensor.detach() (0) | 2024.07.17 |
---|---|
[PyTorch] torch.scatter_ 알아보자 (1) | 2023.12.14 |
[PyTorch] model.zero_grad()와 optimizer.zero_grad()의 차이 (1) | 2023.10.15 |
[PyTorch] torch.no_grad()와 model.eval()의 차이 (2) | 2023.09.18 |