[Pytorch] Multi-gpu 사용하기

2022. 3. 23. 15:43코딩연습장/Pytorch

이번에 pytorch로 multi-gpu를 사용할 일이 생겨서 알아보게 되었다. 방법은 매우 간단하니 한번 살펴보자! ㅎㅎ

 

        NGPU = torch.cuda.device_count()
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        if NGPU > 1:
            self.model = torch.nn.DataParallel(self.model, device_ids=list(range(NGPU)))
        torch.multiprocessing.set_start_method('spawn')
        self.model.to(device)

 

간단하게 살펴보면 torch.cuda.device_count()는 현재 할당된 gpu 개수를 의미한다. 따라서 이 개수가 1개를 넘는다면 multi-gpu라고 인식하여 DataParallel구문을 넣어준다. 

 

torch.multiprocessing.set_start_method('spawn') 구문은 num_workers를 1개 이상으로 두었을 때 나올 수 있는 에러문 때문에 넣어준다. 보너스로 num_workers는 number of gpu * 4로 설정해주는게 가장 효율적이라고 한다.

 

이렇게 하면 multi-gpu 세팅 완료!!! 이게 가장 간단한 세팅이고 추가적으로 loss를 parallel하게 계산한다거나 DataParallel이 아닌 더 복잡한 분산시스템 모듈을 사용하여 구성하는 방법도 존재한다.

 

하지만 필자는 매우 큰 데이터셋을 사용하지 않기 때문에 위와 같은 세팅으로도 충분하고 더 자세한 정보를 알고 싶으면....구글이 도와줄 것이다.

 

230119 업데이트

**torch.multiprocessing.set_start_method('spawn') 은 모든 torch 구문 보다 이전에 선언되어야 깔끔하게 실행된다. 그렇지 않은 경우, cuda와 multi-process 쪽에서 에러가 날 수 있음..