torch.serialization.add_safe_globals([argparse.Namespace]) Error

위와 같이 학습 명령을 내리자마자, 오류가 발생했는데..

PyTorch 2.6 버전부터 torch.load() 기본 옵션이 weights_only=True 로 바뀌었다고 한다.

기존 .pt 파일은 모델 가중치(state_dict)와 함께 argspace.Namespace 도 같이 저장된다.

그런데 Namepsace는 pickle로 저장된 Python 객체이다. 즉, weights_only=True 옵션일 때는

객체를 로드할 수 없다고 에러가 발생할 수 밖에 없다.

(Problem)
check_point = torch.load(model)

(solve)
check_point = torch.load(model, weights_only=False)

이렇게 옵션 처리하고 실행하면 학습이 진행된다.

하지만 꼭 주의해야할 사항은, weights_only 옵션을 False로 하게 되면 weight 외의 다른 내용도 포함된다는 것이기 때문에 weight 파일이 같은 머신에서 학습된 파일이면 괜찮겠지만, 다운로드받은 인터넷 파일이면
보안 이슈가 발생할 우려가 되어, 이 옵션을 변경하지 말라고 한다.

답글 남기기

이메일 주소는 공개되지 않습니다. 필수 필드는 *로 표시됩니다