[Pytorch] VIT Pretrained model 사용하기
이번에 Vision Transformer를 사용할 기회가 생기면서 Pretrained model에 대해서 알아보았다. 방법은 매우 간단하니 누구든 따라할 수 있을 것이다!!
Official Github
바로 모델만 사용하여 학습을 해야하는 것이 아니라 모델이 어떠한 구조로 이루어져있는지 또는 pretrained model이 필요없다면 다음 official github를 참고하는 것을 추천한다.
https://github.com/lucidrains/vit-pytorch
vit 뿐만 아니라 여러가지 파생 모델들도 대부분 다 구현되어있기 때문에 매우 좋은 것 같다.
Easy Coded VIT
좀 더 쉽고 직관적인 vit github를 원한다면 다음 github도 추천하겠다! 다만, 이는 공부 목적으로 만들어졌기 때문에 위 github보다 성능면에서 떨어질 수 있다!
https://github.com/FrancescoSaverioZuppichini/ViT
Pretrained VIT
Official Github에서 timm을 사용하기를 권장하고 있다. 따라서 timm 모듈을 먼저 install 해주자!
pip install timm
timm은 pytorch로 구현된 모델들의 저장소라고 간단하게 생각하면 된다. 많은 모델들을 이 모듈을 통해 불러올 수 있고, 여기서 pretrained model도 사용할 수 있다!
간단하게 어떤 모델들이 있는지 아래와 같은 코드로 확인해 볼 수 있다.
model_names = timm.list_models(pretrained=True)
print(model_names)
'''
['adv_inception_v3',
'cspdarknet53',
'cspresnext50',
'densenet121',
'densenet161',
'densenet169',
'densenet201',
'densenetblur121d',
'dla34',
'dla46_c',
...
]
'''
이제 VIT 모델을 찾아서 만들어주기만 하면 끝! 매우 간단하다. vit baseline 코드는 다음과 같다.
import timm
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
model = timm.create_model('vit_base_patch16_224', pretrained=True)
config = resolve_data_config({}, model=model)
transform = create_transform(**config)
model은 이제 여다른 pretrained model을 불러왔을 때와 같이 사용하면 되고, config와 transform의 경우 timm에서 기본적으로 사용하는 옵션들을 가지고 있다. 확인해봤을 때 별다른 augmentation은 추가되어있지 않으니 따로 transform을 만들어주는 것도 나쁘지 않다.
이렇게 VIT pretrained model을 사용하는 방법에 대해서 간단하게 살펴보았다. timm에는 정말로 많은 모델들이 있으므로 한번 살펴보기를 권장한다. 심지어 너무 간단해서 뭐...ㅎㅎ 별로 설명할게 없다.