Source code for autorag.chroma_store

"""Persistent Chroma collection of per-clip topic embeddings.

Backs the ``/viz`` page's search box and acts as a cache so the page
load doesn't have to re-embed every topic on every request.
"""

from __future__ import annotations

from typing import TYPE_CHECKING, Any, cast

import chromadb
import numpy as np
from chromadb import Documents, EmbeddingFunction, Embeddings

from autorag.embed import Embedder

if TYPE_CHECKING:
    from pathlib import Path


[docs] class EmbedderEmbeddingFunction(EmbeddingFunction[Documents]): """Adapt :class:`Embedder` to Chroma's ``EmbeddingFunction`` protocol.""" def __init__(self, embedder: Embedder | None = None) -> None: self._embedder = embedder or Embedder() def __call__(self, input: Documents) -> Embeddings: vectors = self._embedder.embed_texts(list(input)) return cast("Embeddings", [np.asarray(v, dtype=np.float32) for v in vectors])
[docs] @staticmethod def name() -> str: return "autorag-ollama-embedder"
[docs] def default_chroma_dir(db_path: Path) -> Path: """Return the Chroma persistence directory derived from a SQLite db path.""" return db_path.expanduser().parent / "chroma"
[docs] class ChromaStore: """Persistent Chroma collection of per-clip topic embeddings.""" COLLECTION = "audio_clip_topics" def __init__( self, persist_dir: Path, embedding_function: EmbeddingFunction[Documents] | None = None, ) -> None: persist_dir.mkdir(parents=True, exist_ok=True) self._client = chromadb.PersistentClient(path=str(persist_dir)) self._ef = embedding_function or EmbedderEmbeddingFunction() self._collection = self._client.get_or_create_collection( name=self.COLLECTION, embedding_function=cast("Any", self._ef), metadata={"hnsw:space": "cosine"}, )
[docs] def add_topic_embeddings( self, clip_id: str, clip_title: str, topics: list[dict[str, Any]], embeddings: list[list[float]], ) -> None: """Upsert one document + embedding per topic for ``clip_id``. Ids use the ``"{clip_id}:{topic_index}"`` shape so the position within a clip's filtered (title-bearing) topic list is the stable key — matches what :func:`autorag.viz._collect_rows_embeddings` reads back. """ if not topics: return if len(topics) != len(embeddings): raise ValueError( f"topics ({len(topics)}) and embeddings ({len(embeddings)}) length mismatch" ) ids = [f"{clip_id}:{i}" for i in range(len(topics))] documents = [ f"{t['title']}. {t['summary']}" if t.get("summary") else str(t.get("title", "")) for t in topics ] metadatas: list[dict[str, str | int | float]] = [ { "clip_id": clip_id, "clip_title": clip_title, "topic_index": i, "title": str(t.get("title", "")), "summary": str(t.get("summary", "")), "level": int(t.get("level", 1)), "start_s": float(t.get("start_s", 0.0)), "duration_s": float(t.get("duration_s", 0.0)), "number": str(t.get("number", "")), } for i, t in enumerate(topics) ] self._collection.upsert( ids=ids, embeddings=cast("Any", embeddings), documents=documents, metadatas=cast("Any", metadatas), )
[docs] def get_clip_embeddings(self, clip_id: str) -> dict[int, list[float]]: """Return a ``topic_index -> embedding`` map for every cached topic.""" result = self._collection.get( where={"clip_id": clip_id}, include=["embeddings", "metadatas"], ) embeddings = result.get("embeddings") metadatas = result.get("metadatas") if embeddings is None or metadatas is None: return {} out: dict[int, list[float]] = {} for emb, meta in zip(embeddings, metadatas, strict=True): idx = meta.get("topic_index") if isinstance(idx, int): out[idx] = [float(x) for x in emb] return out
[docs] def query(self, query_embedding: list[float], top_k: int) -> list[dict[str, Any]]: """Return the ``top_k`` topics nearest ``query_embedding`` in cosine space. Each returned dict carries the topic's clip/title/summary metadata and a ``similarity`` field computed as ``1 - distance``. """ result = self._collection.query( query_embeddings=cast("Any", [query_embedding]), n_results=top_k, include=["metadatas", "distances", "documents"], ) ids = (result.get("ids") or [[]])[0] metadatas = (result.get("metadatas") or [[]])[0] distances = (result.get("distances") or [[]])[0] out: list[dict[str, Any]] = [] for _id, meta, dist in zip(ids, metadatas, distances, strict=True): topic_index = meta.get("topic_index", 0) out.append( { "clip_id": str(meta.get("clip_id", "")), "clip_title": str(meta.get("clip_title", "")), "topic_index": int(topic_index) if isinstance(topic_index, int) else 0, "title": str(meta.get("title", "")), "summary": str(meta.get("summary", "")), "similarity": 1.0 - float(dist), } ) return out
[docs] def delete_clip(self, clip_id: str) -> None: """Drop every topic row for ``clip_id`` from the collection.""" self._collection.delete(where={"clip_id": clip_id})
[docs] def count(self) -> int: """Return the total number of topic rows in the collection.""" return int(self._collection.count())