import logging import time from langchain.chat_models.base import BaseChatModel from langchain.vectorstores.qdrant import Qdrant from langchain_core.messages import BaseMessage from qdrant_client import AsyncQdrantClient, QdrantClient from llm_qa.chains.chat import get_chat_chain from llm_qa.embeddings.tei import TeiEmbeddings from llm_qa.models.chat import ChatMessage, ChatResponse from llm_qa.models.prompts import Prompts from llm_qa.models.source import Source from llm_qa.models.tei import TeiConfig logger = logging.getLogger(__name__) async def chat( messages: list[BaseMessage], collection_name: str, prompts: Prompts, tei_config: TeiConfig, tei_rerank_config: TeiConfig, qdrant_host: str, qdrant_grpc_port: int, retrieve_count: int, rerank_count: int | None, chat_model: BaseChatModel, ) -> ChatResponse: tei_embeddings = TeiEmbeddings(tei_config=tei_config) tei_rerank_embeddings = TeiEmbeddings(tei_config=tei_rerank_config) qdrant_client = QdrantClient( location=qdrant_host, grpc_port=qdrant_grpc_port, prefer_grpc=True ) async_qdrant_client = AsyncQdrantClient( location=qdrant_host, grpc_port=qdrant_grpc_port, prefer_grpc=True ) qdrant_vectorstore = Qdrant( client=qdrant_client, async_client=async_qdrant_client, collection_name=collection_name, embeddings=tei_embeddings, ) chat_chain = get_chat_chain( prompts=prompts, vectorstore_retriever=qdrant_vectorstore.as_retriever( search_kwargs={"k": retrieve_count} ), tei_rerank_embeddings=tei_rerank_embeddings, rerank_count=rerank_count, chat_model=chat_model, ) start_time = time.time() chain_output = await chat_chain.ainvoke(messages) elapsed_time = time.time() - start_time logger.info("Chat chain finished in %.2f seconds", elapsed_time) return ChatResponse( response_message=ChatMessage.from_langchain_message(chain_output["response"]), sources=[ Source(content=document.page_content) for document in chain_output["documents"] ], retrieval_query=chain_output["retrieval_query"], )