요즘 많이 겪는 문제가 nan loss이다. 커스텀 레이어와 loss를 쓰다 보니 미처 파악하지 못한 예외가 생긴다. 그래서 nan loss이 발생했을 때 원인을 찾고 해결하는 방법에 대해 짧게 적어보려고 한다.
원인이 되는 연산 찾기
먼저 torch.autograd 함수 중에 NaN loss가 발생했을 경우 원인을 찾아주는 함수가 있다.
autograd.set_detect_anomaly(True)
학습 코드에 위 코드를 추가해주고 실험을 하면, NaN loss가 발생하는 즉시 실행이 멈추고 NaN을 유발한 라인을 출력해준다. 주로 division by zero나 매우 작은 값에 대한 log 연산이 NaN loss를 유발한다. NaN은 loss 연산 뿐만 아니라 forward 연산, backward 연산에서도 발생할 수 있으므로 직접 찾으려면 힘든데, 위 코드를 쓰면 간편하다.
연산 수정하기
나누는 연산이 있는데 divisor가 0이 될 수 있는 경우라면, 예외 처리를 해주거나 divisor에 1e-6 등 연산에 영향을 끼치지 않는 작은 상수를 더해주면 된다. log도 마찬가지다. log(x)에서 x가 매우 작은 값이 될 수 있다면, x에 상수를 더해주면 된다. 또는 NaN 값을 0으로 바꾸어주는 torch 함수를 쓰자.
a = torch.nan_to_num(a)
주의할 점은 nan_to_num은 PyTorch 1.8.0 이후부터 지원된다.
원인이 되는 연산을 알았지만 이유를 모르겠다
- Gradient exploding / vanishing
원인이 되는 레이어의 weight과 grad를 출력해보면 알 수 있다.
torch.any(torch.isnan(weight)) # weight에 NaN 존재 여부
model.layer.grad # layer의 gradient
- Learning rate이 너무 높을 경우
- PyTorch 내장함수 중 나눗셈 연산이 있는 함수를 썼을 경우
이외에도
- input data에 nan이 있는 경우
df.isnull().sum() # 결측치 개수 확인
- data type을 float으로 변경해보자
- optimizer를 변경해보자
- 특성 스케일링 방법을 변경해보자 ex) (0, 1)에서 (-1,1)
- 데이터와 output_size가 일치하지 않는 경우 ex) 예측하고자 하는 클래스 수와 마지막 Dense layer 노드 수의 불일치
'Error' 카테고리의 다른 글
TypeError: unsupported format string passed to numpy.ndarray.__format__ (1) | 2023.09.09 |
---|---|
matplotlib 한글 깨지는 문제 해결 (1) | 2023.09.07 |