본문 바로가기

책/파이토치 트랜스포머를 활용한 자연어 처리와 컴퓨터비전 심층학습

10 비전 트랜스포머 (1) ViT

이전 포스팅에서 ViT 논문을 리뷰한 적이 있어서 여기서는 책의 내용을 몇 가지만 요약하고 파이토치 실습을 진행한다.

2024.01.23 - [논문 리뷰] - An Image Is Worth 16x16 Words : Transformers For Image Recognition At Scale

 

 

1. 합성곱 모델과 ViT 모델 비교

source : https://ogre51.medium.com/vit-vision-transformers-an-introduction-dee8161f2caa

 

ViT는 이미지 시퀀스의 순서를 왼쪽에서 오른쪽 그리고 위에서 아래로 정하기 때문에 왼쪽 위의 패치를 1, 오른쪽 아래를 9라고 할 수 있다. 예를 들어 강아지의 오른쪽 귀에 대한 특징을 얻고 싶을 때 ViT는 셀프 어텐션으로 모든 이미지 패치가 서로에게 주는 영향을 고려해 이미지의 특징을 추출한다. 

 

반면 CNN은 오른쪽 귀에 대한 특징을 2, 3, 5, 9 패치만 관여한다. 따라서 좁은 수용 영역을 가진 CNN은 전체 이미지 정보를 표현하는 데 수많은 게층이 필요하지만, ViT는 어텐션 거리(attention distance)를 계산하여 오직 한 개의 ViT 레이어로 전체 이미지 정보를 표현할 수 있다. 어텐션 거리는 query와 value 벡터 사이의 유사도를 내적으로 계산한 것을 의미한다.

 

다만 ViT는 입력 이미지의 크기가 고정되어 있어 이에 대한 전처리가 필요하며, CNN이 공간적인 위치 정보를 고려하는 데 비해 ViT는 패치간의 상대적인 위치 정보만 고려하기 때문에 이미지 변환에 취약할 수 있다.

 

 

2. ViT의 귀납적 편향

딥러닝 모델의 귀납적 편향(inductive bias)은 일반화 성능 향상을 위한 모델의 가정을 의미한다. 예를 들어 지역적 편향을 가진 CNN은 공간적 관계를 잘 표현하는 이미지 데이터에 많이 활용된다.

 

반면 시계열 데이터는 시간적 관계를 잘 표현하므로 sequential 편향을 가진 순환 신경망 모델이 많이 사용된다.

 

ViT는 입력 데이터의 다양한 QKV의 임베딩 형태로 일반화된 관계를 학습하기 때문에 귀납적 편향이 거의 없다. 귀납적 편향은 해당 모델이 가지는 구조와 매개변수들이 데이터에 적합한 가정을 하고 있음을 나타낸다. 이러한 가정이 올바르다면 높은 일반화 성능을 보일 수 있지만, 너무 강한 귀납적 편향은 다른 유형의 데이터나 관계를 표현하는 데 어려움을 초래할 수 있다. 따라서 다양한 관계를 표현하기 위해 귀납적 편향이 약한 모델을 선호하게 된다.

 

 

3. ViT 모델 실습

Colab에서 허깅 페이스 라이브러리와 FashionMNIST 데이터셋을 활용해 ViT 모델을 미세조정해 본다.

 

!pip install evaluate accelerate transformers -U
import evaluate
import numpy as np
import matplotlib.pyplot as plt
from itertools import chain
from collections import defaultdict
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

import torch
from torch.utils.data import Subset, DataLoader
from torchvision import datasets, transforms

from transformers import AutoImageProcessor, ViTForImageClassification, TrainingArguments, Trainer

 

 

FashionMNIST 다운로드

  • torchvision.datasets의 FashionMNIST 클래스로 지정된 경로에 다운로드 할 수 있다.
  • label2id와 id2label은 모델을 초기화할 때 사용된다.
train_dataset = datasets.FashionMNIST(
    root="./data",
    train=True,
    download=True,
)
test_dataset = datasets.FashionMNIST(
    root="./data",
    train=False,
    download=True,
)

classes = train_dataset.classes
label2id = train_dataset.class_to_idx
id2label = {idx:label for label, idx in label2id.items()}

print(classes)
print(label2id)
['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
{'T-shirt/top': 0, 'Trouser': 1, 'Pullover': 2, 'Dress': 3, 'Coat': 4, 'Sandal': 5, 'Shirt': 6, 'Sneaker': 7, 'Bag': 8, 'Ankle boot': 9}

 

  • 간단한 실습을 위해 전체 데이터가 아닌 일부 데이터만 학습에 사용할 것이기 때문에 subset_sampler 함수를 정의한다.
  • 딕셔너리 target_idx에 label별로 데이터의 index를 저장한다.
  • 여러개의 iterable을 하나의 iterable로 결합해주는 chain 함수를 사용해 label별로 max_len개 만큼 샘플링하여 indices 변수에 해당하는 데이터의 index를 저장한다.
  • Subset 클래스에 dataset과 indices를 전달하여 리턴한다.
def subset_sampler(dataset, classes, max_len):
  target_idx = defaultdict(list)
  for idx, label in enumerate(dataset.targets):
    target_idx[int(label)].append(idx)
  
  indices = list(
      chain.from_iterable(
          [target_idx[label][:max_len] for label in range(len(classes))]
      )
  )
  return Subset(dataset, indices)
subset_train_dataset = subset_sampler(
    dataset=train_dataset, classes=train_dataset.classes, max_len=1000
)
subset_test_dataset = subset_sampler(
    dataset=test_dataset, classes=test_dataset.classes, max_len=100
)

print(f"Training Data Size : {len(subset_train_dataset)}")
print(f"Testing Data Size : {len(subset_test_dataset)}")
Training Data Size : 10000
Testing Data Size : 1000

 

 

이미지 전처리

  • AutoImageProcessor는 주어진 모델에 대한 사전 학습된 전처리기를 쉽게 로드할 수 있게 도와주는 클래스다.
  • "google/vit-base-patch16-224-in21k"은 ImageNet-21k 데이터셋으로 사전 훈련되었으며, 224x224 크기의 이미지를 입력 받고 패치 해상도 16인 모델이다. 
  • ViT는 채널 수가 3인 데이터를 입력받기 때문에 Lambda로 단일 채널을 복제해 다중 채널 이미지로 변환해준다.
image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")

transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Resize((image_processor.size["height"], image_processor.size["width"])),
        transforms.Lambda(lambda x: torch.cat([x, x, x], 0)),
        transforms.Normalize(
            mean=image_processor.image_mean,
            std=image_processor.image_std,
        )
    ]
)
vars(image_processor)
{'_processor_class': None,
 'do_resize': True,
 'do_rescale': True,
 'do_normalize': True,
 'size': {'height': 224, 'width': 224},
 'resample': <Resampling.BILINEAR: 2>,
 'rescale_factor': 0.00392156862745098,
 'image_mean': [0.5, 0.5, 0.5],
 'image_std': [0.5, 0.5, 0.5]}

 

 

데이터로더 적용

ViT 모델의 입력은 {"pixel_values": pixel_values, "labels":labels} 형태를 입력 받기 때문에 batch를 적절한 형태로 변환해주는 collator 함수를 정의한다.

def collator(batch, transform):
  images, labels = zip(*batch)
  pixel_values = torch.stack([transform(image) for image in images])
  labels = torch.tensor([label for label in labels])
  return {"pixel_values": pixel_values, "labels": labels}
train_dataloader = DataLoader(
    dataset=subset_train_dataset,
    batch_size=32,
    shuffle=True,
    collate_fn=lambda x: collator(x, transform),
    drop_last=True,
)
valid_dataloader = DataLoader(
    dataset=subset_test_dataset,
    batch_size=4,
    shuffle=True,
    collate_fn=lambda x: collator(x, transform),
    drop_last=True,
)

 

 

사전 학습된 ViT 모델 불러오기

model = ViTForImageClassification.from_pretrained(
  pretrained_model_name_or_path="google/vit-base-patch16-224-in21k",
  num_labels=len(classes),
  id2label=id2label,
  label2id=label2id,
  ignore_mismatched_sizes=True,
)

 

모델 구조는 다음과 같다.

for name, module in model.named_children():
  print(name)
  for name2, module2 in module.named_children():
    print('L', name2)
    for name3, module3 in module2.named_children():
      print('   L', name3)
vit
L embeddings
   L patch_embeddings
   L dropout
L encoder
   L layer
L layernorm
classifier

 

표준 트랜스포머는 1차원 시퀀스를 입력으로 받기 때문에 2D 이미지를 처리할 필요가 있다. 2D 이미지 x에 대해서 다음과 같이 조정한다.

 

$(H, W, C) \longrightarrow (N, P^2 {\cdot}C)$

 

$(P, P)$는 각 이미지 패치의 해상도를 의미하며, $N=HW/P^2$로 계산되는 패치의 개수이다. 트랜스포머는 모든 레이어에서 일정한 잠재 벡터 크기 D를 사용하므로 패치를 평탄화하고 학습 가능한 선형 투영을 사용하여 D차원으로 매핑한다. 이 투영의 출력을 패치 임베딩이라고 한다.

 

현재 실습에서 H, W, P는 각각 224, 224, 16이기 때문에 N은 196(=224*224/16^2)이다. 또한 P^2*C는 768으로 계산된다. 이것이 정확한지 확인해보자.

 

batch = next(iter(train_dataloader))
print("image_shape : ", batch["pixel_values"].shape)
print("patch embeddings shape :",
    model.vit.embeddings.patch_embeddings(batch["pixel_values"]).shape
)
print("[CLS] + patch embeddings shape :",
    model.vit.embeddings(batch["pixel_values"]).shape
)
image_shape :  torch.Size([32, 3, 224, 224])
patch embeddings shape : torch.Size([32, 196, 768])
[CLS] + patch embeddings shape : torch.Size([32, 197, 768])

 

 

하이퍼파라미터 설정

TrainingArguments 클래스는 모델 학습에 필요한 다양한 인자들을 저장하고 관리할 수 있다.

주요 매개변수를 표로 정리하면 다음과 같다.

매개변수 의미
output_dir 체크포인트 저장 경로
save_strategy 체크포인트 저장 간격 설정
no : 저장 안함
steps : 스텝마다
epoch : 에폭마다
evaluation_strategy 체크포인트 평가 간격 설정
no : 저장 안함
steps : 스텝마다
epoch : 에폭마다
learning_rate 초기 학습률
per_device_train_batch_size 학습 배치 크기
load_best_model_at_end 모델 불러오기 시 최상의 모델 선택 여부
metric_for_best_model 최상의 모델 선정 기준이 되는 평가 방식 설정
accuracy, precision, f1, mae, mse, rmse
logging_dir 로그 저장 경로
logging_steps 로그 출력 간격

 

args = TrainingArguments(
    output_dir="../models/ViT-FashionMNIST",
    save_strategy="epoch",
    evaluation_strategy="epoch",
    learning_rate=1e-5,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    num_train_epochs=3,
    weight_decay=0.001,
    load_best_model_at_end=True,
    metric_for_best_model="f1",
    logging_dir="logs",
    logging_steps=10,
    remove_unused_columns=False,
    seed=42,
)

 

평가 함수는 허깅페이스의 evaluate 라이브러리를 활용하여 정의한다.

def compute_metrics(eval_pred):
  metric = evaluate.load("f1")
  predictions, labels = eval_pred
  predictions = np.argmax(predictions, axis=1)
  macro_f1 = metric.compute(
      predictions=predictions,
      references=labels,
      average="macro"
  )
  return macro_f1

 

 

학습

Trainer 클래스로 학습을 수행하면 따로 데이터로더를 만들 필요없이 학습 및 평가 데이터셋과 collator 함수를 전달할 수 있다.

def model_init(classes, label2id, id2label):
  model = ViTForImageClassification.from_pretrained(
      pretrained_model_name_or_path="google/vit-base-patch16-224-in21k",
      num_labels=len(classes),
      id2label=id2label,
      label2id=label2id,
  )
  return model
trainer = Trainer(
    model_init=lambda x: model_init(classes, label2id, id2label),
    args=args,
    train_dataset=subset_train_dataset,
    eval_dataset=subset_test_dataset,
    data_collator=lambda x: collator(x, transform),
    compute_metrics=compute_metrics,
    tokenizer=image_processor,
)
trainer.train()

 

학습 결과는 다음과 같다.

 

 

평가

confusion matrix를 활용해 모델의 성능을 평가해본다. 우선 모델의 출력 형식은 다음과 같다.

outputs = trainer.predict(subset_test_dataset)
print(outputs)
PredictionOutput(predictions=array([[ 2.620061  , -0.75976956, -0.22009146, ..., -0.43575624,
        -0.26031983, -0.3763049 ],
       [ 1.7169125 , -0.325412  ,  0.16614406, ..., -0.7760629 ,
        -0.4124354 , -0.75789213],
       [ 2.4755652 , -0.70636165, -0.18435624, ..., -0.50174546,
        -0.37349728, -0.44415036],
       ...,
       [-0.5555925 , -0.46106488, -0.5794496 , ...,  0.8493119 ,
        -0.07648225,  2.379642  ],
       [-0.3728562 , -0.44994363, -0.44079086, ...,  0.53375655,
        -0.40643337,  2.8023272 ],
       [-0.4295409 , -0.30554977, -0.4783806 , ..., -0.14260468,
        -0.42149764,  2.8321877 ]], dtype=float32), 
       label_ids=array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       ...,
       9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9,
       9, 9, 9, 9, 9, 9, 9, 9, 9, 9]), 
       metrics={'test_loss': 0.5899580717086792, 'test_f1': 0.9075374774369266, 'test_runtime': 13.3799, 'test_samples_per_second': 74.739, 'test_steps_per_second': 2.392})

 

다음은 confusion matrix를 시각화하는 방법이다.

y_true = outputs.label_ids
y_pred = outputs.predictions.argmax(1)

labels = list(classes)
matrix = confusion_matrix(y_true, y_pred)
display = ConfusionMatrixDisplay(confusion_matrix=matrix, display_labels=labels)
_, ax = plt.subplots(figsize=(10, 10))
display.plot(xticks_rotation=45, ax=ax)
plt.show()

 

결과를 보면 Shirt를 T-shirt, Pullover, Coat로 오분류한 경우가 많은 것을 확인할 수 있다. 이를 보고 오분류된 클래스에 대한 문제점을 파악하고 모델을 개선해 정확도를 향상시킬 수 있다.

 

모델 개선 방법으로는 하이퍼파라미터 조정, 데이터 증강, 전처리 과정 개선, 모델 구조 변경 등이 있다.