llm-qa/llm-qa/llm_qa/services/chat.py

66 lines
2.2 KiB
Python

import logging
import time
from langchain.chat_models.base import BaseChatModel
from langchain.vectorstores.qdrant import Qdrant
from langchain_core.messages import BaseMessage
from qdrant_client import AsyncQdrantClient, QdrantClient
from llm_qa.chains.chat import get_chat_chain
from llm_qa.embeddings.tei import TeiEmbeddings
from llm_qa.models.chat import ChatMessage, ChatResponse
from llm_qa.models.prompts import Prompts
from llm_qa.models.source import Source
from llm_qa.models.tei import TeiConfig
logger = logging.getLogger(__name__)
async def chat(
messages: list[BaseMessage],
collection_name: str,
prompts: Prompts,
tei_config: TeiConfig,
tei_rerank_config: TeiConfig,
qdrant_host: str,
qdrant_grpc_port: int,
retrieve_count: int,
rerank_count: int | None,
chat_model: BaseChatModel,
) -> ChatResponse:
tei_embeddings = TeiEmbeddings(tei_config=tei_config)
tei_rerank_embeddings = TeiEmbeddings(tei_config=tei_rerank_config)
qdrant_client = QdrantClient(
location=qdrant_host, grpc_port=qdrant_grpc_port, prefer_grpc=True
)
async_qdrant_client = AsyncQdrantClient(
location=qdrant_host, grpc_port=qdrant_grpc_port, prefer_grpc=True
)
qdrant_vectorstore = Qdrant(
client=qdrant_client,
async_client=async_qdrant_client,
collection_name=collection_name,
embeddings=tei_embeddings,
)
chat_chain = get_chat_chain(
prompts=prompts,
vectorstore_retriever=qdrant_vectorstore.as_retriever(
search_kwargs={"k": retrieve_count}
),
tei_rerank_embeddings=tei_rerank_embeddings,
rerank_count=rerank_count,
chat_model=chat_model,
)
start_time = time.time()
chain_output = await chat_chain.ainvoke(messages)
elapsed_time = time.time() - start_time
logger.info("Chat chain finished in %.2f seconds", elapsed_time)
return ChatResponse(
response_message=ChatMessage.from_langchain_message(chain_output["response"]),
sources=[
Source(content=document.page_content)
for document in chain_output["documents"]
],
retrieval_query=chain_output["retrieval_query"],
)