[Pytorch] Dataset을 만들어보자!
Pytorch에 대한 모델 아키텍처 구현을 하기 전에 dataset 객체를 만드는 방법을 먼저 소개하려 한다. 방법은 어렵지 않지만 batch단위의 training과 보기 좋은 전처리를 하기 위해서 꼭 필요한 과정이다.
큰 틀은 다음과 같다.
from torch.utils.data import Dataset
class CustomDataset(Dataset):
def __init__(self):
'''init some parameters'''
def __len__(self):
'''return length of datas'''
def __getitem__(self, index):
'''return data with preprocessing if needed'''
먼저 Dataset 모듈을 import 하고, 이를 상속하여 Dataset class를 만들어주어야 한다. 기본적으로 구현해야하는 함수는 위 3가지 인데, __init__는 class가 선언되면 우선적으로 실행되는 함수로 보통 class 내에서 쓰일 여러 변수들을 선언해주는 곳이다.
다음으로 __len__은 data의 크기를 return해주는 함수로 선언한 dataset 변수에 len()함수를 적용하여 data 크기를 return받게 해준다.
마지막으로 __getitem__은 데이터를 전처리하고 return하는 부분이다. index 변수는 데이터를 indexing할 때 사용된다.
백문이 불여일견, 해보면서 알아보자!
여기서 사용할 데이터셋은 CIFAR100이다. 데이터셋마다 load하는 방법과 전처리가 다르므로 그에 맞게 Dataset 클래스를 바꾸어주어야한다.
class CifarTrainDataset(Dataset):
def __init__(self, path, transform=None):
with open(path, 'rb') as f:
data = pickle.load(f, encoding='bytes')
self.transform = transform
self.x = data[b'data']
self.y = data[b'fine_labels']
먼저 클래스와 그 __init__함수를 정의하였다. CIFAR100 데이터셋은 이미지 형태가 아닌 byte형식으로 저장되어있기 때문에 위와 같은 과정을 거쳐서 이미지와 label 데이터를 받아와야 한다.
transform은 일단 ToTensor정도만 사용할 것이다. 여러가지 이미지 Transformation 함수가 존재하는데 이에 대한 자세한 내용은 다음을 참고하길 바란다(솔직히 transform에는 너무 기본적인 함수밖에 없어서 albumentation같은 모듈을 사용하는 것을 권장한다...)
https://pytorch.org/vision/stable/transforms.html
def __len__(self):
return len(self.x)
ds = CifarTrainDataset("./datasets/cifar100/train")
print(len(ds))
50000
다음으로는 __len__함수이다. 위와 같이 단순하게 데이터의 len을 반환해주면 밑에 처럼 dataset 변수에 len함수를 취해주어 데이터의 size를 받을 수 있다.
def __getitem__(self, index):
label = self.y[index]
r = self.x[index, :1024].reshape(32, 32)
g = self.x[index, 1024:2048].reshape(32, 32)
b = self.x[index, 2048:].reshape(32, 32)
image = np.dstack((r, g, b))
if self.transform:
image = self.transform(image)
return image, label
마지막으로 __getitem__함수이다. 여기서 데이터 x가 1차원 벡터이고 1024까지는 R채널, 1024~2048은 G채널, 2048~3060까지는 B채널로 이루어져 있기 때문에 위와 같은 전처리를 해준다. batch단위로 indexing해주는 것은 Dataloader에서 해주기 때문에 여기서는 단순하게 하나의 index라 가정하고 코딩을 해주면 된다.
이제 결과가 잘 나오는지 확인만 한번 해주면 끝!!
label은 19로 정상적으로 소(cattle)이 출력되는 것을 확인했습니다! Dataset 클래스 만들기 끝!