모델 서빙(Model Serving)은 훈련된 머신러닝 모델을 실제 운영 환경에서 사용할 수 있도록 프로세스 또는 시스템을 구축하는 것을 의미한다. 이는 클라이언트의 요청에 따라 서버에서 모델을 호출하고, 예측 결과를 반환하는 인터페이스를 구축하는 과정을 포함한다.
- 클라이언트 : 서버에 서비스를 요청하는 주체. 서비스를 요청하는 PC, 스마트폰, 웹브라우저 등
- 서버 : 클라이언트에게 요청을 받고 요청에 따라 서비스를 제공하는 컴퓨터 시스템
- 인터페이스 : 사용자와 서비스 사이의 상호작용을 가능하게 하는 방법을 의미. UI/UX, API, 데이터 형식 등 서비스에 접근하고 상호작용할 수 있는 모든 요소를 포함
1. 모델 서빙 웹 프레임워크
모델 서빙 웹 프레임워크는 머신러닝 모델을 웹 서비스로 제공하고 관리하기 위한 도구를 의미한다. 이 프레임워크는 실시간으로 요청에 따라 모델을 호출하고 결과를 반환하는 서버를 구성할 수 있으며, 모델의 배포, 스케일링, 모니터링, 버전 관리, 요청 처리 등과 같은 기능도 처리할 수 있다.
파이썬에서는 Flask, Django, Fast API가 있다. 세 프레임워크 모두 RESTful API를 개발할 수 있는 표준화된 방법을 제공한다. RESTful API는 서버와 클라이언트 간의 역할을 명확하게 분리하여, 서버와 클라이언트를 독립적으로 개발하고 유지할 수 있다.
1.1 Flask
플라스크를 이용하여 BERT를 활용한 간단한 API를 구현해 본다.
먼저 BERT 모델 클래스 선언한다. 메서드 선언시 @classmethod 데코레이터를 사용하여 정의한다. 클래스 메서드로 정의하면 모든 인스턴스에서 모델의 상태가 공유된다.
# app_flask.py
import torch
from torch.nn import functional as F
from transformers import BertTokenizer, BertForSequenceClassification
class BertModel:
device = "cuda" if torch.cuda.is_available() else "cpu"
@classmethod
def load_model(cls, weight_path):
cls.tokenizer = BertTokenizer.from_pretrained(
pretrained_model_name_or_path="bert-base-multilingual-cased",
do_lower_case=False,
)
cls.model = BertForSequenceClassification.from_pretrained(
pretrained_model_name_or_path="bert-base-multilingual-cased",
num_labels=2
).to(cls.device)
cls.model.load_state_dict(torch.load(weight_path, map_location=cls.device))
cls.model.eval()
@classmethod
def preprocessing(cls, data):
input_data = cls.tokenizer(
text=data,
padding="longest",
truncation=True,
return_tensors="pt"
).to(cls.device)
return input_data
@classmethod
@torch.no_grad()
def predict(cls, input):
input_data = cls.preprocessing(input)
outputs = cls.model(**input_data).logits
probs = F.softmax(outputs, dim=-1)
index = int(probs[0].argmax(axis=-1))
label = "긍정" if index == 1 else "부정"
score = float(probs[0][index])
return {
"label": label,
"score": score
}
Flask 클래스로 애플리케이션을 생성한다. __name__은 현재 모듈의 이름을 나타내며, 이는 Flask에게 현재 모듈이나 패키지의 위치를 알려주는 역할을 한다.
- @app.route("/predict", methods=["POST"])은 데코레이터로 특정 URL에 대한 요청을 어떻게 처리할지를 지정한다. "/predict"는 엔드포인트의 URL을 나타내며, 클라이언트가 서버에 데이터를 보낼 때 사용된다. 그리고 클라이언트는 이 URL로 POST 요청을 보낼 수 있다.
- inference() 함수는 POST 요청을 처리하고 요청 받은 JSON 데이터에서 텍스트를 추출한다. 그런 다음 예측을 수행하고, 그 결과를 JSON 형식으로 반환한다.
- if __name__ == "__main__":으로 이 파이썬 파일이 메인으로 실행될 때 모델을 로드하고, Flask 애플리케이션을 0.0.0.0 주소와 8000 포트에서 실행한다.
# app_flask.py
import json
from flask import Flask, request, Response
app = Flask(__name__)
@app.route("/predict", methods=["POST"])
def inference():
data = request.get_json()
text = data["text"]
try:
return Response(
response=json.dumps(BertModel.predict(text), ensure_ascii=False),
status=200,
mimetype="application/json",
)
except Exception as e:
return Response(
response=json.dumps({"error": str(e)}, ensure_ascii=False),
status=500,
mimetype="application/json",
)
if __name__ == "__main__":
BertModel.load_model(weight_path="./BertForSequenceClassification.pt")
app.run(host="0.0.0.0", port=8000)
터미널에서 플라스크 애플리케이션을 실행한다.
python app_flask.py
* Serving Flask app 'app_flask'
* Debug mode: off
WARNING: This is a development server. Do not use it in a production deployment. Use a production WSGI server instead.
* Running on all addresses (0.0.0.0)
* Running on http://127.0.0.1:8000
* Running on http://XXX.XXX.X.X:8000
Press CTRL+C to quit
이제 모델 추론을 요청하는 새로운 파이썬 파일을 생성해 실행한다.
- 파일의 콘텐츠 유형은 HTTP 헤더에 Content-Type으로 지정되며, 클라이언트와 서버 간에 어떤 유형의 데이터를 주고받을지를 알려준다. MIME(MultiPurpose Internet Mail Expression) 유형의 형식을 따르며, "주요 유형/서브 유형"으로 표기된다.
- requests.post() 함수를 사용해 post 요청을 보낸다. 반환된 응답은 response 변수에 저장된다.
import json
import requests
url = "http://127.0.0.1:8000/predict"
headers = {"content-type": "application/json"}
response = requests.post(
url=url,
headers=headers,
data=json.dumps({"text": "정말 재미 있어요!"})
)
print(response.status_code)
print(response.json())
200
{'label': '긍정', 'score': 0.9859620928764343}
아래는 주로 사용하는 HTTP 메서드를 정리한 표이다. 이밖에도 PATCH, HEAD, OPTIONS 등이 있다.
HTTP 메서드 | 처리 작업 | 주요 용도 |
GET | 읽기 | 서버로부터 정보를 요청하기 위함 |
POST | 생성 | 서버에 데이터를 제출하기 위함 |
PUT | 갱신 | 서버에 리소스를 업데이트하기 위함 |
DELETE | 삭제 | 서버에서 리소스를 삭제하기 위함 |
1.2 Fast API
Fast API로 VGG16을 활용한 간단한 API를 구현해 본다.
이번에는 추론 요청 파이썬 코드부터 살펴본다. Flask 예제와 달라진 점은 이미지를 Base64로 인코딩하는 부분이 추가되었다.
- 먼저 io.BytesIO()를 사용하여 바이트 스트림을 메모리에 쓰는 임시 버퍼를 생성한다.
- image를 JPEG 형식으로 버퍼에 저장한다.
- buffer.seek(0)로 임시 버퍼의 파일 포인터를 처음 위치(0번째 바이트)로 이동시킨다. 이미지를 저장한 후에는 파일 포인터가 마지막 위치에 유지된다. 따라서 이를 생략하면 buffer.read()를 호출할 때 파일의 끝에서부터 읽기를 시도하게 되므로, 아무런 데이터도 읽을 수 없다.
- 버퍼에서 데이터를 읽어와 bytes에 저장한다. 그런 다음 바이트 스트림을 Base64로 인코딩하고 그 결과를 문자열로 반환하기 위해 decode("utf-8") 메서드를 사용한다.
import io
import json
import base64
import requests
from PIL import Image
url = "http://127.0.0.1:8000/predict"
headers = {"content-type": "application/json"}
image = Image.open("./11 모델 배포/dog.jpg")
with io.BytesIO() as buffer:
image.save(buffer, format="JPEG")
buffer.seek(0)
bytes = buffer.read()
string = base64.b64encode(bytes).decode("utf-8")
response = requests.post(
url=url,
headers=headers,
data=json.dumps({"base64": string})
)
print(response.status_code)
print(response.json())
app_fastapi.py 파일을 생성하여 VGG16 모델 클래스를 정의한다.
- preprocessing 메서드를 보면 입력으로 받은 Base64로 인코딩된 데이터를 디코딩한다. 디코딩된 이진 데이터는 io.BytesIO로 이진 스트림으로 변환한 뒤 PIL Image 객체로 읽어들인다.
# app_fastapi.py
import io
import torch
import base64
from PIL import Image
from torch.nn import functional as F
from torchvision import models, transforms
class VGG16Model:
def __init__(self, weight_path):
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.transform = transforms.Compose(
[
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.48235, 0.45882, 0.40784],
std=[1.0 / 255.0, 1.0 / 255.0, 1.0 / 255.0]
)
]
)
self.model = models.vgg16(num_classes=2).to(self.device)
self.model.load_state_dict(torch.load(weight_path, map_location=self.device))
self.model.eval()
def preprocessing(self, data):
decode = base64.b64decode(data)
bytes = io.BytesIO(decode)
image = Image.open(bytes)
input_data = self.transform(image).to(self.device)
return input_data
@torch.no_grad()
def predict(self, input):
input_data = self.preprocessing(input)
outputs = self.model(input_data.unsqueeze(0))
probs = F.softmax(outputs, dim=-1)
index = int(probs[0].argmax(axis=-1))
label = "개" if index == 1 else "고양이"
score = float(probs[0][index])
return {
"label": label,
"score": score
}
- uvicorn은 빠르고 강력한 ASGI(Asynchronous Server Gateway Interface) 웹 서버이다. ASGI는 Python의 비동기 웹 애플리케이션을 위한 표준 인터페이스로, asyncio를 기반으로 하는 웹 프레임워크와 서버 간의 통신을 가능하게 한다. 패스트 API와 함께 자주 사용되는 웹 서버다.
- pydantic은 파이썬에서 데이터 유효성 검사, 구조화 및 직렬화를 위한 라이브러리다. 이를 사용해 데이터 모델을 정의하고, 입력 데이터의 유효성을 검사하고, 데이터를 직렬화하여 다른 형식으로 변환할 수 있다.
- app, vgg를 정의하고 Item이라는 BaseModel을 정의한다. 이 클래스는 입력 데이터를 정의하는데, 여기서는 Base64로 인코딩된 이미지 데이터를 받는다.
- get_model() 함수는 종속성으로 사용되며, 이 함수에 의존하는 항목에 대해 vgg 변수를 반환한다. 패스트 API는 종속성 주입을 통해 각 요청에 대한 새로운 인스턴스를 생성하고, 상태를 공유하지 않는 방식으로 구현해야 한다. 단, 파이토치는 모델 초기화 과정이 필요하므로 종속성 함수 get_model을 Depends로 감싸서 model 인수에 주입한다. 이를 통해 동시성 문제를 피할 수 있다.
- 플라스크와 달리 패스트 API는 uvicorn.run으로 애플리케이션을 실행한다. app 매개변수는 '파이썬 스크립트 파일 이름:패스트 API 애플리케이션 변수 이름'으로 입력한다.
# app_fastapi.py
import uvicorn
from pydantic import BaseModel
from fastapi import FastAPI, Depends, HTTPException
app = FastAPI()
vgg = VGG16Model(weight_path='./VGG16.pt')
class Item(BaseModel):
base64: str
def get_model():
return vgg
@app.post("/predict")
async def inference(item: Item, model: VGG16Model = Depends(get_model)):
try:
return model.predict(item.base64)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__=="__main__":
uvicorn.run(
app="app_fastapi:app",
host="0.0.0.0",
port=8000,
workers=2,
)
터미널에 스크립트를 실행한다.
python app_fastapi.py
INFO: Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit)
INFO: Started parent process [4156]
INFO: Started server process [5912]
INFO: Waiting for application startup.
INFO: Application startup complete.
INFO: Started server process [10652]
INFO: Waiting for application startup.
INFO: Application startup complete.
INFO: 127.0.0.1:53284 - "POST /predict HTTP/1.1" 200 OK
작성해둔 추론 파일을 실행하면 추론 결과를 얻을 수 있다.
200
{'label': '개', 'score': 1.0}
'책 > 파이토치 트랜스포머를 활용한 자연어 처리와 컴퓨터비전 심층학습' 카테고리의 다른 글
11 모델 배포 (5) 데모 애플리케이션 (0) | 2024.03.17 |
---|---|
11 모델 배포 (4) 도커(Docker), VSCode 실습 (0) | 2024.03.17 |
11 모델 배포 (2) 모델 경량화 (0) | 2024.03.15 |
11 모델 배포 (1) 모델 경량화 (0) | 2024.03.15 |
10 비전 트랜스포머 (3) CvT pytorch 실습 (0) | 2024.03.12 |