본문 바로가기

책/파이토치 트랜스포머를 활용한 자연어 처리와 컴퓨터비전 심층학습

11 모델 배포 (3) 모델 서빙

모델 서빙(Model Serving)은 훈련된 머신러닝 모델을 실제 운영 환경에서 사용할 수 있도록 프로세스 또는 시스템을 구축하는 것을 의미한다. 이는 클라이언트의 요청에 따라 서버에서 모델을 호출하고, 예측 결과를 반환하는 인터페이스를 구축하는 과정을 포함한다.

 

  • 클라이언트 : 서버에 서비스를 요청하는 주체. 서비스를 요청하는 PC, 스마트폰, 웹브라우저 등
  • 서버 : 클라이언트에게 요청을 받고 요청에 따라 서비스를 제공하는 컴퓨터 시스템
  • 인터페이스 : 사용자와 서비스 사이의 상호작용을 가능하게 하는 방법을 의미. UI/UX, API, 데이터 형식 등 서비스에 접근하고 상호작용할 수 있는 모든 요소를 포함

 

 

1. 모델 서빙 웹 프레임워크

모델 서빙 웹 프레임워크는 머신러닝 모델을 웹 서비스로 제공하고 관리하기 위한 도구를 의미한다. 이 프레임워크는 실시간으로 요청에 따라 모델을 호출하고 결과를 반환하는 서버를 구성할 수 있으며, 모델의 배포, 스케일링, 모니터링, 버전 관리, 요청 처리 등과 같은 기능도 처리할 수 있다.

 

파이썬에서는 Flask, Django, Fast API가 있다. 세 프레임워크 모두 RESTful API를 개발할 수 있는 표준화된 방법을 제공한다. RESTful API는 서버와 클라이언트 간의 역할을 명확하게 분리하여, 서버와 클라이언트를 독립적으로 개발하고 유지할 수 있다.  

 

 

1.1 Flask

플라스크를 이용하여 BERT를 활용한 간단한 API를 구현해 본다.

 

먼저 BERT 모델 클래스 선언한다. 메서드 선언시 @classmethod 데코레이터를 사용하여 정의한다. 클래스 메서드로 정의하면 모든 인스턴스에서 모델의 상태가 공유된다.

# app_flask.py
import torch
from torch.nn import functional as F
from transformers import BertTokenizer, BertForSequenceClassification


class BertModel:
    device = "cuda" if torch.cuda.is_available() else "cpu"

    @classmethod
    def load_model(cls, weight_path):
        cls.tokenizer = BertTokenizer.from_pretrained(
            pretrained_model_name_or_path="bert-base-multilingual-cased",
            do_lower_case=False,
        )
        cls.model = BertForSequenceClassification.from_pretrained(
            pretrained_model_name_or_path="bert-base-multilingual-cased",
            num_labels=2
        ).to(cls.device)
        cls.model.load_state_dict(torch.load(weight_path, map_location=cls.device))
        cls.model.eval()
        
    @classmethod
    def preprocessing(cls, data):
        input_data = cls.tokenizer(
            text=data,
            padding="longest",
            truncation=True,
            return_tensors="pt"
        ).to(cls.device)
        return input_data

    @classmethod
    @torch.no_grad()
    def predict(cls, input):
        input_data = cls.preprocessing(input)
        outputs = cls.model(**input_data).logits
        probs = F.softmax(outputs, dim=-1)
        
        index = int(probs[0].argmax(axis=-1))
        label = "긍정" if index == 1 else "부정"
        score = float(probs[0][index])

        return {
            "label": label,
            "score": score
        }

 

 

Flask 클래스로 애플리케이션을 생성한다. __name__은 현재 모듈의 이름을 나타내며, 이는 Flask에게 현재 모듈이나 패키지의 위치를 알려주는 역할을 한다.

  • @app.route("/predict", methods=["POST"])은 데코레이터로 특정 URL에 대한 요청을 어떻게 처리할지를 지정한다. "/predict"는 엔드포인트의 URL을 나타내며, 클라이언트가 서버에 데이터를 보낼 때 사용된다. 그리고 클라이언트는 이 URL로 POST 요청을 보낼 수 있다.
  • inference() 함수는 POST 요청을 처리하고 요청 받은 JSON 데이터에서 텍스트를 추출한다. 그런 다음 예측을 수행하고, 그 결과를 JSON 형식으로 반환한다.
  • if __name__ == "__main__":으로 이 파이썬 파일이 메인으로 실행될 때 모델을 로드하고, Flask 애플리케이션을 0.0.0.0 주소와 8000 포트에서 실행한다.
# app_flask.py
import json
from flask import Flask, request, Response


app = Flask(__name__)


@app.route("/predict", methods=["POST"])
def inference():
    data = request.get_json()
    text = data["text"]

    try:
        return Response(
            response=json.dumps(BertModel.predict(text), ensure_ascii=False),
            status=200,
            mimetype="application/json",
        )

    except Exception as e:
        return Response(
            response=json.dumps({"error": str(e)}, ensure_ascii=False),
            status=500,
            mimetype="application/json",
        )


if __name__ == "__main__":
    BertModel.load_model(weight_path="./BertForSequenceClassification.pt")
    app.run(host="0.0.0.0", port=8000)

 

 

터미널에서 플라스크 애플리케이션을 실행한다.

python app_flask.py
 * Serving Flask app 'app_flask'
 * Debug mode: off
WARNING: This is a development server. Do not use it in a production deployment. Use a production WSGI server instead.
 * Running on all addresses (0.0.0.0)
 * Running on http://127.0.0.1:8000
 * Running on http://XXX.XXX.X.X:8000
Press CTRL+C to quit

 

 

이제 모델 추론을 요청하는 새로운 파이썬 파일을 생성해 실행한다.

  • 파일의 콘텐츠 유형은 HTTP 헤더에 Content-Type으로 지정되며, 클라이언트와 서버 간에 어떤 유형의 데이터를 주고받을지를 알려준다. MIME(MultiPurpose Internet Mail Expression) 유형의 형식을 따르며, "주요 유형/서브 유형"으로 표기된다.
  • requests.post() 함수를 사용해 post 요청을 보낸다. 반환된 응답은 response 변수에 저장된다.
import json
import requests


url = "http://127.0.0.1:8000/predict"
headers = {"content-type": "application/json"}

response = requests.post(
    url=url,
    headers=headers,
    data=json.dumps({"text": "정말 재미 있어요!"})
)

print(response.status_code)
print(response.json())
200
{'label': '긍정', 'score': 0.9859620928764343}

 

 

아래는 주로 사용하는 HTTP 메서드를 정리한 표이다. 이밖에도 PATCH, HEAD, OPTIONS 등이 있다.

HTTP 메서드 처리 작업 주요 용도
GET 읽기 서버로부터 정보를 요청하기 위함
POST 생성 서버에 데이터를 제출하기 위함
PUT 갱신 서버에 리소스를 업데이트하기 위함
DELETE 삭제 서버에서 리소스를 삭제하기 위함

 

 

1.2 Fast API

Fast API로 VGG16을 활용한 간단한 API를 구현해 본다.

 

이번에는 추론 요청 파이썬 코드부터 살펴본다. Flask 예제와 달라진 점은 이미지를 Base64로 인코딩하는 부분이 추가되었다. 

  • 먼저 io.BytesIO()를 사용하여 바이트 스트림을 메모리에 쓰는 임시 버퍼를 생성한다.
  • image를 JPEG 형식으로 버퍼에 저장한다.
  • buffer.seek(0)로 임시 버퍼의 파일 포인터를 처음 위치(0번째 바이트)로 이동시킨다. 이미지를 저장한 후에는 파일 포인터가 마지막 위치에 유지된다. 따라서 이를 생략하면 buffer.read()를 호출할 때 파일의 끝에서부터 읽기를 시도하게 되므로, 아무런 데이터도 읽을 수 없다.
  • 버퍼에서 데이터를 읽어와 bytes에 저장한다. 그런 다음 바이트 스트림을 Base64로 인코딩하고 그 결과를 문자열로 반환하기 위해 decode("utf-8") 메서드를 사용한다.
import io
import json
import base64
import requests
from PIL import Image


url = "http://127.0.0.1:8000/predict"
headers = {"content-type": "application/json"}

image = Image.open("./11 모델 배포/dog.jpg")
with io.BytesIO() as buffer:
    image.save(buffer, format="JPEG")
    buffer.seek(0)
    bytes = buffer.read()
string = base64.b64encode(bytes).decode("utf-8")

response = requests.post(
    url=url,
    headers=headers,
    data=json.dumps({"base64": string})
)

print(response.status_code)
print(response.json())

 

 

app_fastapi.py 파일을 생성하여 VGG16 모델 클래스를 정의한다.

  • preprocessing 메서드를 보면 입력으로 받은 Base64로 인코딩된 데이터를 디코딩한다. 디코딩된 이진 데이터는 io.BytesIO로 이진 스트림으로 변환한 뒤 PIL Image 객체로 읽어들인다.
# app_fastapi.py
import io
import torch
import base64
from PIL import Image
from torch.nn import functional as F
from torchvision import models, transforms


class VGG16Model:
    def __init__(self, weight_path):
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.transform = transforms.Compose(
            [
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=[0.48235, 0.45882, 0.40784],
                    std=[1.0 / 255.0, 1.0 / 255.0, 1.0 / 255.0]
                )
            ]
        )
        self.model = models.vgg16(num_classes=2).to(self.device)
        self.model.load_state_dict(torch.load(weight_path, map_location=self.device))
        self.model.eval()

    def preprocessing(self, data):
        decode = base64.b64decode(data)
        bytes = io.BytesIO(decode)
        image = Image.open(bytes)
        input_data = self.transform(image).to(self.device)
        return input_data
    
    @torch.no_grad()
    def predict(self, input):
        input_data = self.preprocessing(input)
        outputs = self.model(input_data.unsqueeze(0))
        probs = F.softmax(outputs, dim=-1)
        
        index = int(probs[0].argmax(axis=-1))
        label = "개" if index == 1 else "고양이"
        score = float(probs[0][index])

        return {
            "label": label,
            "score": score
        }

 

 

 

  • uvicorn은 빠르고 강력한 ASGI(Asynchronous Server Gateway Interface) 웹 서버이다. ASGI는 Python의 비동기 웹 애플리케이션을 위한 표준 인터페이스로, asyncio를 기반으로 하는 웹 프레임워크와 서버 간의 통신을 가능하게 한다. 패스트 API와 함께 자주 사용되는 웹 서버다.
  • pydantic은 파이썬에서 데이터 유효성 검사, 구조화 및 직렬화를 위한 라이브러리다. 이를 사용해 데이터 모델을 정의하고, 입력 데이터의 유효성을 검사하고, 데이터를 직렬화하여 다른 형식으로 변환할 수 있다.
  • app, vgg를 정의하고 Item이라는 BaseModel을 정의한다. 이 클래스는 입력 데이터를 정의하는데, 여기서는 Base64로 인코딩된 이미지 데이터를 받는다.
  • get_model() 함수는 종속성으로 사용되며, 이 함수에 의존하는 항목에 대해 vgg 변수를 반환한다. 패스트 API는 종속성 주입을 통해 각 요청에 대한 새로운 인스턴스를 생성하고, 상태를 공유하지 않는 방식으로 구현해야 한다. 단, 파이토치는 모델 초기화 과정이 필요하므로 종속성 함수 get_model을 Depends로 감싸서 model 인수에 주입한다. 이를 통해 동시성 문제를 피할 수 있다.
  • 플라스크와 달리 패스트 API는 uvicorn.run으로 애플리케이션을 실행한다. app 매개변수는 '파이썬 스크립트 파일 이름:패스트 API 애플리케이션 변수 이름'으로 입력한다.
# app_fastapi.py
import uvicorn
from pydantic import BaseModel
from fastapi import FastAPI, Depends, HTTPException


app = FastAPI()
vgg = VGG16Model(weight_path='./VGG16.pt')

class Item(BaseModel):
    base64: str

def get_model():
    return vgg

@app.post("/predict")
async def inference(item: Item, model: VGG16Model = Depends(get_model)):
    try:
        return model.predict(item.base64)
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))


if __name__=="__main__":
    uvicorn.run(
        app="app_fastapi:app",
        host="0.0.0.0",
        port=8000,
        workers=2,
    )

 

 

터미널에 스크립트를 실행한다.

python app_fastapi.py
INFO:     Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit)
INFO:     Started parent process [4156]
INFO:     Started server process [5912]
INFO:     Waiting for application startup.
INFO:     Application startup complete.
INFO:     Started server process [10652]
INFO:     Waiting for application startup.
INFO:     Application startup complete.
INFO:     127.0.0.1:53284 - "POST /predict HTTP/1.1" 200 OK

 

 

작성해둔 추론 파일을 실행하면 추론 결과를 얻을 수 있다.

200
{'label': '개', 'score': 1.0}