본문 바로가기

책/파이토치 트랜스포머를 활용한 자연어 처리와 컴퓨터비전 심층학습

07 트랜스포머 (4) BART

1. BART

BART(Bidirectional Auto Regressive Transformer)는 2019년 메타의 FAIR 연구소에서 발표한 트랜스포머 기반의 모델이다.

 

 

BART는 BERT의 인코더와 GPT의 디코더를 결합한 seq2seq 구조노이즈 제거 오토인코더(Denoising Autoencoder)로 사전 학습된다. 이는 입력 데이터에 노이즈를 추가하고, 노이즈가 없는 원본 데이터를 복원하도록 학습하는 방식으로 수행된다.

 

트랜스포머에서는 인코더의 모든 계층과 디코더의 모든 계층 사이의 어텐션 연산을 수행했다면, BART는 인코더의 마지막 계층과 디코더의 각 계층 사이에만 어텐션 연산을 수행한다.

 

 

2. 사전 학습 방법

사전 학습에 사용한 기법은 토큰 마스킹(Token Masking), 토큰 삭제(Token Deletion), 문장 순열(Sentence Permutation), 문서 회전(Document Rotation) 그리고 텍스트 채우기(Text Infilling)가 있다.

 

  1. 토큰 마스킹 : BERT와 동일
  2. 토큰 삭제 : 입력 문장의 일부 토큰을 치환하는 것이 아니라, 삭제하는 방법이다. 모델은 어떤 위치의 토큰이 삭제되었는지도 맞춰야 한다. 이 기법은 입력 문장에서 불필요한 정보나 중요하지 않은 정보를 자동으로 필터링해 처리할 수 있게 되며, 모델의 학습과 예측 시간을 줄이고, 모델의 일반화 성능을 향상시킬 수 있다. 이 기법으로 인해 문장 요약 작업을 수행할 수 있는 것이 아닌가 하는 생각이 든다.
  3. 문장 순열 : 마침표를 기준으로 문장을 나눈뒤, 문장의 순서를 섞는 방법이다. 모델은 원래의 문장 순서를 맞춰야 한다.
  4. 문서 회전 : 임의의 토큰으로 문서가 시작하도록 하되, 문장 순열과는 다르게 문장의 순서는 유지한다. 모델은 문서의 원래 시작 토큰을 맞춰야 한다.
  5. 텍스트 채우기 : $\lambda = 3$인 푸아송 분포를 따르는 text span들을 뽑는다. text span은 몇 개의 토큰을 하나의 구간(span)으로 묶은 것을 의미한다. 그리고 이 중 일부를 마스크 토큰으로 대체한다. 모델은 연속된 마스크 토큰을 복구하되, 실제로는 마스킹되지 않은 토큰도 구분해야 한다.

 

예시를 들어보면 다음과 같다.

 

원본 문장 : 그는 정문으로 발을 옮겼다. 그의 아내는 멀어지는 그를 바라보고 있었다. 그녀는 눈물을 흘렸다. 그도 마찬가지였다.

 

토큰 마스킹

그는 _ 발을 옮겼다. _ 아내는 멀어지는 그를 _ 있었다. 그녀는 눈물을 _. 그도 _였다.

 

토큰 삭제

그는 발을 옮겼다. 아내는 멀어지는 그를 있었다. 그녀는 눈물을 . 그도 였다.

 

문장 순열

그녀는 눈물을 흘렸다. 그는 정문으로 발을 옮겼다. 그의 아내는 멀어지는 그를 바라보고 있었다. 그도 마찬가지였다.

 

문서 회전

그를 바라보고 있었다. 그녀는 눈물을 흘렸다. 그도 마찬가지였다. 그는 정문으로 발을 옮겼다. 그의 아내는 멀어지는

 

텍스트 채우기

그는 옮겼다. 그의 아내는 멀어지는 . 그녀는 눈물을 흘렸다. 그도 .

 

 

3. Fine Tuning

인코더와 디코더를 모두 사용하는 구조를 가지고 있기 때문에 미세 조정 시 각 다운스트림 작업에 맞게 입력 문장을 구성해야 한다. 즉, 인코더와 디코더에 다른 문장 구조로 입력한다.

 

문장 분류 : 입력 문장을 인코더와 디코더에 동일하게 입력하고 디코더의 마지막 토큰이 위치한 은닉 상태를 선형 분류기의 입력값으로 사용한다.

 

토큰 분류(품사 태깅) : 입력 문장을 인코더와 디코더에 동일하게 입력한다. 디코더의 마지막 은닉 상태에서 각각의 토큰이 위치한 부분의 은닉 상태를 토큰 분류기의 입력값으로 사용한다.

 

문장 생성 : BART는 autoregressive decoder를 갖고 있으므로 abstractive question answering나 summarization와 같은 생성 작업에 바로 적용할 수 있다. 두 가지 모두 정보를 입력값을 조작해 출력을 생성하는 작업이다. 이는 BART의 사전 학습 방식과 유사하기 때문에 뛰어난 성능을 보인다.

 

기계 번역 : 사전 학습된 인코더에 기계 번역을 위한 인코더를 추가해 작업을 수행할 수 있다. 추가된 인코더는 기존의 단어 사전을 사용하지 않아도 되며, 디코더는 사전 학습된 가중치와 단어 사전을 사용한다. 먼저 기존에 있던 인코더와 디코더의 가중치를 동결하고 추가된 인코더의 가중치를 학습한다. 그런 다음 전체 가중치를 학습한다.

 

 

4. 모델 실습

허깅페이스 라이브러리의 BART 모델과 뉴스 요약 데이터세트를 활용해 문장 요약 모델을 미세 조정해 본다.

 

!pip install datasets portalocker -U

 

뉴스 요약 데이터세트는 미국의 인공지능 회사인 Argilla에서 공개한 데이터세트로 뉴스 본문과 본문의 요약 텍스트로 구성된다. 허깅페이스에서 제공하는 datasets 라이브러리를 통해 불러온다.

 

 

뉴스 요약 데이터세트 불러오기

import numpy as np
from datasets import load_dataset


news = load_dataset("argilla/news-summary", split="test")
df = news.to_pandas().sample(5000, random_state=42)[["text", "prediction"]]
df["prediction"] = df["prediction"].map(lambda x: x[0]["text"])
train, valid, test = np.split(
    df.sample(frac=1, random_state=42), [int(0.6 * len(df)), int(0.8 * len(df))]
)

print(f"Source News : {train.text.iloc[0][:200]}")
print(f"Summarization : {train.prediction.iloc[0][:50]}")
print(f"Training Data Size : {len(train)}")
print(f"Validation Data Size : {len(valid)}")
print(f"Testing Data Size : {len(test)}")
Source News : DANANG, Vietnam (Reuters) - Russian President Vladimir Putin said on Saturday he had a normal dialogue with U.S. leader Donald Trump at a summit in Vietnam, and described Trump as civil, well-educated
Summarization : Putin says had useful interaction with Trump at Vi
Training Data Size : 3000
Validation Data Size : 1000
Testing Data Size : 1000

 

 

BART 입력 텐서 생성

import torch
from transformers import BartTokenizer
from torch.utils.data import TensorDataset, DataLoader
from torch.utils.data import RandomSampler, SequentialSampler
from torch.nn.utils.rnn import pad_sequence


def make_dataset(data, tokenizer, device):
    tokenized = tokenizer(
        text=data.text.tolist(),
        padding="longest",
        truncation=True,
        return_tensors="pt"
    )
    labels = []
    input_ids = tokenized["input_ids"].to(device)
    attention_mask = tokenized["attention_mask"].to(device)
    for target in data.prediction:
        labels.append(tokenizer.encode(target, return_tensors="pt").squeeze())
    labels = pad_sequence(labels, batch_first=True, padding_value=-100).to(device)
    return TensorDataset(input_ids, attention_mask, labels)



def get_datalodader(dataset, sampler, batch_size):
    data_sampler = sampler(dataset)
    dataloader = DataLoader(dataset, sampler=data_sampler, batch_size=batch_size)
    return dataloader


epochs = 5
batch_size = 8
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = BartTokenizer.from_pretrained(
    pretrained_model_name_or_path="facebook/bart-base"
)

train_dataset = make_dataset(train, tokenizer, device)
train_dataloader = get_datalodader(train_dataset, RandomSampler, batch_size)

valid_dataset = make_dataset(valid, tokenizer, device)
valid_dataloader = get_datalodader(valid_dataset, SequentialSampler, batch_size)

test_dataset = make_dataset(test, tokenizer, device)
test_dataloader = get_datalodader(test_dataset, SequentialSampler, batch_size)

print(train_dataset[0])
(tensor([   0,  495, 1889,  ...,    1,    1,    1], device='cuda:0'), tensor([1, 1, 1,  ..., 0, 0, 0], device='cuda:0'), tensor([    0, 35891,   161,    56,  5616, 10405,    19,   140,    23,  5490,
         3564,     2,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100],
       device='cuda:0'))

 

이전 포스팅인 BERT를 실습할 때는 텍스트 감성 분류였기 때문에 label이 정수였지만, 이번에는 label 데이터 또한 텍스트이다. 따라서 tokenizer의 encode 메서드로 문장을 토큰화해준다. 문장의 길이가 각각 다르기 때문에 torch.nn.utils.rnn의 pad_sequence 함수로 label 데이터를 패딩해주어 길이를 맞춰준다.

 

이때 패딩 값은 -100을 사용하는데 이는 손실 함수에서 패딩된 토큰을 무시하기 위해 사용한다.

 

 

BART 모델 선언

from torch import optim
from transformers import BartForConditionalGeneration


model = BartForConditionalGeneration.from_pretrained(
    pretrained_model_name_or_path="facebook/bart-base"
).to(device)
optimizer = optim.AdamW(model.parameters(), lr=5e-5, eps=1e-8)

for main_name, main_module in model.named_children():
    print(main_name)
    for sub_name, sub_module in main_module.named_children():
        print("└", sub_name)
        for ssub_name, ssub_module in sub_module.named_children():
            print("│  └", ssub_name)
            for sssub_name, sssub_module in ssub_module.named_children():
                print("│  │  └", sssub_name)
model
└ shared
└ encoder
│  └ embed_tokens
│  └ embed_positions
│  └ layers
│  │  └ 0
│  │  └ 1
│  │  └ 2
│  │  └ 3
│  │  └ 4
│  │  └ 5
│  └ layernorm_embedding
└ decoder
│  └ embed_tokens
│  └ embed_positions
│  └ layers
│  │  └ 0
│  │  └ 1
│  │  └ 2
│  │  └ 3
│  │  └ 4
│  │  └ 5
│  └ layernorm_embedding
lm_head

 

BartForConditionalGeneration 클래스는 BART 모델의 변형 중 하나로 조건부 생성 작업에 특화된 모델이다. 예를 들어 문장 요약, 기계 번역, 질의응답 등과 같은 작업에 사용할 수 있다.

 

BART는 인코더와 디코더가 동일한 임베딩 계층을 사용한다. shared 계층은 인코더와 디코더가 공유하는 임베딩 계층을 의미하며, 이러한 공유로 인코더와 디코더 간의 연결을 강화한다.

 

 

ROUGE(Recall-Oriented Understudy for Gisting Evaluation)

평가는 루지 점수를 사용한다. 이는 생성된 요약문과 정답 요약문이 얼마나 유사한지를 평가하기 위해 토큰의 N-gram 정밀도와 재현율을 이용해 평가하는 지표다.

 

예를 들어, 유니그램을 사용하면 ROUGE-1, 바이그램은 ROUGE-2, N-gram을 사용하면 ROUGE-N이라고 한다.

 

계산 방법의 예시를 보자.

 

정답 문장 유니그램 : [대한민국, 은, 16, 강, 에, 진출, 했다]

생성된 문장 유니그램 : [대한민국, 은, 8, 강, 에, 진출, 하지, 못, 했다]

ROUGE-1 재현율 : (둘 다 등장한 토큰 / 정답 문장에 등장한 토큰) = 6/7

ROUGE-1 정밀도 : (둘 다 등장한 토큰 / 생성된 문장에 등장한 토큰) = 6/9

 

이외에도 ROUGE-L, ROUGE-LSUM, ROUGE-W 등이 있다.

 

 

ROUGE-L

ROUGE-L은 생성된 요약문과 정답 요약문 사이에서 최장 공통부분 수열(Longest Common, Subseqeunce, LCS) 기반의 통계 방식이다.

 

위의 예시에서 "대한민국 은 강 에 진출" 이렇게 5개의 토큰이 LCS이기 때문에 5/7로 계산한다.

 

 

ROUGE-LSUM

ROUGE-LSUM은 ROUGE-L의 변형으로 텍스트 내의 개행 문자(\n)를 문장 경계로 인식하고, 각 문장 쌍에 대해 LCS를 계산한 후, union-LCS라는 값을 계산한다. union-LCS는 각 문장 쌍의 LCS를 합집합 한 것으로, 중복된 부분을 제거한 후 길이를 계산한다.

 

정답 문장 : "John is a talented musician.\n He has band called (The Forest Rangers) Recently"

생성 문장 : "John is an accomplished artist.\n He is part of band called (The Forest Rangers)"

 

첫 번째 문장 : "John is"가 LCS이므로 점수는 2/5이다.두 번째 문장 : "He band called (The Forest Rangers)"가 LCS이므로 점수는 6/8이다.

 

그런 다음 위의 두 점수를 평균내어 최종 점수를 계산한다. ROUGE-LSUM은 요약의 정확성과 완전성을 모두 반영할 수 있는 지표로 사용된다.

 

 

ROUGE-W

ROUGE-W는 가중치가 적용된 LCS 방법으로 연속된 LCS에 가중치를 부여해 계산한다. 이 방법은 공통부분 문자열의 길이뿐만 아니라 해당 부분 문자열 내의 단어에 가중치를 부여해 평가하는 방식이다. 이 방법은 단어 간 유사도를 고려해 요약 문장의 의미를 더욱 잘 전달하는지 평가하는데 유용하다.

 

 

허깅페이스의 evaluate 라이브러리로 루지 점수를 계산할 수 있다. 이를 위해 rouge_score 라이브러리와 absl-py도 함께 설치해줘야 한다.

 

!pip install evaluate rouge_score absl-py

 

 

모델 학습

import evaluate


def calc_rouge(preds, labels):
    preds = preds.argmax(axis=-1)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)

    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    rouge2 = rouge_score.compute(
        predictions=decoded_preds,
        references=decoded_labels
    )
    return rouge2["rouge2"]

def train(model, optimizer, dataloader):
    model.train()
    train_loss = 0.0

    for input_ids, attention_mask, labels in dataloader:
        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)

        loss = outputs.loss
        train_loss += loss.item()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    train_loss = train_loss / len(dataloader)
    return train_loss

def evaluation(model, dataloader):
    with torch.no_grad():
        model.eval()
        val_loss, val_rouge = 0.0, 0.0

        for input_ids, attention_mask, labels in dataloader:
            outputs = model(
                input_ids=input_ids, attention_mask=attention_mask, labels=labels
            )
            logits = outputs.logits
            loss = outputs.loss

            logits = logits.detach().cpu().numpy()
            label_ids = labels.to("cpu").numpy()
            rouge = calc_rouge(logits, label_ids)

            val_loss += loss
            val_rouge += rouge

    val_loss = val_loss / len(dataloader)
    val_rouge = val_rouge / len(dataloader)
    return val_loss, val_rouge

rouge_score = evaluate.load("rouge", tokenizer=tokenizer)
best_loss = 10000
for epoch in range(epochs):
    train_loss = train(model, optimizer, train_dataloader)
    val_loss, val_accuracy = evaluation(model, valid_dataloader)
    print(f"Epoch {epoch + 1}: Train Loss: {train_loss:.4f} Val Loss: {val_loss:.4f} Val Rouge {val_accuracy:.4f}")

    if val_loss < best_loss:
        best_loss = val_loss
        torch.save(model.state_dict(), "../BartForConditionalGeneration.pt")
        print("Saved the model weights")

 

 

모델 평가

model = BartForConditionalGeneration.from_pretrained(
    pretrained_model_name_or_path="facebook/bart-base"
).to(device)
model.load_state_dict(torch.load("../BartForConditionalGeneration.pt"))

test_loss, test_rouge_score = evaluation(model, test_dataloader)
print(f"Test Loss : {test_loss:.4f}")
print(f"Test ROUGE-2 Score : {test_rouge_score:.4f}")

 

 

문장 요약문 비교

from transformers import pipeline


summarizer = pipeline(
    task="summarization",
    model=model,
    tokenizer=tokenizer,
    max_length=54,
    device="cpu"
)

for index in range(5):
    news_text = test.text.iloc[index]
    summarization = test.prediction.iloc[index]
    predicted_summarization = summarizer(news_text)[0]["summary_text"]
    print(f"정답 요약문 : {summarization}")
    print(f"모델 요약문 : {predicted_summarization}\n")
정답 요약문 : Clinton leads Trump by 4 points in Washington Post: ABC News poll
모델 요약문 : Clinton leads Trump by 4 percentage points in Washington Post-ABC News poll

정답 요약문 : Democrats question independence of Trump Supreme Court nominee
모델 요약문 : Democratic senators question whether Gorsuch will be independent as Supreme Court justice

정답 요약문 : In push for Yemen aid, U.S. warned Saudis of threats in Congress
모델 요약문 : U.S. warns Saudi over humanitarian situation in Yemen

정답 요약문 : Romanian ruling party leader investigated over 'criminal group'
모델 요약문 : Romania investigates leader of ruling Social Democrat Party over graft

정답 요약문 : Billionaire environmental activist Tom Steyer endorses Clinton
모델 요약문 : Steyer backs Clinton for U.S. president