44 lines
1.2 KiB
Python
44 lines
1.2 KiB
Python
import logging
|
|
|
|
from langchain.schema.document import Document
|
|
from langchain_core.embeddings import Embeddings
|
|
|
|
from llm_qa.chains.text_splitters.markdown_header_text_splitter import (
|
|
markdown_3_headers_text_splitter_chain,
|
|
)
|
|
from llm_qa.chains.text_splitters.text_splitter import (
|
|
recursive_character_text_splitter_chain,
|
|
)
|
|
from llm_qa.models.upsert import TextType
|
|
from llm_qa.vectorstores.qdrant import upsert_documents
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
async def upsert_text(
|
|
text: str,
|
|
text_type: TextType,
|
|
collection: str,
|
|
embeddings: Embeddings,
|
|
qdrant_url: str,
|
|
) -> int:
|
|
match text_type:
|
|
case TextType.PLAIN_TEXT:
|
|
text_splitter_chain = recursive_character_text_splitter_chain
|
|
case TextType.MARKDOWN:
|
|
text_splitter_chain = markdown_3_headers_text_splitter_chain
|
|
case _:
|
|
raise ValueError(f"Unknown text type: `{text_type}`") # noqa: TRY003
|
|
|
|
text_chunks = await text_splitter_chain.ainvoke(text)
|
|
|
|
documents = [Document(page_content=chunk) for chunk in text_chunks]
|
|
|
|
await upsert_documents(
|
|
documents=documents,
|
|
embeddings=embeddings,
|
|
qdrant_url=qdrant_url,
|
|
collection=collection,
|
|
)
|
|
return len(documents)
|