본문 바로가기
🟣 AI & ML

마구잡이 질문에도 강건한 RAG 시스템 만들기: Query Transformation

by 제리강 2024. 7. 20.

TL;DR

RAG 기반 어플리케이션은 질문 텍스트를 바탕으로 정보를 검색하기 때문에, 사용자의 무작위적인 질문 형식에 매우 취약한 시스템이다.
이를 해결하기 위해 모델이 질문의 핵심 내용을 잘 인식할 수 있도록 질문을 변형하여 재구성할 수 있는데, 이를 'query transformation'라 한다.

본 포스트에서는 LangChain 환경을 바탕으로 이러한 변형 방법 중 하나인 'Rewrite'를 알아보고 실제로 구현을 해보도록 한다.

 

 

 

Query transformation은 다양한 방법이 있지만, 이 포스트에서 소개할 방법은 'Rewrite'이다. 말 그대로 사용자 쿼리에 LLM을 한번 더 적용하여, 사용자가 알고자 하는 핵심 정보로 재작성한다. 먼저, 필요한 라이브러리들을 설치해보자.

 

DuckDuckGoSearchAPIWrapper는 DuckDB에서 제공하는 검색 엔진 API다. RAG는 보통 벡터DB에 저장된 정보를 검색하여 사용하지만, 이렇게 웹에서 정보를 검색해서 사용할 수도 있다. 해당 라이브러리가 설치되어 있지 않다면, 아래 라인의 pip install 명령어를 이용해 설치하자. 그리고, OpenAI API 키도 설정해준다.

from langchain_community.utilities import DuckDuckGoSearchAPIWrapper
# pip install -U -q duckduckgo-search
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_openai import ChatOpenAI
import os

os.environ['OPENAI_API_KEY'] = 'YOUR_OPENAI_API_KEY'

 

 

 

RAG 타입의 프롬프트를 설정해준다. RAG 프롬프트는 일반적으로 사용자 질문인 question과 검색해오는 정보를 참고할 context 파트로 구성된다. 설정한 프롬프트 텍스트는 ChatPromptTemplate 개체에 할당해준다. 그 다음엔, LLM 및 검색 엔진을 설정해준다. 최근 출시한, GPT-3.5-turbo의 대체 모델 GPT-4o-mini를 사용해보도록 하자.

template = """Answer the users question based only on the following context:
Question: {question}
Context: {context}
Answer:
"""

prompt = ChatPromptTemplate.from_template(template)
model = ChatOpenAI(temperature=0, model_name='gpt-4o-mini')
search = DuckDuckGoSearchAPIWrapper()

def retriever(query):
    return search.run(query)

 

 

 

이제 chain을 설정해준다. LangChain의 최근 버전에서는, 다음과 같이, '|' 기호로 모듈을 연결하여 chain을 구성한다. RunnablePassthrough()는 모듈 간에 데이터를 변경 없이 통과시키는 역할을 하고, StrOutputParser()는 json 등으로 표현되는 모델 출력 결과로부터 문자열을 추출한다.

chain = (
    {"context": retriever, "question": RunnablePassthrough()}
    | prompt
    | model
    | StrOutputParser()
)

 

 

 

먼저 정상적인 질문을 해보자. 답변이 자세히 잘 출력된다.

simple_query = "what is LangChain?"
chain.invoke(simple_query)

 

'LangChain is a framework designed for building a variety of applications 
powered by large language models (LLMs). 
It simplifies AI application development by integrating different language models 
and tools, enabling features like document analysis, summarization, 
and chatbot creation. LangChain allows developers to create context-aware 
applications that can perform tasks such as text generation, translation, and more. 
It also supports integration with other libraries and frameworks, 
making it a versatile toolkit for crafting sophisticated AI applications.'

 

 

 

이제, 질문 정보는 포함되어 있지만 노이즈를 포함한 질문을 해보자. 모델이 답변을 하지 못하는 것을 볼 수 있다.

참고로, 노이즈를 질문 앞에만 넣으면 생각보다 정상적으로 검색되는 편이다. 하지만, 노이즈를 질문 양쪽에 넣으면 답변 성능이 급격히 하락한다.

distracted_query = "man that sam bankman fried trial was crazy! 
What is LangChain? nssiissi desk chairman!"
chain.invoke(distracted_query)

 

'The context provided does not contain any information about LangChain. 
It primarily discusses the trial and conviction of Sam Bankman-Fried.'

 

 

 

쿼리를 재작성하기 위한 rewrite 프롬프트를 설계한다. 사용자의 원래 질문을 받아서, 더 나은 쿼리를 출력하도록 지시한다. 아래와 같이 프롬프트를 LangChain hub에서 불러올 수도 있다(같은 프롬프트이다). 그리고 재작성한 쿼리를 전처리할 parse 함수를 구성한다. 그 다음엔 rewrite를 수행할 chain을 별도로 구성한다.

template = """Provide a better search query for
search engine to answer the given question, end
the queries with ’**’. 
Question: {question} 
Answer:"""
rewrite_prompt = ChatPromptTemplate.from_template(template)

# or
from langchain import hub
rewrite_prompt = hub.pull("langchain-ai/rewrite")

# Parser to remove the `**`
def _parse(text):
    return text.strip('"').strip("**")
    
rewriter = rewrite_prompt | 
           ChatOpenAI(temperature=0, model_name='gpt-4o') | 
           StrOutputParser() | 
           _parse

 

 

 

이제 재작성된 쿼리를 확인해보자. 질문의 핵심 정보를 잘 추출한 것을 볼 수 있다. 언어 모델 특성 상, 실행할 때마다 결과는 조금씩 다를 수 있다.

rewriter.invoke({"question": distracted_query})

 

'LangChain overview and features'

 

 

 

이제 rewriter chain과 기존의 chain을 결합하여, query rewrite를 자동으로 수행하는 chain을 구성하고 답변을 확인해보자. 답변이 잘 출력되는 것을 볼 수 있다.

rewrite_retrieve_read_chain = (
    {
        "context": {"x": RunnablePassthrough()} | rewriter | retriever,
        "question": RunnablePassthrough(),
    }
    | prompt
    | model
    | StrOutputParser()
)

rewrite_retrieve_read_chain.invoke(distracted_query)

 

'LangChain is a cutting-edge framework designed for developing applications 
powered by language models. It allows for the creation of complex applications 
that can perform a variety of tasks by combining several components or chains 
in a single pipeline. LangChain enhances context awareness and reasoning, 
enabling applications to remember previous interactions with users, 
which improves the quality of the output. 
It simplifies the process of organizing large volumes of data for language models, 
making it easier for developers to create advanced AI applications.'

 

 

마치며

Rewriter를 이용하여 노이즈가 포함된 질문에도 잘 작동하는 강건한(robust) LLM 어플리케이션을 구성해보았다. Query transformation만 잘 활용해도 훨씬 개선된 RAG 시스템을 구성할 수 있다. 다음 포스트에서도, RAG 시스템을 개선할 수 있는 다양한 방법들을 소개해 보겠다.

 

 

댓글