본문 바로가기

책/파이토치 트랜스포머를 활용한 자연어 처리와 컴퓨터비전 심층학습

10 비전 트랜스포머 (2) Swin Transformer pytorch 실습

Swin Transformer의 이론적인 내용은 이전 포스팅을 참고 바랍니다. 

2024.03.09 - [논문 리뷰] - Swin Transformer: Hierarchical Vision Transformer using Shifted Windows

 

 

허깅페이스 라이브러리로 사전학습된 Swin 모델을 FashionMNIST 데이터셋에 미세 조정하는 실습을 해본다. 학습 전에 patch embedding 출력과 patch merging 출력 등을 확인하여 논문에서 배운것과 동일한지 확인해보고 모델 학습을 진행한다.

 

 

사전 학습된 Swin 모델 불러오기

"microsoft/swin-tiny-patch4-window7-224" 모델을 사용한다. ImageNet-1K 데이터셋으로 224x224, 패치 해상도 4x4, local window 7x7로 학습됐다.

from transformers import SwinForImageClassification

model = SwinForImageClassification.from_pretrained(
    pretrained_model_name_or_path="microsoft/swin-tiny-patch4-window7-224",
    num_labels=len(classes),
    id2label={idx: label for label, idx in class_to_idx.items()},
    label2id=class_to_idx,
    ignore_mismatched_sizes=True
)

 

 

모델 구조는 크게 swin과 classifier로 구성된다. swin은 embeddings, encoder(stage 0, 1, 2, 3), layernorm, pooler로 구성된다. 

for main_name, main_module in model.named_children():
    print(main_name)
    for sub_name, sub_module in main_module.named_children():
        print("└", sub_name)
        for ssub_name, ssub_module in sub_module.named_children():
            print("│  └", ssub_name)
            for sssub_name, sssub_module in ssub_module.named_children():
                if sssub_name == "projection":
                    print("│  │  └", sssub_name, sssub_module)
                else:
                    print("│  │  └", sssub_name)
swin
└ embeddings
│  └ patch_embeddings
│  │  └ projection Conv2d(3, 96, kernel_size=(4, 4), stride=(4, 4))
│  └ norm
│  └ dropout
└ encoder
│  └ layers
│  │  └ 0
│  │  └ 1
│  │  └ 2
│  │  └ 3
└ layernorm
└ pooler
classifier

 

 

patch partition은 patch_embeddings에서 수행되며, 모듈은 Conv2d(3, 96, kernel_size=(4, 4), stride=(4, 4))를 사용한다. 224x224 이미지가 이를 통과하면 (224 - 4(kernel_size)) / 4(stride) + 1 = 56이 되어 56x56 텐서가 생성된다. 그리고 이 텐서를 일렬로 나열하면 3136(=56x56)개의 패치가 생성된다. 모듈의 tiny 모델은 다음과 같이 특성 차원이 96으로 설정되어있다.

batch = next(iter(train_dataloader))
print("이미지 차원 : ",batch["pixel_values"].shape)

patch_emb_output, shape = model.swin.embeddings.patch_embeddings(batch["pixel_values"])
print("모듈:", model.swin.embeddings.patch_embeddings)
print("패치 임베딩 차원 :", patch_emb_output.shape)
이미지 차원 :  torch.Size([32, 3, 224, 224])
모듈: SwinPatchEmbeddings(
  (projection): Conv2d(3, 96, kernel_size=(4, 4), stride=(4, 4))
)
패치 임베딩 차원 : torch.Size([32, 3136, 96])

 

 

다음은 스테이지 1에 대한 구조를 보여준다. 2개의 swin transformer block과 패치 병합을 수행하는 downsample로 구성되어있다.

for main_name, main_module in model.swin.encoder.layers[0].named_children():
    print(main_name)
    for sub_name, sub_module in main_module.named_children():
        print("└", sub_name)
        for ssub_name, ssub_module in sub_module.named_children():
            print("│ └", ssub_name)
blocks
└ 0
│ └ layernorm_before
│ └ attention
│ └ drop_path
│ └ layernorm_after
│ └ intermediate
│ └ output
└ 1
│ └ layernorm_before
│ └ attention
│ └ drop_path
│ └ layernorm_after
│ └ intermediate
│ └ output
downsample
└ reduction
└ norm

 

 

block의 어텐션을 수행하여도 차원은 동일하게 유지된다.

print("패치 임베딩 차원 :", patch_emb_output.shape)

W_MSA = model.swin.encoder.layers[0].blocks[0]
SW_MSA = model.swin.encoder.layers[0].blocks[1]

W_MSA_output = W_MSA(patch_emb_output, W_MSA.input_resolution)[0]
SW_MSA_output = SW_MSA(W_MSA_output, SW_MSA.input_resolution)[0]

print("W-MSA 결과 차원 :", W_MSA_output.shape)
print("SW-MSA 결과 차원 :", SW_MSA_output.shape)
패치 임베딩 차원 : torch.Size([32, 3136, 96])
W-MSA 결과 차원 : torch.Size([32, 3136, 96])
SW-MSA 결과 차원 : torch.Size([32, 3136, 96])

 

 

그런 다음 패치 병합을 통해 3136개의 패치가 784(=28x28)개로 병합된다. 이 과정에서 384(=96x4) 특징 차원을 선형 변환을 통해 절반인 192로 감소시킨다.

patch_merge = model.swin.encoder.layers[0].downsample
print("patch_merge 모듈 :", patch_merge)

output = patch_merge(SW_MSA_output, patch_merge.input_resolution)
print("patch_merge 결과 차원 :", output.shape)
patch_merge 모듈 : SwinPatchMerging(
  (reduction): Linear(in_features=384, out_features=192, bias=False)
  (norm): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
)
patch_merge 결과 차원 : torch.Size([32, 784, 192])

 

 

모델 학습

모델 초기화 함수와 평가 메트릭을 정의한다.

def model_init(classes, class_to_idx):
    model = SwinForImageClassification.from_pretrained(
        pretrained_model_name_or_path="microsoft/swin-tiny-patch4-window7-224",
        num_labels=len(classes),
        id2label={idx: label for label, idx in class_to_idx.items()},
        label2id=class_to_idx,
        ignore_mismatched_sizes=True
    )
    return model
import evaluate

def compute_metrics(eval_pred):
    metric = evaluate.load("f1")
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    macro_f1 = metric.compute(
        predictions=predictions, references=labels, average="macro"
    )
    return macro_f1

 

 

TrainingArguments를 정의하고 이를 Trainer 클래스에 전달한 뒤 모델을 학습한다.

from transformers import Trainer, TrainingArguments


args = TrainingArguments(
    output_dir="../models/Swin-FashionMNIST",
    save_strategy="epoch",
    evaluation_strategy="epoch",
    learning_rate=1e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=3,
    weight_decay=0.001,
    load_best_model_at_end=True,
    metric_for_best_model="f1",
    logging_dir="logs",
    logging_steps=125,
    remove_unused_columns=False,
    seed=7
)

trainer = Trainer(
    model_init=lambda x: model_init(classes, class_to_idx),
    args=args,
    train_dataset=subset_train_dataset,
    eval_dataset=subset_test_dataset,
    data_collator=lambda x: collator(x, transform),
    compute_metrics=compute_metrics,
    tokenizer=image_processor,
)
trainer.train()

 

 

평가 결과는 다음과 같다. ViT와 비교하였을 때 Shirts에 대한 recall이 많이 증가하였다.

왼쪽 ViT, 오른쪽 Swin