본문 바로가기
PyTorch

[PyTorch] make_grid() 사용하는 방법

by holy_jjjae 2023. 12. 13.

torchvision.utils.make_grid()

(https://pytorch.org/vision/stable/generated/torchvision.utils.make_grid.html) 

 

make_grid — Torchvision 0.16 documentation

Shortcuts

pytorch.org

 

이미지 텐서들을 모아서 grid 형태로 만들어주는 함수로, 이미지 시각화에 유용하다.  

공식문서를 참고해서 설명해보면 다음과 같다.

 

Parameters 

  • tensor (Tensor or list) : 4D mini-batch Tensor (Batch, Channel, Height, Width) 또는 같은 크기의 이미지 리스트
  • nrow (int, optional) : 한 행에 표시될 이미지의 개수. 최종 그리드의 형태는 ( Batch / nrow, nrow )가 됩니다. (Default : 8)
  • padding (int, optional) : 전체 텐서의 좌우상하 각각 padding + 이미지 사이 간격 pdding (Default : 2)
  • normalize (bool, optional) : True 일 경우, image 를 0~1 값으로 변환. (아래 value_range의 min, max 값을 기준) (Default : False)
  • value_range (tuple, optional) : 튜플 (min, max) 을 따로 설정가능. (Default : tensor에서 자동으로 min, max 계산)
  • scale_each (bool, optional) : True 일 경우, 전체 이미지에서 min, max를 찾는 것이 아니라 각 이미지 별로 scaling 진행. (Default : False)
  • pad_value (float, optional) : 패딩 되는 픽셀의 값 (Default : 0)

 

Returns

  • tensor : 이미지의 그리드를 담은 텐서 반환. (type : grid(Tensor))
  • 쉽게말해 여러 이미지 텐서를 타일처럼 합쳐서 하나의 텐서로 만들어 준다.

 

 


예시

 현재 구현중인 "Semi-supervised Learning with Deep Generative Models" 에서 사용한 코드의 일부를 이용해보겠다.

(자세한 코드는 https://github.com/ImJaeSung/VAE/tree/main/Semisupervised%20learning%20with%20deep%20generative%20models)

 

 

- 좌표공간 생성하는 사용자 정의함수

 

def latent_img(grid_size, grid_range = (-5, 5), latent_size = 2):
    grid = []
    for _ in range(latent_size):
        axis = np.linspace(grid_range[0], grid_range[1], grid_size) # -5부터 5까지 10개 
        grid.append(axis)

    grid_points = np.meshgrid(*grid) # 한개의 axis로 좌표공간 생성 
    grid_points = np.column_stack([point.ravel() for point in grid_points])

    return grid_points

 

미리 학습시켜서 저장한 M2 model 을 이용하여 논문 7p 있는 이미지를 구현하고자 한다.

 

model = torch.load('models/M2.pt') # 기존에 학습시킨 M2 모델을 로드
save_path = 'imgs/'

for i in range(10):
    grids = latent_img(10)
    latent_image = [model.M2Decoder(torch.cat([torch.FloatTensor(grid), 
    	F.one_hot(torch.LongTensor([i]), num_classes = 10).reshape(-1)])).reshape(-1, 28, 28)
        for grid in grids] # latent_size + label_size = 12 : M2Decoder input
    
    # latent_image[1].size() # 1x28x28
    latent_grid_img = torchvision.utils.make_grid(latent_image, nrow = 10) 
    # Channel : 3 (RGB)
    # Height : 302 
    # Width : 302
    # 3x302x302 (280+4+2x9) 4:좌우끝, 2:이미지 사이 간격
    # grid shape (Channel, Height, Width)
    plt.imshow(latent_grid_img.permute(1, 2, 0), cmap = 'gray') # CxHxW -> HxWxC
    plt.axis('off')  # 축 정보 제거
    plt.savefig(f'{save_path}latent_image_{i}.png', bbox_inches = 'tight', pad_inches = 0)
    plt.show()

 

 모델에 대한 이해가 없으면 나머지 차원을 이해하기 어렵겠지만, 한가지 이해하고 갈 수 있는 부분은

make_grid를 이용하여 grid를 생성하면 다음과 같이 차원이 설정된다.

 

  • Channel : 3 (RGB)
  • Height :  302 = 28x10(세로 이미지 10 개) + 4 (위아래 끝) + 2(두 이미지 사이 간격 패딩) x9
  • Wdith : 302 = 28x10(가로 이미지 2 개) + 4 (좌우 끝) + 2(두 이미지 사이 간격 패딩) x9

 

 

이런식으로 한번에 100개의 이미지를 한번에 그려주는 grid를 만들어주는 역할을 하는 함수인 make_grid()에 대해 알아보았다.