pytorch에서 loss 값이 nan이 나올 때!

2024. 10. 22. 16:28오류해결

 

이번에 준비하고 있는 논문을 위해서 아이디어가 나올 때마다 실험을 하고 있는데, 이번에 어째서인지 loss가 잘 떨어지다가 중간에 nan 값이 되어버리는 문제가 발생하였다...

 

처음부터 튀어버리면 loss 함수 자체에서 문제가 발생했다고 생각했겠지만 감소 중에 nan 값으로 바뀐 것이기 때문에 원인 파악이 어려웠다. 이럴 땐 뭐다? 바로 구글링을 하는 것이다.

*요즘엔 ChatGPT도 하나의 방법!

 

그래서 찾아보니,

torch.autograd.set_detect_anomaly(True)

 

자동으로 nan값이 나오는 곳을 찾아주는 명령어가 있지 않았겠는가!

 

이 명령어를 써놓기만 해도 nan값이 출력되는 순간 어느 backward에서 발생한건지 예쁘게 출력해준다!

 

내 경우에는 Pow와 관련있었는데, 최근 추가했던 functional.normalize가 떠올라 제거해주었더니 바로 문제가 해결되었다.