Add upsert endpoint

This commit is contained in:
Martin Popovski 2024-02-13 09:13:13 +01:00
parent faa2ecbfda
commit 4f58e0f0b9
Signed by: martinkozle
GPG Key ID: 0A5F2984DB008108
24 changed files with 617 additions and 2 deletions

View File

@ -5,4 +5,38 @@ services:
build: .
volumes:
- ..:/workspace:cached
ports:
- "8000:8000"
environment:
- QDRANT_URL=qdrant:6334
- TEI_BASE_URL=http://text-embeddings-inference
- TEI_RERANK_BASE_URL=http://text-embeddings-inference-rerank
command: sleep infinity
qdrant:
image: qdrant/qdrant:v1.7.4
volumes:
- ../qdrant_storage:/qdrant/storage:z
ports:
- "6333:6333"
- "6334:6334"
text-embeddings-inference:
image: ghcr.io/huggingface/text-embeddings-inference:cpu-0.6
volumes:
- "../tei_data:/data"
ports:
- "8001:80"
environment:
- MODEL_ID=${TEI_MODEL_ID:-BAAI/bge-large-en-v1.5}
- REVISION=${TEI_MODEL_REVISION}
text-embeddings-inference-rerank:
image: ghcr.io/huggingface/text-embeddings-inference:cpu-0.6
volumes:
- "../tei_data:/data"
ports:
- "8002:80"
environment:
- MODEL_ID=${TEI_RERANK_MODEL_ID:-BAAI/bge-reranker-large}
- REVISION=${TEI_RERANK_MODEL_REVISION}

View File

@ -30,7 +30,8 @@
"eamodio.gitlens",
"yzhang.markdown-all-in-one",
"DavidAnson.vscode-markdownlint",
"ninoseki.vscode-pylens"
"ninoseki.vscode-pylens",
"ms-azuretools.vscode-docker"
]
}
},

View File

View File

@ -0,0 +1,36 @@
from langchain.text_splitter import MarkdownHeaderTextSplitter
from langchain_core.runnables import Runnable, chain
def get_markdown_header_text_splitter_chain(
markdown_header_text_splitter: MarkdownHeaderTextSplitter,
) -> Runnable[str, list[str]]:
if not markdown_header_text_splitter.strip_headers:
raise ValueError("`strip_headers` must be True") # noqa: TRY003
@chain
def markdown_header_text_splitter_chain(text: str) -> list[str]:
documents = markdown_header_text_splitter.split_text(text)
# Add all parent headers to the page content
return [
"\n...\n".join(
f"{header_key} {document.metadata[header_key]}"
for _, header_key in markdown_header_text_splitter.headers_to_split_on
)
+ f"\n{document.page_content}"
for document in documents
]
return markdown_header_text_splitter_chain
markdown_3_headers_text_splitter_chain = get_markdown_header_text_splitter_chain(
MarkdownHeaderTextSplitter(
headers_to_split_on=[
("#", "#"),
("##", "##"),
("###", "###"),
],
strip_headers=True,
)
)

View File

@ -0,0 +1,17 @@
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
from langchain_core.runnables import Runnable, chain
def get_text_splitter_chain(
text_splitter: TextSplitter,
) -> Runnable[str, list[str]]:
@chain
def text_splitter_chain(text: str) -> list[str]:
return text_splitter.split_text(text)
return text_splitter_chain
recursive_character_text_splitter_chain = get_text_splitter_chain(
RecursiveCharacterTextSplitter()
)

View File

@ -0,0 +1,24 @@
from typing import Annotated
from fastapi import Depends
from llm_qa.embeddings.tei import TeiEmbeddings
from llm_qa.settings import Settings
def settings() -> Settings:
return Settings()
def tei_embeddings(settings: Annotated[Settings, Depends(settings)]) -> TeiEmbeddings:
return TeiEmbeddings(
base_url=settings.tei_base_url,
)
def tei_rerank_embeddings(
settings: Annotated[Settings, Depends(settings)],
) -> TeiEmbeddings:
return TeiEmbeddings(
base_url=settings.tei_rerank_base_url,
)

View File

View File

@ -0,0 +1,8 @@
from abc import ABC
from langchain.embeddings.base import Embeddings
from pydantic import BaseModel
class PydanticEmbeddings(Embeddings, BaseModel, ABC):
pass

View File

@ -0,0 +1,105 @@
from typing import override
from urllib.parse import urljoin
import httpx
from llm_qa.embeddings.base import PydanticEmbeddings
from llm_qa.models.tei import (
EmbedRequest,
EmbedResponseAdapter,
ErrorResponse,
RerankRequest,
RerankResponse,
RerankResponseAdapter,
)
class TeiEmbeddings(PydanticEmbeddings):
base_url: str
embed_endpoint: str = "/embed"
rerank_endpoint: str = "/rerank"
document_prefix: str = "passage: "
query_prefix: str = "query: "
_client: httpx.Client = httpx.Client()
_async_client: httpx.AsyncClient = httpx.AsyncClient()
@property
def embed_url(self) -> str:
return urljoin(self.base_url, self.embed_endpoint)
@property
def rerank_url(self) -> str:
return urljoin(self.base_url, self.rerank_endpoint)
@staticmethod
def _handle_status(response: httpx.Response) -> None:
try:
response.raise_for_status()
except httpx.HTTPStatusError as e:
match e.response.status_code:
case 413 | 422 | 424 | 429:
try:
error_response = ErrorResponse.model_validate_json(
e.response.content
)
note = f"Error: {error_response.error}, Error Type: {error_response.error_type}"
except ValueError:
note = e.response.text
e.add_note(note)
raise
def _embed(self, text: str | list[str]) -> list[list[float]]:
"""Embed text."""
response = self._client.post(
url=self.embed_url,
json=EmbedRequest(inputs=text).model_dump(),
)
self._handle_status(response)
return EmbedResponseAdapter.validate_json(response.content)
async def _aembed(self, text: str | list[str]) -> list[list[float]]:
"""Asynchronously embed text."""
response = await self._async_client.post(
url=self.embed_url,
json=EmbedRequest(inputs=text).model_dump(),
)
self._handle_status(response)
return EmbedResponseAdapter.validate_json(response.content)
@override
def embed_documents(self, texts: list[str]) -> list[list[float]]:
"""Embed search docs."""
return self._embed([self.document_prefix + text for text in texts])
@override
def embed_query(self, text: str) -> list[float]:
"""Embed query text."""
return self._embed(self.document_prefix + text)[0]
@override
async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
"""Asynchronous Embed search docs."""
return await self._aembed([self.document_prefix + text for text in texts])
@override
async def aembed_query(self, text: str) -> list[float]:
"""Asynchronous Embed query text."""
return (await self._aembed(self.document_prefix + text))[0]
def rerank(self, query: str, texts: list[str]) -> list[RerankResponse]:
"""Rerank texts."""
response = self._client.post(
url=self.rerank_url,
json=RerankRequest(query=query, texts=texts).model_dump(),
)
self._handle_status(response)
return RerankResponseAdapter.validate_json(response.content)
async def arerank(self, query: str, texts: list[str]) -> list[RerankResponse]:
"""Asynchronously rerank texts."""
response = await self._async_client.post(
url=self.rerank_url,
json=RerankRequest(query=query, texts=texts).model_dump(),
)
self._handle_status(response)
return RerankResponseAdapter.validate_json(response.content)

View File

View File

@ -0,0 +1,42 @@
from enum import StrEnum
from pydantic import BaseModel, Field, TypeAdapter
class EmbedRequest(BaseModel):
inputs: str | list[str]
normalize: bool = True
truncate: bool = False
EmbedResponseAdapter = TypeAdapter(list[list[float]])
class RerankRequest(BaseModel):
query: str
texts: list[str]
raw_scores: bool = False
return_text: bool = False
truncate: bool = False
class RerankResponse(BaseModel):
index: int = Field(..., ge=0)
score: float
text: str | None = None
RerankResponseAdapter = TypeAdapter(list[RerankResponse])
class ErrorType(StrEnum):
Unhealthy = "Unhealthy"
Backend = "Backend"
Overloaded = "Overloaded"
Validation = "Validation"
Tokenizer = "Tokenizer"
class ErrorResponse(BaseModel):
error: str
error_type: ErrorType

View File

@ -0,0 +1,18 @@
from enum import StrEnum
from pydantic import BaseModel
class TextType(StrEnum):
PLAIN_TEXT = "PLAIN_TEXT"
MARKDOWN = "MARKDOWN"
class UpsertTextRequest(BaseModel):
text: str
type: TextType
collection: str
class UpsertTextResponse(BaseModel):
num_documents: int

View File

View File

@ -0,0 +1,6 @@
from fastapi import APIRouter
from llm_qa.routers import upsert
router = APIRouter(prefix="/api/v1", tags=["v1"])
router.include_router(upsert.router)

View File

@ -0,0 +1,32 @@
from typing import Annotated
from fastapi import APIRouter, Depends
from llm_qa.dependencies import settings, tei_embeddings
from llm_qa.embeddings.base import PydanticEmbeddings
from llm_qa.models.upsert import UpsertTextRequest, UpsertTextResponse
from llm_qa.services.upsert import upsert_text as upsert_text_service
from llm_qa.settings import Settings
router = APIRouter()
@router.post("/upsert-text")
async def upsert_text(
upsert_request: UpsertTextRequest,
settings: Annotated[Settings, Depends(settings)],
embeddings: Annotated[PydanticEmbeddings, Depends(tei_embeddings)],
) -> UpsertTextResponse:
num_documents = await upsert_text_service(
text=upsert_request.text,
text_type=upsert_request.type,
collection=upsert_request.collection,
embeddings=embeddings,
qdrant_url=settings.qdrant_url,
)
return UpsertTextResponse(num_documents=num_documents)
@router.post("/upsert-file")
async def upsert_file() -> None:
raise NotImplementedError

View File

View File

@ -0,0 +1,43 @@
import logging
from langchain.schema.document import Document
from langchain_core.embeddings import Embeddings
from llm_qa.chains.text_splitters.markdown_header_text_splitter import (
markdown_3_headers_text_splitter_chain,
)
from llm_qa.chains.text_splitters.text_splitter import (
recursive_character_text_splitter_chain,
)
from llm_qa.models.upsert import TextType
from llm_qa.vectorstores.qdrant import upsert_documents
logger = logging.getLogger(__name__)
async def upsert_text(
text: str,
text_type: TextType,
collection: str,
embeddings: Embeddings,
qdrant_url: str,
) -> int:
match text_type:
case TextType.PLAIN_TEXT:
text_splitter_chain = recursive_character_text_splitter_chain
case TextType.MARKDOWN:
text_splitter_chain = markdown_3_headers_text_splitter_chain
case _:
raise ValueError(f"Unknown text type: `{text_type}`") # noqa: TRY003
text_chunks = await text_splitter_chain.ainvoke(text)
documents = [Document(page_content=chunk) for chunk in text_chunks]
await upsert_documents(
documents=documents,
embeddings=embeddings,
qdrant_url=qdrant_url,
collection=collection,
)
return len(documents)

View File

@ -0,0 +1,9 @@
from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings):
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8")
qdrant_url: str
tei_base_url: str
tei_rerank_base_url: str

View File

View File

@ -0,0 +1,22 @@
import logging
from langchain.docstore.document import Document
from langchain.embeddings.base import Embeddings
from langchain.vectorstores.qdrant import Qdrant
logger = logging.getLogger(__name__)
async def upsert_documents(
documents: list[Document], embeddings: Embeddings, qdrant_url: str, collection: str
) -> None:
logger.info(
"Upserting %d documents to Qdrant collection `%s`", len(documents), collection
)
await Qdrant.afrom_documents(
documents=documents,
embedding=embeddings,
url=qdrant_url,
prefer_grpc=True,
collection_name=collection,
)

12
llm-qa/llm_qa/web.py Normal file
View File

@ -0,0 +1,12 @@
from fastapi import FastAPI
from llm_qa.routers import api_v1
app = FastAPI(title="LLM QA")
app.include_router(api_v1.router)
if __name__ == "__main__":
import uvicorn
uvicorn.run("llm_qa.web:app", host="0.0.0.0", port=8000, reload=True) # noqa: S104

155
llm-qa/poetry.lock generated
View File

@ -287,6 +287,20 @@ files = [
{file = "charset_normalizer-3.3.2-py3-none-any.whl", hash = "sha256:3e4d1f6587322d2788836a99c69062fbb091331ec940e02d12d179c1d53e25fc"},
]
[[package]]
name = "click"
version = "8.1.7"
description = "Composable command line interface toolkit"
optional = false
python-versions = ">=3.7"
files = [
{file = "click-8.1.7-py3-none-any.whl", hash = "sha256:ae74fb96c20a0277a1d615f1e4d73c8414f5a98db8b799a7931d1582f3390c28"},
{file = "click-8.1.7.tar.gz", hash = "sha256:ca9853ad459e787e2192211578cc907e7594e294c7ccc834310722b41b9ca6de"},
]
[package.dependencies]
colorama = {version = "*", markers = "platform_system == \"Windows\""}
[[package]]
name = "colorama"
version = "0.4.6"
@ -338,6 +352,25 @@ files = [
[package.extras]
tests = ["asttokens (>=2.1.0)", "coverage", "coverage-enable-subprocess", "ipython", "littleutils", "pytest", "rich"]
[[package]]
name = "fastapi"
version = "0.109.2"
description = "FastAPI framework, high performance, easy to learn, fast to code, ready for production"
optional = false
python-versions = ">=3.8"
files = [
{file = "fastapi-0.109.2-py3-none-any.whl", hash = "sha256:2c9bab24667293b501cad8dd388c05240c850b58ec5876ee3283c47d6e1e3a4d"},
{file = "fastapi-0.109.2.tar.gz", hash = "sha256:f3817eac96fe4f65a2ebb4baa000f394e55f5fccdaf7f75250804bc58f354f73"},
]
[package.dependencies]
pydantic = ">=1.7.4,<1.8 || >1.8,<1.8.1 || >1.8.1,<2.0.0 || >2.0.0,<2.0.1 || >2.0.1,<2.1.0 || >2.1.0,<3.0.0"
starlette = ">=0.36.3,<0.37.0"
typing-extensions = ">=4.8.0"
[package.extras]
all = ["email-validator (>=2.0.0)", "httpx (>=0.23.0)", "itsdangerous (>=1.1.0)", "jinja2 (>=2.11.2)", "orjson (>=3.2.1)", "pydantic-extra-types (>=2.0.0)", "pydantic-settings (>=2.0.0)", "python-multipart (>=0.0.7)", "pyyaml (>=5.3.1)", "ujson (>=4.0.1,!=4.0.2,!=4.1.0,!=4.2.0,!=4.3.0,!=5.0.0,!=5.1.0)", "uvicorn[standard] (>=0.12.0)"]
[[package]]
name = "frozenlist"
version = "1.4.1"
@ -495,6 +528,62 @@ files = [
docs = ["Sphinx", "furo"]
test = ["objgraph", "psutil"]
[[package]]
name = "h11"
version = "0.14.0"
description = "A pure-Python, bring-your-own-I/O implementation of HTTP/1.1"
optional = false
python-versions = ">=3.7"
files = [
{file = "h11-0.14.0-py3-none-any.whl", hash = "sha256:e3fe4ac4b851c468cc8363d500db52c2ead036020723024a109d37346efaa761"},
{file = "h11-0.14.0.tar.gz", hash = "sha256:8f19fbbe99e72420ff35c00b27a34cb9937e902a8b810e2c88300c6f0a3b699d"},
]
[[package]]
name = "httpcore"
version = "1.0.2"
description = "A minimal low-level HTTP client."
optional = false
python-versions = ">=3.8"
files = [
{file = "httpcore-1.0.2-py3-none-any.whl", hash = "sha256:096cc05bca73b8e459a1fc3dcf585148f63e534eae4339559c9b8a8d6399acc7"},
{file = "httpcore-1.0.2.tar.gz", hash = "sha256:9fc092e4799b26174648e54b74ed5f683132a464e95643b226e00c2ed2fa6535"},
]
[package.dependencies]
certifi = "*"
h11 = ">=0.13,<0.15"
[package.extras]
asyncio = ["anyio (>=4.0,<5.0)"]
http2 = ["h2 (>=3,<5)"]
socks = ["socksio (==1.*)"]
trio = ["trio (>=0.22.0,<0.23.0)"]
[[package]]
name = "httpx"
version = "0.26.0"
description = "The next generation HTTP client."
optional = false
python-versions = ">=3.8"
files = [
{file = "httpx-0.26.0-py3-none-any.whl", hash = "sha256:8915f5a3627c4d47b73e8202457cb28f1266982d1159bd5779d86a80c0eab1cd"},
{file = "httpx-0.26.0.tar.gz", hash = "sha256:451b55c30d5185ea6b23c2c793abf9bb237d2a7dfb901ced6ff69ad37ec1dfaf"},
]
[package.dependencies]
anyio = "*"
certifi = "*"
httpcore = "==1.*"
idna = "*"
sniffio = "*"
[package.extras]
brotli = ["brotli", "brotlicffi"]
cli = ["click (==8.*)", "pygments (==2.*)", "rich (>=10,<14)"]
http2 = ["h2 (>=3,<5)"]
socks = ["socksio (==1.*)"]
[[package]]
name = "idna"
version = "3.6"
@ -1113,6 +1202,21 @@ files = [
[package.dependencies]
typing-extensions = ">=4.6.0,<4.7.0 || >4.7.0"
[[package]]
name = "pydantic-settings"
version = "2.1.0"
description = "Settings management using Pydantic"
optional = false
python-versions = ">=3.8"
files = [
{file = "pydantic_settings-2.1.0-py3-none-any.whl", hash = "sha256:7621c0cb5d90d1140d2f0ef557bdf03573aac7035948109adf2574770b77605a"},
{file = "pydantic_settings-2.1.0.tar.gz", hash = "sha256:26b1492e0a24755626ac5e6d715e9077ab7ad4fb5f19a8b7ed7011d52f36141c"},
]
[package.dependencies]
pydantic = ">=2.3.0"
python-dotenv = ">=0.21.0"
[[package]]
name = "pygments"
version = "2.17.2"
@ -1128,6 +1232,20 @@ files = [
plugins = ["importlib-metadata"]
windows-terminal = ["colorama (>=0.4.6)"]
[[package]]
name = "python-dotenv"
version = "1.0.1"
description = "Read key-value pairs from a .env file and set them as environment variables"
optional = false
python-versions = ">=3.8"
files = [
{file = "python-dotenv-1.0.1.tar.gz", hash = "sha256:e324ee90a023d808f1959c46bcbc04446a10ced277783dc6ee09987c37ec10ca"},
{file = "python_dotenv-1.0.1-py3-none-any.whl", hash = "sha256:f7b63ef50f1b690dddf550d03497b66d609393b40b564ed0d674909a68ebf16a"},
]
[package.extras]
cli = ["click (>=5.0)"]
[[package]]
name = "pyyaml"
version = "6.0.1"
@ -1363,6 +1481,23 @@ pure-eval = "*"
[package.extras]
tests = ["cython", "littleutils", "pygments", "pytest", "typeguard"]
[[package]]
name = "starlette"
version = "0.36.3"
description = "The little ASGI library that shines."
optional = false
python-versions = ">=3.8"
files = [
{file = "starlette-0.36.3-py3-none-any.whl", hash = "sha256:13d429aa93a61dc40bf503e8c801db1f1bca3dc706b10ef2434a36123568f044"},
{file = "starlette-0.36.3.tar.gz", hash = "sha256:90a671733cfb35771d8cc605e0b679d23b992f8dcfad48cc60b38cb29aeb7080"},
]
[package.dependencies]
anyio = ">=3.4.0,<5"
[package.extras]
full = ["httpx (>=0.22.0)", "itsdangerous", "jinja2", "python-multipart (>=0.0.7)", "pyyaml"]
[[package]]
name = "tenacity"
version = "8.2.3"
@ -1435,6 +1570,24 @@ h2 = ["h2 (>=4,<5)"]
socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"]
zstd = ["zstandard (>=0.18.0)"]
[[package]]
name = "uvicorn"
version = "0.27.1"
description = "The lightning-fast ASGI server."
optional = false
python-versions = ">=3.8"
files = [
{file = "uvicorn-0.27.1-py3-none-any.whl", hash = "sha256:5c89da2f3895767472a35556e539fd59f7edbe9b1e9c0e1c99eebeadc61838e4"},
{file = "uvicorn-0.27.1.tar.gz", hash = "sha256:3d9a267296243532db80c83a959a3400502165ade2c1338dea4e67915fd4745a"},
]
[package.dependencies]
click = ">=7.0"
h11 = ">=0.8"
[package.extras]
standard = ["colorama (>=0.4)", "httptools (>=0.5.0)", "python-dotenv (>=0.13)", "pyyaml (>=5.1)", "uvloop (>=0.14.0,!=0.15.0,!=0.15.1)", "watchfiles (>=0.13)", "websockets (>=10.4)"]
[[package]]
name = "wcwidth"
version = "0.2.13"
@ -1552,4 +1705,4 @@ multidict = ">=4.0"
[metadata]
lock-version = "2.0"
python-versions = "^3.12"
content-hash = "97e8b5f3e3e62c576dc50012212e077c00fd9fe18b6273a8929bf95d24ba168f"
content-hash = "7ffd6635973f31fbcbc5e962102b8cf12920f83e383dc0bb73732522db897a6f"

View File

@ -9,6 +9,10 @@ readme = "README.md"
python = "^3.12"
langchain-community = "^0.0.19"
langchain = "^0.1.6"
fastapi = "^0.109.2"
uvicorn = "^0.27.1"
httpx = "^0.26.0"
pydantic-settings = "^2.1.0"
[tool.poetry.group.dev.dependencies]
ruff = "0.2.1"
@ -18,3 +22,52 @@ ipython = "^8.21.0"
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
[tool.mypy]
python_version = "3.12"
plugins = ["pydantic.mypy"]
modules = ["llm_qa"]
strict = true
[tool.ruff]
target-version = "py312"
preview = true
include = ["llm_qa/**/*.py", "pyproject.toml"]
lint.select = [
"F",
"I",
"N",
"UP",
"YTT",
"ANN",
"ASYNC",
"S",
"B",
"C4",
"DTZ",
"FA",
"ISC",
"ICN",
"G",
"INP",
"PIE",
"PT",
"RSE",
"RET",
"SLF",
"SIM",
"TID",
"TCH",
"PTH",
"ERA",
"PGH",
"PL",
"TRY",
"FLY",
"PERF",
"FURB",
"LOG",
"RUF",
]
lint.ignore = ["E501", "ANN101", "PLR0913", "PLR0917", "ISC001"]