2021. 7. 8. 14:03ㆍ코딩연습장/Keras
모듈로 제공되는 데이터셋을 사용하는 것은 슬슬 질릴 때가 왔다. 회사나 팀 차원에서 데이터셋을 구축하거나, 또는 Open Datasets을 다운로드하여 사용할 때, 우리는 그 데이터셋을 전처리하여 모델에 input으로 넣어주어야하는데 체계적인 어떠한 모듈이 있으면 좋을 것 같다.
출처: https://towardsdatascience.com/keras-data-generators-and-how-to-use-them-b69129ed779c
그것이 바로 Generator!
오늘은 그 Generator를 만들어 tiny-imagenet 데이터를 가져오고, 전처리하는 코드를 만들 것이다. 추후에 Data Augmentation을 할 때에도 Generator를 사용할 것이니 꼭 준비하도록 하자!
1. Class 정의
먼저 Generator 클래스를 만들고, kears.utils.Sequence 모듈을 상속받아야한다.
class Class_Generator(keras.utils.Sequence):
2. __init__()
다음으로는 __init__함수에서 하이퍼파라메터들을 정의해주고 데이터를 받아오자.
데이터 로드는 tiny-imagenet 기준으로 작성이 되었으며, 데이터셋에 따라서 x(데이터), y(label)을 나누어주는 코드를 작성하면 된다.
def __init__(self, src_path, input_shape, class_num, batch_size, is_train=False):
folder_list = os.listdir(src_path)
self.is_train = is_train
self.batch_size = batch_size
self.input_shape = input_shape
self.class_num = class_num
self.x, self.y = [], []
with open("./datasets/wnids.txt", "r") as f:
cls_list = f.readlines()
cls_list = [cls_object.replace("\n", "") for cls_object in cls_list]
for _, folder in enumerate(folder_list):
imgs = glob.glob(src_path+folder+"/*.JPEG")
cls = cls_list.index(folder)
self.x += imgs
self.y += list(np.full((len(imgs)), fill_value=cls))
self.on_epoch_end()
3. on_epoch_end()
상속받은 모듈의 하위 함수인 on_epoch_end()함수를 작성해주자. 이 함수는 epoch가 끝날 때마다 실행되는 특성을 가졌다. 이 generator에서는 이미지들(x)를 epoch마다 shuffle해주는 기능을 한다.
def on_epoch_end(self):
self.index = np.arange(len(self.x))
if self.is_train:
np.random.shuffle(self.index)
4. __len__(self)
이 함수는 step size(length of data / batch size)를 리턴하는 함수이다.
def __len__(self):
return round(len(self.x) / self.batch_size)
5. __getitem__(self, idx)
원래는 __getitem__함수는 객체를 indexing하기 위해 쓰인다.
여기에서는 fit_generator에서, 한 epoch 동안 0부터 __len__함수에서 얻은 step size까지를 순서대로, idx 인자로, 받아서 데이터를 batch단위로 나누고, 전처리와 증강 작업 후 실제 모델에 input이 될 데이터가 return 되는 부분이다.
def __getitem__(self, idx):
batch_x, batch_y = [], []
batch_index = self.index[idx*self.batch_size:(idx+1)*self.batch_size]
for i in batch_index:
batch_x.append(self.x[i])
batch_y.append(self.y[i])
out_x, out_y = self.data_gen(batch_x, batch_y)
return out_x, out_y
6. data_gen()
데이터의 전처리와 증강 처리를 해주는 부분이다. 여기서 증강은 다루지 않았다(다음 시간에..)
전처리로서 이미지를 255.0으로 나누어 정규화하고, one_hot encoding을 쓰기 위해 keras.utils.to_categorical함수를 사용하였다.
def data_gen(self, x, y):
input_x = np.zeros((self.batch_size, self.input_shape[0], self.input_shape[1], self.input_shape[2]), dtype=np.float32)
imgs = []
for idx in range(len(x)):
img = cv2.imread(x[idx])
input_x[idx] = cv2.resize(img, (self.input_shape[0], self.input_shape[1])) / 255.0
input_y = to_categorical(y, num_classes=self.class_num)
return input_x, input_y
7. train.py 코드 수정
위 작업으로써 Generator는 완성이 되었다. 이제 trian code를 수정하여 Generator로 학습을 시켜보자.
먼저 train에 쓰일 generator와 valid로써 쓸 generator를 각 선언해준다. 그 후에 fit_generator를 통해서 다음과 같이 입력해주면 된다.
**max_queue_size나 workers의 경우, 컴퓨터의 사양에 따라 바꾸어주어야한다. 잘 모른다면 default로 정의되어있으니 그냥 쓰자.
train_gen = Class_Generator("./datasets/train/", input_shape, class_num, batch_size, augs=augs, is_train=True)
valid_gen = Class_Generator("./datasets/val/", input_shape, class_num, batch_size, augs= [], is_train=False)
step_size = train_gen.__len__()
print("Step size: ", step_size)
'''train'''
rs = Resnet(input_shape=input_shape, class_num=class_num, layer_num=50, weight_decay=weight_decay)
model = rs.resnet()
model.compile(optimizer=optimizer, loss=loss, metrics=[metrics])
model.fit_generator(train_gen, validation_data=valid_gen, epochs=epochs,
max_queue_size=20, workers=4,
callbacks=[TensorBoard(log_dir),
# ReduceLROnPlateau(monitor="val_loss", factor=0.5, patience=10),
CyclicLR(base_lr=1e-5, max_lr=1e-3, step_size=step_size*8, mode="triangular2"),
ModelCheckpoint(weight_save_file, monitor="val_loss", save_best_only=True)])
이제 학습이 되는 것을 볼 수 있을 것이다. 이와 같이 Generator를 사용하면 데이터 로드 및 전처리, 증강을 적용하는 것이 체계화 되어, 대용량의 데이터 처리가 편해진다는 장점이 있다. 크기가 큰 데이터셋을 사용한다면 꼭 만드는 것을 추천한다.
'코딩연습장 > Keras' 카테고리의 다른 글
[데이터 증강] IMGAUG 모듈 (0) | 2021.07.08 |
---|---|
[데이터 준비] Tiny-Imagenet (3) | 2021.06.29 |
[Classification] ResNet 코딩 (0) | 2021.06.29 |
[Classificaiton] VGGNET 모델 코딩 (2) | 2021.06.24 |
[데이터] 데이터 준비 단계(CIFAR-100) (0) | 2021.06.24 |