논문리뷰/Vision Transformer

[논문 리뷰] GroupViT

인공지능스타터 2023. 4. 28. 02:59

정말 오랫만에 논문 리뷰글을 쓰려고 한다. 최근 블로그를 하지 못했는데 논문을 발표할 기회가 생겨서 겸사 블로그에도 글을 정리하려고 한다.

 

이번에 리뷰할 논문은 GroupViT: Semantic Segmentation Emerges from Text Supervision이다.

https://arxiv.org/abs/2202.11094

 

GroupViT: Semantic Segmentation Emerges from Text Supervision

Grouping and recognition are important components of visual scene understanding, e.g., for object detection and semantic segmentation. With end-to-end deep learning systems, grouping of image regions usually happens implicitly via top-down supervision from

arxiv.org

 


0. Summary

자세한 내용을 들어가기 앞서 간단하게 해당 논문이 어떠한 내용을 담고 있는지 알아보자. 논문 제목에서도 유추해볼 수가 있는데, semantic segmentation을 text supervision으로부터 배워보자! 가 핵심이다. 

 

기존 방식들은 주로 이미지를 pixel-level로 labeling한 segmentation label을 사용하여 학습을 하였는데, 이는 cost측면에서 매우 비싸다는 단점이 있다. 논문에서는 최근 powerful한 performance를 보여주고 있는 text supervision과 Vision Transformer를 활용하여 모델이 semantic한 grouping을 할 수 있도록 학습시킨다. 결과적으로 inference 단계에서도 이미지와 text 만으로 segmentation task를 풀어낼 수 있고, zero-shot prediction까지 좋은 성능을 보여준다.

 

위 그림은 GroupViT가 segmentation을 grouping하는 과정이다. 처음에는 여러 개의 작은 group region으로 시작해서 점진적으로 같은 group끼리 merging되는 것을 볼 수 있다.

1. Preliminary

1-1. Segmentation

먼저 segmentaiton이란 무엇일까? Vision 분야에 종사하는 사람들이라면 classification과 object detection task와 함께 많이 들어본 topic 중 하나에 해당될 것이다.

 

출처: http://machinelearningkorea.com/2019/07/13/%EC%9D%B4%EB%AF%B8%EC%A7%80%EB%B3%84-%EB%B6%84%EB%A5%98%EB%AC%B8%EC%A0%9C/

 

위 그림을 보면 쉽게 이해가 가능한데, 간단하게 정리해서

  1. Classification and Localization: Object의 위치와 그 class를 예측하는 것.
  2. Object Detection: 객체를 구분하여 각각의 위치와 class를 예측하는 것.
  3. Semantic Segmentation: pixel 단위로 class를 예측하는 것.
  4. Instance Segmenation: 객체를 구분하여 pixel 단위로 class를 예측하는 것.

이라 말할 수 있겠다. 이번 논문에서는 semantic semgentation을 다루므로 이를 조금 더 살펴보자면,

출처: https://www.jeremyjordan.me/semantic-segmentation/#dilated_convolutions

왼쪽 그림과 같이 pixel level로 (그림에서는 region으로 나오지만) class를 예측하되, 해당 class에 해당하는 object 간의 구별은 하지 않는다 (구별을 하는 것이 Instance Segmentation). 가장 simple하게 이를 딥러닝으로 푸는 방법은 간단하게 오른쪽 그림으로 볼 수 있는데, encoder 또는 encoder-decoder (주로 UNet) 구조를 사용하여 pixel-wise prediction을 하고 이를 ground truth인, pixel-level로 class를 일일이 labeling한, segmentation mask와의 차이를 줄이도록 학습을 하게 된다.

 

필자도 classification과 landmark에 대한 labeling을 해본 경험이 있는데, pixel-level로 class를 찍어주어야 한다는 것이 도대체 얼마나 힘든 작업일지 상상하지 않아도 알 것 같다. 물론, 최근에야 좋은 tools가 많아서 그 수고스러움이 덜할 수는 있겠지만 그런 점을 감안하더라도 segmenation labels을 만드는 일이 비용과 시간 측면에서 매우 비싸다는 것은 모두 동의할 것이다.

 

1-2. Grouping

출처: Combining Top-down and Bottom-up Segmentation

Grouping은 말 그대로 픽셀을 어떤 기준에 따라서 뭉치게 만드는 것을 말한다. Grouping을 사용해 segmentation을 풀어내는 것은 딥러닝이 나오기 전부터 사용되던 방법인데, 이는 크게 두 방법으로 구분 된다.

  • Bottom-up: 이미지를 먼저 여러 클래스를 가지는 작은 영역으로 쪼갠 후에, 이를 합쳐나가는 방식.
  • Top-down: 이미지로부터 큰, 대략적인 features를 추출한 뒤에 이를 통하여 세밀한 segmentation을 찾아가는 방식.

이러한 explicit하게 grouping을 하는 방식들은 딥러닝이 나오면서 자연스럽게 implicit하게 학습이 될 것이라 치부되고 크게 구분짓지 않게 되었다. 예를 들어, Mask-RCNN이나 FCN과 같은 모델의 prediction output과 ground truth labels의 차이를 줄여주는 방식으로 단순하게 학습함에도 불구하고 매우 좋은 성능을 보여주기 때문에, 모델이 알아서 위와 같은 grouping을 학습하겠구나 라고 생각하게 되는 것이다.

아니 그러면 그냥 딥러닝 쓰면 되는거 아닌가요!!!

물론 여전히 딥러닝을 통해 supervised로 학습시키는 방식이 성능 측면에서는 매우 우수하다. 하지만 위에서도 언급했듯이 큰 단점이 존재하는데 먼저 (1) per-pixel humal labels가 필요하다는 점(2) unseen categories에 대해서 generalize 성능이 떨어진다는 점이다. 

 

논문에서는 segmentaiton labels를 사용하지 않는 다는 점을 보완하기 위해서 explicti하게 grouping을 학습할 수 있도록 Vision Transformer와 논문에서 제안하는 grouping module을 사용하고 있다.

1-3. Text Supervision

https://arxiv.org/abs/2103.00020

 

Learning Transferable Visual Models From Natural Language Supervision

State-of-the-art computer vision systems are trained to predict a fixed set of predetermined object categories. This restricted form of supervision limits their generality and usability since additional labeled data is needed to specify any other visual co

arxiv.org

 

최근 CLIP: Learning Transferable Visual Models From Natural Language Supervision 논문이 나오면서 text supervision으로 부터 visual representation을 학습하는 방식이 transfer learning에서 매우 뛰어난 성능을 보이고 있다. 간단하게만 설명을 하자면, 이미지와 text를 같은 embedding space에 mapping시킨 후에, positive pair의 cosine similarity는 커지도록 negative pair의 cosine similarity는 작아지도록 constrative loss를 사용하여 학습을 진행하게 된다. 이렇게 학습된 모델은 zero-shot tranfer에서도 높은 성능을 보여준다.

 

해당 논문에서도 CLIP의 영향을 크게 받아, 학습 과정에서 비슷한 방식으로 constrative loss를 사용하고 있다. 따라서 높은 zero-shot transfer performance를 갖게 되고, unseen categories에 대해서도 segmentation prediction이 가능한 것이다.

 

1-4. Visual Grounding

출처: Revisiting Visual Grounding

visual grounding은 어떻게 보면 text supervision과 비슷한데, image region과 text 간의 correspondence를 배우는 것을 목표로 한다. 간단한 예로, pretrained object detection model로부터 얻어낸 bounding box들과 text 간의 비교를 통해서, 해당되는 object의 좌표만 예측하도록 학습하여 두 modal간의 correspondence를 배우게 된다.

 

위 방식도 image의 ground truth labels를 필요로 하지 않고 text supervision으로부터 정보를 학습하게 되지만, 본 논문과는 크게 2가지의 차이점이 있다. 먼저, 본 논문은 image-text pair를 web에서부터 무분별하게 가져와서 noisy한 데이터를 사용하지만 visual grounding은 사람에 의해 잘 선별된 데이터 pair를 필요로 한다. 두번 째로, GroupViT는 오직 text supervision으로부터 object segments 정보를 추출해내지만, visual grounding은 사전에 학습된 detector를 요구하기 때문에 간접적으로 다른 데이터셋을 참조했다고 볼 수 있다.

1-5. Vision Transformer

논문의 architecture가 Vision Transformer (ViT)의 변형이기 때문에, ViT를 살짝 보고 넘어가려고 한다. Transformer라는 architecture가 Natural Language Processing (NLP) 분야에서 큰 성공을 거두면서, Vision에서도 이를 활용해보려는 여러 시도들이 있었다. 그 중에서 가장 시초가 되는 (실제로 시초는 아니지만 근본) 논문이 Vision Transformer인데, 나온 이후로 그 variants가 쏟아져 나오고 있다.

https://arxiv.org/abs/2010.11929

 

An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale

While the Transformer architecture has become the de-facto standard for natural language processing tasks, its applications to computer vision remain limited. In vision, attention is either applied in conjunction with convolutional networks, or used to rep

arxiv.org

 

Transformer에 대해서는 다음 글에서 다루고 있으니 미리 읽어보면 더 이해가 빠를 것이다.

https://aistudy9314.tistory.com/63

 

[논문리뷰] Attention is All you need

Transformer는 최근 들어 자연어 처리와 비전 분야 모두에서 월등한 성능을 보이면서 발전하고 있다. 이러한 Transformer를 처음으로 제안한 논문이 바로 "Attention is all you need"이 되시겠다 ㅎㅎ. 자연

aistudy9314.tistory.com

 

모든 pixel에 대해 attention을 하게 되면 resolution의 제곱 수에 해당하는 computational complexity를 갖게 된다. 이는 연산 측면에서 너무나도 비싸기 때문에 ViT에서는 이미지를 16x16 size의 sub-images로 분할하여 이를 tokenizing하는 방식으로 사용한다. 이렇게 tokenized된 embedding vectors를 token이라고 부르고, 이 token들을 그대로 attention block에 적용하는 것이 ViT의 전부이다. ViT는 전체 이미지를 tokenize해서 한번에 attention block에서 처리하기 때문에 이미지를 global하게 볼 수 있다는 장점을 가지고 있다. 필자가 multi-modal 분야를 연구하지 않아서 자세히는 모르지만, modal 간의 aggregation을 할 때 cross-attention이 좋은 성능을 보여주고 있는 것 같다.

 

2. Method

method를 설명하기 앞서, GroupViT의 problem overview를 보도록 하자.

이미지는 GroupViT, text는 Text encoder를 통해 embedding이 되고 그 output들로 contrastive loss를 구하여 representation learning을 하는 점으로보아 전체적은 flow는 CLIP과 매우 유사하다고 볼 수 있다. 이렇게 학습된 모델을 zero-shot tranfer하여 unseen category에 해당하는 labels와 이미지에 대해 segmentation labels을 예측하게 된다.

 

2-1. Grouping Vision Transformer

GroupViT의 핵심이 되는 model architecture이다. 논문에서는 기존 Vision Transformer의 architecture에 group token과 grouping block을 추가하여 explicit하게 모델이 grouping을 학습하도록 유도하고 있다. 위 그림이 모든 것을 설명해주고 있지만 자세히 flow를 따라 살펴보도록 하겠다.

 

(1) 먼저 ViT와 똑같이 이미지를 patch 단위로 나누고 tokenizing을 해준다. 여기서 추가된 점은 group tokens인데, 이는 learnable parameters로써 image tokens들과 concat된다. 다음으로 (2) concat된 tokens를 transformer layers 입력으로 주게 되고, 여기서 image tokens들의 global information을 group tokens와 같이 aggreagte함으로써 group tokens이 유의미한 정보를 포함하게 된다. 그리고 이를 다시 image tokens과 group tokens로 분할하여 (3) Grouping block에 입력으로 준다. Grouping block 또한 결국 attention인데, group tokens와 image tokens 간의 cross attention을 하고 gumbel softmax를 통해 각 image tokens가 어느 group에 속할지에 대한 확률을 곱해준다고 생각하면 된다. 이러한 과정은 여러 개의 small groups를 더 큰 groups으로 merge하는 역할을 하며, output이 group tokens의 개수를 따라감으로 output tokens의 개수가 downsample된다. 그리고 (2)-(3)의 과정을 반복한다. 마지막 layer에서는 최종 output tokens에 대해서 transformer block을 한 후, grouping block 대신 average pooling을 하여 최종 representation을 내보낸다.

 

위 이미지 (a)를 보면, group tokens들이 regular-grid structure (e.g. rectangle)이 아닌 arbitrary shape의 image segment 정보를 포함하고 있는데, 이는 transformer의 flexible한 특성을 더 잘 이용할 수 있게 해주는 GroupViT의 장점 중 하나라 할 수 있다.

 

*Gumble softmax

여기서 gumble-softmax라는 것을 사용하는데, 이는 stochastic term을 stochatic과 deterministic term 두개로 나누어 미분 가능하도록 만드는 trick이다. 필자는 이 논문에서 단순히 $\gamma$ 를 gumber distribution으로 사용하는 거 외에 굳이 왜 쓰는지 잘 모르겠다...stochatic하게 sampling하는 부분이 없어 보이는데...

 

추가적으로 hard label인 one-hot을 사용하기 위해서 다음과 같은 trick을 사용한다.

**one-hot은 discrete하기 때문에 미분이 안되므로

 

$$ \tilde{A}^l = \text{one-hot}(A^l_{argmax})+A^l-\text{sg}(A^l)$$

그냥 discrete value에 continuous 값을 더하고 다시 빼주는데, $A^l$의 gradient를 가질 수 있기 때문에 grouping block이 미분 가능하도록 해준다.

**sg는 stop gradient operator

 

hard label을 사용하지 않고 그냥 softmax output (soft label)을 사용하는 방법도 있는데, 전자가 더 성능이 좋다고 한다.

 

2-2. Learning from Image-Text pairs

위에서 간단히 설명했듯이, Image-Text pairs는 constrastive loss를 통해 학습된다. Image input은 GroupViT를 통해서 representation을 뽑고, Text는 Transformer를 통해 mapping하여 common embedding space에 두 representation을 놓는다. 그 다음 positive image-text pairs를 similarity를 최소화하, negative pairs에 대해서는 최대화하도록 학습을 한다.

 

해당하는 loss는 다음과 같다.

여기서 두 loss term이 존재하는 것을 볼 수 있는데, 차이점은 $L_{T \rightarrow I}$의 경우 prompting engineering을 사용하여 생성한 추가 texts를 사용한다는 점이다. 위 이미지에서 기존 text에서 적절한 noun을 뽑아내어 prompted sentences를 만드는 것을 볼 수 있다. 여기서 $\tau$는 temparature로 높은 값을 가지면 uniform에 가까운 output을 내주고, 작아질 수록 discrete한 확률값을 내뱉는다.  

 

2-3. Zero-shot Transfer to Semantic Segmentation

전체적인 method에 대한 설명은 이걸로 끝이다. 다음으로 zero-shot tranfer에 이를 어떻게 사용할 지를 살펴보도록 하자. 

방법은 매우 간단한데, 학습된 각 encoder로부터 representation을 만들어주고 두 modal간의 similarity를 통해 해당되는 group의 class를 부여한다. 여기서 dataset classes는 사전에 정의되어있고 이를 prompt를 통해 training에 사용했던 prompted text와 같은 형태로 만들어준다. 또한 GroupVit의 output은 average pooling을 하기 전의 값을 사용한다.

 

근데 논문에서는 Decoder에 대한 설명이 없다.....도대체 segmentation output은 어떻게 만드는지 알려주지 않아서 official github code를 뒤져봐야했는데, 자세히 설명하기는 어렵지만 대략적으로 다음과 같이 동작한다.

 

먼저 GroupViT의 가장 마지막 Grouping Block에서 gumbel softmax output에 해당하는 attention map을 가져온다. 이 attention map을 one-hot으로 바꾸어주는 작업을 추가로 해준다. 이 때, attention map의 크기는 $H \times W \times G$의 크기를 가진다. 두번 째로 GroupViT의 output (GAP 이전) 값과 embedde text representation과 dot product를 통해 similarity matrix를 만들어준다. 여기도 추가로 후처리가 있는 것 같은데 그에 대한 내용은 생략하고 이 similarity matrix의 shape가 $G \times C$가 되므로, 이를 attention map과 dot product 시키면 $H \times W \times C$가 되고 이를 segmentation output으로 내보낸다. 더 자세한 알고리즘이 알고 싶다면 해당 Github의 코드를 분석해보길 바란다.

**G는 그룹 개수, C는 class 개수라고 보면 된다.

 

https://github.com/NVlabs/GroupViT

 

GitHub - NVlabs/GroupViT: Official PyTorch implementation of GroupViT: Semantic Segmentation Emerges from Text Supervision, CVPR

Official PyTorch implementation of GroupViT: Semantic Segmentation Emerges from Text Supervision, CVPR 2022. - GitHub - NVlabs/GroupViT: Official PyTorch implementation of GroupViT: Semantic Segmen...

github.com

 

3. Experiments

먼저 간단하게 실험 setting에 대해서 정리하겠다.

  • Architecture: ViT-S with 12 Transformer layers.
  • Datasets: CC and the filtered YFCC.
  • Zero-shot Transfer: PASCAL VOC and PASCAL Context.

더 자세한 내용은 논문을 참고하길 바란다.

 

3-1. Ablation Study

 

논문에서 사용된 방법 중에서 hyper parameter를 바꾸었을 때의 실험이다. 결과만 살펴보자면, 먼저 hard assignment를 사용하는 것이 성능 향상 폭이 높고, multi-label loss (prompted text를 추가한 loss)를 더해주었을 때 성능이 좋은 것을 확인할 수 있다.

 

다음으로는 group tokens의 개수와 output tokens의 개수에 대한 결과인데 일반적으로 group tokens의 개수가 높을 수록 성능이 좋은 것을 볼 수 있다. 저자의 말로는 각 group token은 disticnt한 semantic concepts을 나타내고 있기 때문에 그 수가 많으면 그만 큼 도움이 될 수 있다고 한다. output tokens은 heuristic하게 8개가 optimal이라고 한다.

 

마지막으로 single stage와 2 stage architecture에 대한 성능 비교인데, stage란 Grouping Block의 개수를 말한다. 즉, single stage는 한번의 grouping만 거친다는 의미고, 2 stage는 두번 그 과정을 수행한다는 의미이다. 직관적으로 한번에 grouping을 하게 되면 세밀한 조정이 그만큼 어려우므로 2-stage의 성능이 더 높은 것을 알 수 있다.

**Grouping block에서 gumbel softmax output을 그대로 사용하면 soft, straingt through trick을 사용하면 hard.

 

3-2. Comparison with Existing Methods

먼저 zero-shot에 대한 비교이다. ViT를 CLIP과 같이 학습시킨 후에 여러 method를 적용한 것인데 이게 existing methods와 비교했다고 할 수 있나?? 라는 생각이 들었다. GroupViT 외에도 zero-shot segmentation을 적용한 methods가 있을텐데...왜 저런 식으로 비교했는지는 알 수 없다.

 

두번 째로는 fully-supervised transfer에 대한 비교이다. 이번에는 제대로 다른 논문들의 모델과 비교를 하였는데, PASCAL VOC에서 완전 supervision인 모델의 성능과도 comparable하고 그 외 다른 self-supervision 모델보다는 확연히 성능 차이가 난다.

qualiatative results로 보았을 때에도 semnatic segmenation이 아주 잘 되는 것을 볼 수 있다. 물론 체리피킹이겠지만 segmenation lbaels 없이 이러한 결과가 나온다는 점이 매우 흥미로웠다.