llm-qa/llm-qa/llm_qa/services/upsert.py

44 lines
1.2 KiB
Python

import logging
from langchain.schema.document import Document
from langchain_core.embeddings import Embeddings
from llm_qa.chains.text_splitters.markdown_header_text_splitter import (
markdown_3_headers_text_splitter_chain,
)
from llm_qa.chains.text_splitters.text_splitter import (
recursive_character_text_splitter_chain,
)
from llm_qa.models.upsert import TextType
from llm_qa.vectorstores.qdrant import upsert_documents
logger = logging.getLogger(__name__)
async def upsert_text(
text: str,
text_type: TextType,
collection: str,
embeddings: Embeddings,
qdrant_url: str,
) -> int:
match text_type:
case TextType.PLAIN_TEXT:
text_splitter_chain = recursive_character_text_splitter_chain
case TextType.MARKDOWN:
text_splitter_chain = markdown_3_headers_text_splitter_chain
case _:
raise ValueError(f"Unknown text type: `{text_type}`") # noqa: TRY003
text_chunks = await text_splitter_chain.ainvoke(text)
documents = [Document(page_content=chunk) for chunk in text_chunks]
await upsert_documents(
documents=documents,
embeddings=embeddings,
qdrant_url=qdrant_url,
collection=collection,
)
return len(documents)