66 lines
2.2 KiB
Python
66 lines
2.2 KiB
Python
import logging
|
|
import time
|
|
|
|
from langchain.chat_models.base import BaseChatModel
|
|
from langchain.vectorstores.qdrant import Qdrant
|
|
from langchain_core.messages import BaseMessage
|
|
from qdrant_client import AsyncQdrantClient, QdrantClient
|
|
|
|
from llm_qa.chains.chat import get_chat_chain
|
|
from llm_qa.embeddings.tei import TeiEmbeddings
|
|
from llm_qa.models.chat import ChatMessage, ChatResponse
|
|
from llm_qa.models.prompts import Prompts
|
|
from llm_qa.models.source import Source
|
|
from llm_qa.models.tei import TeiConfig
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
async def chat(
|
|
messages: list[BaseMessage],
|
|
collection_name: str,
|
|
prompts: Prompts,
|
|
tei_config: TeiConfig,
|
|
tei_rerank_config: TeiConfig,
|
|
qdrant_host: str,
|
|
qdrant_grpc_port: int,
|
|
retrieve_count: int,
|
|
rerank_count: int | None,
|
|
chat_model: BaseChatModel,
|
|
) -> ChatResponse:
|
|
tei_embeddings = TeiEmbeddings(tei_config=tei_config)
|
|
tei_rerank_embeddings = TeiEmbeddings(tei_config=tei_rerank_config)
|
|
qdrant_client = QdrantClient(
|
|
location=qdrant_host, grpc_port=qdrant_grpc_port, prefer_grpc=True
|
|
)
|
|
async_qdrant_client = AsyncQdrantClient(
|
|
location=qdrant_host, grpc_port=qdrant_grpc_port, prefer_grpc=True
|
|
)
|
|
qdrant_vectorstore = Qdrant(
|
|
client=qdrant_client,
|
|
async_client=async_qdrant_client,
|
|
collection_name=collection_name,
|
|
embeddings=tei_embeddings,
|
|
)
|
|
chat_chain = get_chat_chain(
|
|
prompts=prompts,
|
|
vectorstore_retriever=qdrant_vectorstore.as_retriever(
|
|
search_kwargs={"k": retrieve_count}
|
|
),
|
|
tei_rerank_embeddings=tei_rerank_embeddings,
|
|
rerank_count=rerank_count,
|
|
chat_model=chat_model,
|
|
)
|
|
start_time = time.time()
|
|
chain_output = await chat_chain.ainvoke(messages)
|
|
elapsed_time = time.time() - start_time
|
|
logger.info("Chat chain finished in %.2f seconds", elapsed_time)
|
|
return ChatResponse(
|
|
response_message=ChatMessage.from_langchain_message(chain_output["response"]),
|
|
sources=[
|
|
Source(content=document.page_content)
|
|
for document in chain_output["documents"]
|
|
],
|
|
retrieval_query=chain_output["retrieval_query"],
|
|
)
|