[논문 리뷰] Tokens-to-Token ViT

2022. 5. 5. 22:57논문리뷰/Vision Transformer

이번에 소개할 논문은 Tokens-to-Token ViT라는 논문이다. 

https://arxiv.org/abs/2101.11986

 

Tokens-to-Token ViT: Training Vision Transformers from Scratch on ImageNet

Transformers, which are popular for language modeling, have been explored for solving vision tasks recently, e.g., the Vision Transformer (ViT) for image classification. The ViT model splits each image into a sequence of tokens with fixed length and then a

arxiv.org

기존 ViT는 좋은 performance를 보여주지만 매우 큰 데이터셋(JFT)이 아닌 중간 크기의 데이터셋(ImageNet)을 학습시키면 CNN보다 성능이 더 떨어지는 문제가 있었다. 논문에서는 이러한 문제의 원인으로 simple tokenization과 attention backbone을 뽑았고 이를 중심으로 Tokens-to-Token Vision Transformer를 새롭게 제안한다. 한번 살펴보도록 하자.


1. Vision Transformer(ViT)

이번 논문은 ViT의 문제점을 기반으로 mechanism을 소개하기 때문에 ViT를 먼저 읽고 오기를 권장한다.

https://aistudy9314.tistory.com/67?category=1044949 

 

[논문 리뷰] An Image is Worth 16x16 Words: Transformers for image Recognition at scale(VIT)

최근들어 필자가 가장 많은 관심을 가지고 있는 vision transformer가 처음 등장하는 논문이다. 물론 이 논문 이전에도 transformer를 vision 분야에 적용한 시도들이 있었지만, 실제로 vision분야에서 transfo

aistudy9314.tistory.com

 

간단하게 설명해서 ViT는 이미지를 patch 단위로 쪼개어서 각 patch를 하나의 token으로 두고 attention을 적용시키는 방법을 사용한다. ViTCNN을 outperform하는 성능을 보여주었는데 아쉽게도 매우 큰 데이터셋을 사용했을 때만 좋은 결과가 나오고 ImageNet과 같은 midsize 데이터셋에서는 CNN보다 떨어지는 결과를 내보냈다.

 

Why does ViT have inferior performance in midsize dataset?

왜 midsize dataset에서는 CNN보다 성능이 떨어질까? 그 원인으로 많은 문제들이 추측되고 있지만 논문에서는 2가지를 main limitation으로 지적한다.

  1. ViT는 tokenization을 할 때 overlap없이 patch를 sliding하는 "hard split"을 사용하는데, 이러한 방식이 ViT가 image local structure(ex. edges and lines)를 modeling할 수 없도록 만들고 따라서 더 많은 training samples를 요구한다.
  2. ViT의 Attention backbone은 vision tasks에 잘 맞도록 design 되어있지 않다

저자들은 위 가정에 대해서 ViT와 ResNet의 learned features 차이를 조사하기 위한 pilot study도 진행하였는데, 

 

 

ResNet의 경우 edges, lines, textures와 같은 local structure(green box)가 bottom layer(conv1)부터 middle layer(conv25)까지 다 잘 잡히는 것을 볼 수 있고.

 

반면에 ViT는 global relations(ex. 강아지 자체)는 잘 잡아내지만 structure information은 거의 modeling 되지 않은 것을 볼 수 있는데, 이러한 결과는 ViT가 고정된 크기로 이미지를 tokenization했을 때 local structure를 무시한다는 것을 보여준다. 또한 많은 channels가 zero-value(red box)를 가지고 있는데, 이는 ViT의 backbone이 ResNet만큼 efficient하지 않고 따라서 training samples가 충분하지 않을 때 제한된 feature richness를 제공한다는 것을 말한다.

 

2. Tokens-to-Token ViT

 

 

논문에서는 위에서 제기된 simple tokenization과 inefficient backbone의 limitation을 해결하기 위해서 Tokens-to-Token Vision Transformer(T2T-ViT)를 제안한다. T2T-ViT는 두 main components로 이루어져있는데

 

첫번 째는 tokens의 길이를 progressive하게 줄이고 이미지의 local structure information을 담을 수 있도록 하는 layer-wise "Tokens-to-Token module"이고, 두번 째는 T2T module로부터 tokens의 global attention relation을 끌어내는 효율적인 "T2T-ViT backbone"이다. 

 

2.1 Tokens-to-Token module: Progressive Tokenization

T2T module은 Re-structurization stepsoft split(SS) step으로 나누어져 있다. 이를 하나씩 살펴볼 것인데 그림과 같이 보는 것이 이해가 더 빠르고 쉽기 때문에 추천한다.

Re-structurization

먼저 input은 이전 transformer layer의 tokens $T$이고 이는 T2T-transformers를 거쳐 $T'$으로 transform된다.

$$T' = MLP(MSA(T))$$

Transformer는 별 다른 것은 없고 똑같이 Multihead Self-Attention(MSA)을 거치고 Multi-layer Perceptron(MLP)을 적용 해준다. 이 때 모두 Layer Norm을 사용하며 vanila Transformer가 아닌 Performer를 사용해도 된다. 

 

다음으로 이렇게 나온 tokens $T'$을 이미지 형태로 reshape해주면 된다.

$$I = Reshape(T')$$

여기서 Reshape는 tokens $T' \in \mathbb{R}^{l \times c}$를 $I \in \mathbb{R}^{h \times w \times c}$로 reorganize하는 함수이다.

 

Soft Split(SS)

Re-structurization에서 얻은re-structurized image $I$를 다시 tokenization해주는 부분이다. $I$로부터 tokens을 만들 때 information loss가 생길 수 있기 때문에 patch를 overlap하면서 split을 해주고, 이로 인해 각 patch는 surrounding patches와 correlation을 가지게 되면서 surrounding tokens 간의 강한 correlation을 갖도록 하는 prior가 생긴다. 그리고 각 split patches를 하나의 token으로 concat함으로써 local information이 surrounding pixels과 patches로부터 aggregate될 수 있도록 한다.

 

soft split을 할 때 각 patch size를 $k \times k$, overlapping size를 s, padding을 p라고 하면 output tokens $T_o$의 길이는 다음과 같다.

**k-s는 convolution에서 stride와 비슷한 개념이라고 보면 된다.

이렇게 나온 token output $T_o$는 $\mathbb{R}^{l_o \times ck^2}$ 차원을 가질 것이다. 

T2T module

위의 Re-structurization과 Soft Split을 반복적으로 수행함으로써, T2T module은 점진적으로 tokens의 길이를 줄일 수 있고 이미지의 spatial structure를 바꿀 수 있다(patch를 통해 surrounding tokens을 concat하면서 spatial structure가 바뀌는 것을 말하는 듯하다).

$$T'_i = MLP(MSA(T_i))$$

$$I_i = Reshape(T'_i)$$

$$T_{i+1} = SS(I_i), \ \ \ i = 1, \dots , (n-1)$$

**input image $I_0$에 대해서는 바로 SS를 적용한다. 마지막 iteration후에 T2T-ViT backbone이 $T_f$의 global relation을 modeling할 수 있도록 output tokens $T_f$는 고정된 길이를 가진다.

 

추가적으로 T2T module의 tokens 길이가 ViT(16x16)보다 크기 때문에 Multiply-Accumulate(MAC)와 memory usage가 크다고 한다. 저자들은 이를 줄이기 위해서 channel dimension을 줄이고 옵션으로 Performer와 같은 efficient Transformer를 사용하였다고 한다. 

 

2.2 T2T-ViT Backbone

저자들은 vanila ViT에서 사용되었던 backbone이 well-design되지 않았다고 생각했기 때문에 CNN의 여러 architecture designs을  기반으로 ViT를 위한 efficeint architecture designs를 고려하였다. 

 

  1. feature richness와 connectivity를 향상시키기 위한 DenseNet의 dense connection
  2. channel dimension과 head number를 바꾸기 위한 Wide-ResNets 또는 ResNeXt
  3. Squeeze-an-Excitation Networks
  4. GhostNet

ViT에서 위 structure를 사용한 details는 논문 appendix에 있으니 참고하길 바란다. 저자들은 위 구조들을 사용하여 여러 실험을 해보았는데, 최종적으로 deep-narrow structure를 사용하는 것이, 간단히 channel dimmensions을 줄일 수 있어서 channels간 redundancy를 많이 감소시킬 수 있고 layer depth를 높여주어 ViT의 feature richness도 향상시킬 수 있다고 한다.

** SE block과 같은 channel attention도 성능 향상이 있었지만 deep-narrow structure가 더 효과적이라고 한다.

 

Wide-resnet block

**Wide-resnet에서 사용되었던 shallow-wide는 convolution channels을 기존보다 더 크게 주고 layer depth를 낮추는 방법인데, deep-narrow는 그와 반대로 channels을 낮추고 layer depth를 높이는 방법이다.

3. T2T-ViT Architecture

t

T2T-ViT는 Tokens-to-Token moduleT2T-ViT backbone 두 part로 나누어져있다. 위 이미지를 예를 들면 n=2로써 두 개의 Transformer layers를 두었고 3번의 soft splits에 사용되는 patch size는 $P=[7, 3, 3]$, overlapping size는 $S=[3, 1, 1]$을 사용하였다. 이 결과 $224 \times 224$ 크기의 이미지를 $14 \times 14$ 크기로 줄일 수 있다고 한다. 

**이미지 크기를 바탕으로 계산해보았을 때 패딩은 [3, 1, 1]을 주었다.

 

Tokens-to-token module의 최종 layer에서는 fixed tokens $T_f$가 나오며 이를 class token과 concat하고 Sinusoidal Position Embedding(PE)를 더해준 후에 T2T-ViT backbone으로 보낸다.

**E: Sinusoidal Position Embedding, LN: Layer Norm, fc: fully-connected layer

 

4. Experiments

이번 논문은 필자가 논문 발표를 해야하기 때문에 experiments부분도 자세히 다룰 것이다 ㅎㅎ.

4.1 T2T-ViT on ImageNet

(왼쪽) T2T-ViT vs ViT, (오른쪽) T2T-ViT vs Resnet

위는 ImageNet에서 T2T-ViT를 ViT랑 ResNet과 비교한 결과이다.

 

먼저 ViT와의 비교에서는 T2T-ViT가 훨씬 적은 parameter수와 MAC를 가지면서 더 높은 performance를 보여주었고, DeiT는 teacher model로써 큰 CNN models가 필요한데 반해 T2T-ViT는 그렇지 않음에도 불구하고 비슷한 MAC로 더 높은 accuracy를 보여주었다. 

 

두번째로 Resnet과의 비교에서는 T2T-ViT가 비슷한 model size와 MACs에서 1.4%~2.7%의 performance gain을 얻은 것을 볼 수 있다. 

 

T2T-ViT vs Mobilenet

MobileNet과도 비교를 하였는데, 비슷하거나 더 높은 performance를 보여주고 있다. 하지만 MobileNet보다 더 큰 MACs를 필요로 하는데, 저자가 말하길 T2T-ViT는 오직 hidden dimension과 MLP ratio, depth of layers만 줄여서 model size만 작게 만들었고 efficient convolution과 같은 special operations을 사용하지 않았다는 점에서 T2T-ViT 또한 lite model로써 매우 훌륭하다고 주장한다. 추가로 Distillation도 적용하였을 때 더 성능이 높아지는 것을 볼 수 있다.

 

Transfer Learning

 

 

또한 저자들은 학습한 T2T-ViT를 CIFAR-10과 CIFAR-100과 같은 downstream datasets에 transfer learning을 해보았는데, ViT와 비교했을 때 더 높은 performance를 보여주는 것을 알 수 있다.

 

4.2 From CNN to ViT

 

 

저자들은 Vision Transformer를 위한 efficient backbone을 찾기 위해서 위에서 말한 4가지의 model structure designs를 ViT에 적용해보았다. 결과 Deep-narrow structure와 SE-Network가 효과적이었는데, deep-narrow structure가 ViT에 비해서 0.9% 더 높은 성능을 가지면서 model size와 MACs를 거의 2배 가까이 줄임으로써 가장 좋았다고 한다.

 

다른 structure를 ViT로 바꾸었을 때의 결과들도 살펴보자면, 먼저 Wide-Resnet의 Shallow-wise structure는 performance에서 8%의 큰 성능 저하가 있었고 DenseNet의 dense connection 또한 performance가 저하되는 결과를 초래하였다고 한다. SEblock의 경우 baseline보다 높은 성능을 보여주었고 이는 channels attention이 CNN과 ViT 모두에 도움이 된다는 것을 말해준다. 그리고 ResNeXt 구조는 어떻게 보면 multi-head attention과 같은데, 따라서 저자들은 더 많은 heads(ex.32)를 주어보았지만 약간의 성능 향상이 있는 것에 비해 GPU memory를 크게 증가시킨다고 한다. Ghost operation은 중복되는 feature maps을 cheap operation을 통해 만들어주는 mechanism인데 parameter 수와 MACs는 줄여주지만 accuracy도 같이 줄어든다고 한다(Resnet과 T2T-ViT 공통). 

 

4.3 Ablation Study

 

 

마지막으로 T2T-module의 effectiveness를 증명하기 위한 실험을 보여주는데, T2T-ViT-$14_{wo}$는 T2T-module을 제거한 모델이고 T2T-ViT-$14_c$는 Soft Split이 Convolution과 비슷하기 때문에 이를 convolution으로 대체한 모델이다. 참고로 baseline T2T-ViT는 performer를 사용한 모델이고 T2T-ViT-$14_t$가 transformer를 사용한 모델이다.

 

결과를 보면, T2T-ViT-$14_c$가 T2T-ViT와 T2T-ViT-$14_t$보다 0.5%~1.0% 정도로 성능이 더 낮았고 T2T-ViT-$14_{wo}$이 가장 낮은 accurcay를 보여주었다. 

 

추가적으로 Deep-Narrow structure를 Shallow-wide structure와 비교한 결과도 보여주고 있다.