logo
Published on

Storing Chat History like ChatGPT for your RAG Pipeline with previous context — LlamaIndex, FastAPI

Authors
  • Name
    Twitter

If you are like me who find reading blogs boring and straight away jump to exploring the code, here is the GitHub Repo. Disclaimer, this blog is not about building over-the-top RAG Pipelines or Agents. It’s about a basic concept to help you with this feature while building your next project.

Moreover, you will find some Harry Potter Easter Eggs. Which one is your favourite movie?

Demo Video:

Storing Chat History

Entity-Relationship Diagram:

There are just two tables required to fulfil this task. A conversation is created when a user has a new chat, and all the queries and responses of that conversation are added to messages.

Given below are the functions that are used to create both conversation and messages. These functions are called while querying.

def create_conversation(db: Session, name: str) -> Conversation:  
    logger.debug(f"Creating conversation with name: {name}")  
    try:  
        conversation = Conversation(name=name, dateCreated=datetime.utcnow())  
        db.add(conversation)  
        db.commit()  
        db.refresh(conversation)  
        logger.info(f"Created conversation with ID: {conversation.id}")  
        return conversation.id  
    except Exception as e:  
        logger.error(f"Error creating conversation: {str(e)}")  
        db.rollback()  
        raise HTTPException(status_code=500, detail=str(e))  
  
def create_message(db: Session, conversation_id: str, content: str, role: str = "user") -> Message:  
    logger.debug(f"Creating message for conversation {conversation_id}")  
    try:  
        message = Message(  
            conversation_id=conversation_id,  
            content=content,  
            role=role,  
            dateCreated=datetime.utcnow()  
        )  
        db.add(message)  
        db.commit()  
        db.refresh(message)  
        logger.info(f"Created message with ID: {message.id}")  
        return message  
    except Exception as e:  
        logger.error(f"Error creating message: {str(e)}")  
        db.rollback()  
        raise HTTPException(status_code=500, detail=str(e))

There are two APIs to improve the user experience: one to fetch all the conversations in descending order by dateCreated and another to fetch all the messages of a particular conversation. For more details, check the repo.

@conversation_router.get("/fetch")  
async def get_conversations(db = Depends(get_db)):  
    return fetch_conversations(db)  
  
@conversation_router.get("/messages/")  
async def get_messages(conversation_id: uuid.UUID, db = Depends(get_db)):  
    return fetch_messages(db, str(conversation_id))

Querying with the previous context

Now comes the interesting part. Since we are dynamically storing messages, we can use those to add previous context to every new query we have in a conversation.

When I was going through LlamaIndex docs, similar to query_engine there is a chat_engine. In the docs, you will find almost all examples are using ChatMemoryBuffer but that won’t work out for us, the reason being it stays for that particular session but won’t scale up for multiple conversations.

# From LlamaIndex docs  
from llama_index.core.memory import ChatMemoryBuffer  
  
memory = ChatMemoryBuffer.from_defaults(token_limit=1500)  
  
chat_engine = index.as_chat_engine(  
    chat_mode="context",  
    memory=memory,  
    system_prompt=(  
        "You are a chatbot, able to have normal interactions, as well as talk"  
        " about an essay discussing Paul Grahams life."  
    ),  
)

But chat method has an optional parameter of chat_history, which receives the list of messages in the form of class ChatMessage. Here is the function that is used to structure the list of messages in the required form to be accessible by chat_engine, and I will be taking only the last 4 messages, the reason being LLMs have small context length and also saving money on tokens.

def get_chat_history( chat_messages:List[dict],) -> List[ChatMessage]:  
    
    chat_history = []  
    for message in chat_messages:  
        chat_history.append(ChatMessage(content=message['content'], role=message['role']))  
    return chat_history[-4:]

Lastly, we will combine all of this in the query function

def engine(query,conersation_id,db: Session):  
    logger.debug(f"Received query: {query} and conversation_id: {conersation_id}")  
    chat_history = []  
    if not(conersation_id):  
        conversation_id = create_conversation(db=db,name=query[:20])  
    else:  
        conversation_id = conersation_id  
        messages = fetch_messages(db=db, conversation_id=conversation_id)  
        chat_history.extend(get_chat_history(chat_messages=messages))  
      
    logger.debug(f"Chat history: {chat_history}")  
    create_message(db=db, conversation_id=conversation_id, content=query, role=MessageRole.USER)  
    response = chat_engine.chat(query, chat_history=chat_history)  
    create_message(db=db, conversation_id=conversation_id, content=response.response, role=MessageRole.ASSISTANT)  
    logger.info(f"Query response: {response.response} and conversation_id: {conversation_id}")  
    return {  
        "response": response.response,  
        "conversation_id": conversation_id  
    }

Wrapping up

If you have come this far, thank you for giving your time. Hope you learned something new today. Soon, I will be posting many such blogs about my learnings.