본문 바로가기

책/파이토치 트랜스포머를 활용한 자연어 처리와 컴퓨터비전 심층학습

11 모델 배포 (5) 데모 애플리케이션

데모 애플리케이션이란 기술이나 제품의 동작 및 기능을 보여주는 작은 규모의 응용 프로그램을 의미한다. 주로 웹 기반 인터페이스의 형태로 제공되는데, 예를 들어 자연어 처리 모델의 데모 애플리케이션은 사용자가 입력 문장을 제공하면 모델이 문장을 이해하고 응답을 생성하는 등의 기능을 제공할 수 있다.

 

데모 애플리케이션은 모델의 성능 평가 및 피드백 수집에도 중요한 역할을 한다. 사용자는 모델을 직접 사용하면서 발생하는 문제점이나 개선 사항을 제공할 수 있으며, 이는 모델의 품질 향상과 개선에 도움을 준다.

 

 

스트림릿(Streamlit)

스트림릿을 사용해 데모 애플리케이션을 구현해 본다. 스트림릿은 넘파이, 판다스, 이미지 라이브러리 등 다른 파이썬 패키지들과 호환되며 인터랙티브한 웹 애플리케이션을 만들 수 있도록 차트, 그리드, 맵 등 다양한 시각화 도구를 제공하는 강력한 라이브러리다.

 

스트림릿의 주요 구성요소

스크립트

  • 스트림릿 애플리케이션은 파이썬 스크립트로 작성되며, 실시간으로 애플리케이션을 업데이트하는 기능을 제공해 코드를 수정하면 자동으로 애플리케이션도 업데이트되어 최신 결과를 바로 확인할 수 있다. 또한 대화형 웹 애플리케이션으로 사용자의 입력을 받아 처리해 응답을 생성하여 사용자와 실시간으로 상호작용한다.

 

컴포넌트

  • 스트림릿은 다양한 컴포넌트 함수를 제공한다. 컴포넌트 함수는 애플리케이션에 요소를 추가하고 상호작용을 구현하는 데 사용된다.

 

캐싱

  • 스트림릿은 사용자와 상호작용이 있을 때마다 전체 프로세스를 다시 수행한다. 딥러닝 모델 초기화와 같이 계산이 복잡한 작업을 최적화하기 위해 캐시 기능을 제공한다.

 

배포

  • 간단한 명령을 사용하여 웹에 배포할 수 있다. 로컬 서버에 애플리케이션이 실행되며, 이를 통해 빠르게 애플리케이션을 공유하고 협업할 수 있다.

 

 

애플리케이션 배포

다음은 스트림릿을 사용하여 데이터 시각화 및 상호 작용 가능한 대시보드를 만드는 코드이다.

 

데이터프레임의 경로는 필자의 환경에 맞게 설정되었다.

import pandas as pd
import streamlit as st


st.set_page_config(
    page_title="데모 애플리케이션",
    page_icon=":shark:",
    layout="wide",
)

df = pd.read_csv("../../datasets/non_linear.csv")

st.header(body="Demo Application")
st.subheader(body="non_linear.csv")

x = st.sidebar.selectbox(label="X 축", options=df.columns, index=0)
y = st.sidebar.selectbox(label="Y 축", options=df.columns, index=1)

col1, col2 = st.columns(2)
with col1:
    st.dataframe(data=df, height=500, use_container_width=True)
with col2:
    st.line_chart(data=df, x=x, y=y, height=500)

 

 

터미널에서 다음 명령어를 통해 실행할 수 있다.

streamlit run demo.py [-- script args]

 

 

기본 포트는 8501로 설정되어 http://localhost:8501로 접속할 수 있다.

 

가령 streamlit run demo.py --server.port=8080을 입력하면 8080 포트에 데모 애플리케이션을 배포할 수 있다.

 

스트림릿 라이브러리의 위젯, 컴포넌트 및 API는 API 레퍼런스치트시트에서 자세히 확인할 수 있다.

 

 

파이토치 모델 연동

데모 애플리케이션의 셀렉트 박스와 상호작용을 하면 전체 코드가 다시 실행된다. 만약 상호작용할 때마다 딥러닝 모델이나 대규모 데이터셋을 불러온다면 사용하기 어려울 것이다.

 

다음은 데이터 캐싱 방법을 보여준다.

import pandas as pd
import streamlit as st


st.set_page_config(
    page_title="데모 애플리케이션",
    page_icon=":shark:",
    layout="wide",
)

@st.cache_data
def load_data(path):
    return pd.read_csv(path)

df = load_data("../../datasets/non_linear.csv")

st.header(body="Demo Application")
st.subheader(body="non_linear.csv")

x = st.sidebar.selectbox(label="X 축", options=df.columns, index=0)
y = st.sidebar.selectbox(label="Y 축", options=df.columns, index=1)

col1, col2 = st.columns(2)
with col1:
    st.dataframe(data=df, height=500, use_container_width=True)
with col2:
    st.line_chart(data=df, x=x, y=y, height=500)

 

@st.cache_data 데코레이터를 활용하면, 동일한 계산을 반복해서 실행하지 않는다.

 

 

이번엔 딥러닝 모델을 캐싱하는 방법을 보여준다.

import streamlit as st
from transformers import pipeline


@st.cache_resource
def load_model():
    return pipeline(task="text-generation", model="gpt2")

model = load_model()

text = st.text_input("텍스트 입력", value="Barack Hussein Obama is")
if text:
    result = model(
        text_inputs=text,
        max_length=30,
        num_return_sequence=3,
        pad_token_id=model.tokenizer.eos_token_id,
    )
st.write(result)

 

@st.cache_resource는 파일 또는 리소스에 대한 캐싱을 처리하는 데 사용된다. 또한 @st.cache_data는 개별 사용자마다 리소스를 캐싱한다면 @st.cache_resource는 모든 사용자를 대상으로 리소스를 캐싱하는 전역 리소스 캐싱이다.

 

 

다음과 같이 텍스트를 변경해 입력하는 등 상호작용이 가능하다.

 

 

이번엔 YOLO를 사용해 이미지를 업로드하고 그 예측 결과를 확인해보자.

import cv2
import torch
import numpy as np
import streamlit as st
from PIL import Image
from ultralytics import YOLO
from ultralytics.utils.plotting import Annotator


@st.cache_resource
def load_model():
    return YOLO("yolov8m-pose.pt")


def predict(frame, iou=0.7, conf=0.25):
    results = model(
        source=frame,
        device="cuda" if torch.cuda.is_available() else "cpu",
        iou=0.7,
        conf=0.25,
        verbose=False,
    )
    result = results[0]
    return result


def draw_boxes(result, frame):
    for boxes in result.boxes:
        x1, y1, x2, y2, score, classes = boxes.data.squeeze().cpu().numpy()
        cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (0, 0, 255), 1)
    return frame


def draw_keypoints(result, frame):
    annotator = Annotator(frame, line_width=1)
    for kps in result.keypoints:
        annotator.kpts(kps)

        for idx, kp in enumerate(kps):
            x, y, score = kp.data.squeeze().cpu().numpy()
            
            if score > 0.5:
                cv2.circle(frame, (int(x), int(y)), 3, (0, 0, 255), cv2.FILLED)
                cv2.putText(frame, str(idx), (int(x), int(y)), cv2.FONT_HERSHEY_COMPLEX, 1, (0, 0, 255), 1)
    
    return frame

model = load_model()

uploaded_file=st.file_uploader("파일 선택", type=["PNG", "JPG", "JPEG"])
if uploaded_file is not None:
    print(uploaded_file.type)
    if "image" in uploaded_file.type:
        with st.spinner(text="포즈 정보 추출중..."):
            pil_image = Image.open(uploaded_file).convert("RGB")
            np_image = np.asarray(pil_image)
            cv_image = cv2.cvtColor(np_image, cv2.COLOR_RGB2BGR)

            result = predict(cv_image)
            image = draw_boxes(result, cv_image)
            image = draw_keypoints(result, image)
            st.image(image, channels="BGR")

 

이를 실행해서 이미지를 업로드해보니 403 에러가 발생했다. 403은 서버가 클라이언트의 요청을 거부한 것이다. 이를 해결하기 위해 구글링을 해봤더니 14시간 전에 나랑 같은 오류가 생긴 외국인 친구도 있었다. 뭔가 주기적으로 이런 문제가 발생하는 것 같으니 이부분은 나중에 해봐야겠다.