Source code for autorag.viz

"""Topic embedding visualization: Chroma-backed embeddings + UMAP + FastAPI endpoints."""

from __future__ import annotations

import json
import pathlib
from typing import Any

import numpy as np
import numpy.typing as npt
import umap
from fastapi import APIRouter, HTTPException, Query
from fastapi.responses import HTMLResponse
from pydantic import BaseModel

from autorag.chroma_store import ChromaStore, default_chroma_dir
from autorag.config import get_settings
from autorag.db import Database
from autorag.embed import Embedder
from autorag.topic_cluster import build_edges, cluster_embeddings

router = APIRouter()

_VIZ_DIR = pathlib.Path(__file__).parent / "static" / "viz"
_HTML_PATH = _VIZ_DIR / "index.html"
viz_assets_dir = _VIZ_DIR


# ---------------------------------------------------------------------------
# Response schemas
# ---------------------------------------------------------------------------


[docs] class TopicPoint(BaseModel): """One topic node placed in 3-D space by the UMAP projection.""" topic_title: str clip_id: str clip_title: str level: int start_s: float duration_s: float number: str summary: str = "" x: float y: float z: float cluster_id: int = 0
[docs] class Edge(BaseModel): """A similarity edge between two :class:`TopicPoint` indices.""" a: int b: int similarity: float
[docs] class VizData(BaseModel): """Payload returned by ``GET /viz/data`` — points, edges, clip metadata.""" points: list[TopicPoint] clip_ids: list[str] clip_titles: dict[str, str] total_topics: int total_clips: int edges: list[Edge] = [] total_clusters: int = 0
[docs] class SearchResult(BaseModel): """One hit returned by ``GET /viz/search``.""" point_index: int topic_title: str clip_title: str clip_id: str similarity: float summary: str = ""
# --------------------------------------------------------------------------- # UMAP dimensionality reduction # ---------------------------------------------------------------------------
[docs] def umap_3d(embeddings: list[list[float]]) -> npt.NDArray[np.float64]: """Project N-D embeddings down to 3 columns with cosine UMAP. Handles the small-N degenerate cases (n == 1 → all zeros; n < 4 → pad to three columns) so the page can render the very first clip. """ emb = np.array(embeddings, dtype=np.float64) n = len(emb) if n == 1: return np.zeros((1, 3)) n_components = min(3, n - 1) n_neighbors = min(15, n - 1) reducer = umap.UMAP( n_components=n_components, metric="cosine", n_neighbors=n_neighbors, random_state=42, ) coords: npt.NDArray[np.float64] = np.asarray(reducer.fit_transform(emb), dtype=np.float64) if coords.shape[1] < 3: coords = np.pad(coords, ((0, 0), (0, 3 - coords.shape[1]))) return coords # (N, 3)
# --------------------------------------------------------------------------- # Shared row/embedding collection # --------------------------------------------------------------------------- Row = tuple[str, str, dict[str, Any], int] """A single point in the viz: (clip_id, clip_title, topic, topic_index).""" def _collect_rows_embeddings( clips: list[dict[str, Any]], chroma: ChromaStore, ) -> tuple[list[Row], list[list[float]]]: """Build (rows, embeddings) from clip records, filling missing vecs via Ollama. ``topic_index`` is the position of the topic within the clip's filtered (title-bearing) topic list — same convention used by ``cli._transcribe`` when writing into Chroma. """ rows: list[Row] = [] for clip in clips: raw = clip.get("topics") if not raw: continue try: topics = json.loads(raw) except (json.JSONDecodeError, TypeError): continue topics = [t for t in topics if t.get("title")] for i, t in enumerate(topics): rows.append((clip["id"], clip["title"], t, i)) if not rows: return [], [] stored: dict[str, dict[int, list[float]]] = {} for clip in clips: try: stored[clip["id"]] = chroma.get_clip_embeddings(clip["id"]) except Exception: stored[clip["id"]] = {} embeddings: list[list[float] | None] = [] missing_per_clip: dict[str, list[tuple[int, dict[str, Any], int]]] = {} for row_idx, (clip_id, _clip_title, t, topic_index) in enumerate(rows): vec = stored.get(clip_id, {}).get(topic_index) embeddings.append(vec) if vec is None: missing_per_clip.setdefault(clip_id, []).append((row_idx, t, topic_index)) if missing_per_clip: all_texts: list[str] = [] flat: list[tuple[int, dict[str, Any], int]] = [] for items in missing_per_clip.values(): for row_idx, t, topic_index in items: text = f"{t['title']}. {t['summary']}" if t.get("summary") else t["title"] all_texts.append(text) flat.append((row_idx, t, topic_index)) computed = Embedder().embed_texts(all_texts) for (row_idx, _t, _topic_index), vec in zip(flat, computed, strict=True): embeddings[row_idx] = vec clip_titles = {clip["id"]: clip.get("title", "") for clip in clips} for clip_id, items in missing_per_clip.items(): current = dict(stored.get(clip_id, {})) for row_idx, _t, topic_index in items: vec = embeddings[row_idx] if vec is not None: current[topic_index] = vec topic_lookup = {ti: t for cid, _, t, ti in rows if cid == clip_id} ordered_indices = [idx for idx in sorted(current.keys()) if idx in topic_lookup] try: chroma.delete_clip(clip_id) chroma.add_topic_embeddings( clip_id, str(clip_titles.get(clip_id, "")), [topic_lookup[idx] for idx in ordered_indices], [current[idx] for idx in ordered_indices], ) except Exception: pass return rows, [e for e in embeddings if e is not None] # --------------------------------------------------------------------------- # Endpoints # --------------------------------------------------------------------------- @router.get("/viz", response_class=HTMLResponse, include_in_schema=False) def viz_page() -> HTMLResponse: """Serve the React/r3f single-page app at ``GET /viz``.""" return HTMLResponse(_HTML_PATH.read_text(encoding="utf-8")) @router.get("/viz/data", response_model=VizData) def viz_data( distance_threshold: float = Query(default=0.35, ge=0.0, le=1.0), ) -> VizData: """Return the full :class:`VizData` payload for the ``/viz`` page. Pulls every clip + topic from SQLite, fills missing embeddings via Ollama (and caches them in Chroma), runs the 3-D UMAP projection, and assembles cluster labels and similarity edges using ``distance_threshold`` as the cluster cut. """ settings = get_settings() db_path = settings.db_path.expanduser() db = Database(db_path) clips = db.list_clips() chroma = ChromaStore(default_chroma_dir(db_path)) try: rows, embeddings = _collect_rows_embeddings(clips, chroma) except Exception as exc: raise HTTPException(status_code=503, detail=str(exc)) from exc if not rows: return VizData( points=[], clip_ids=[], clip_titles={}, total_topics=0, total_clips=len(clips), ) coords = umap_3d(embeddings) emb_matrix = np.array(embeddings, dtype=np.float64) cluster_labels = cluster_embeddings(emb_matrix, distance_threshold=distance_threshold) raw_edges = build_edges(emb_matrix) seen: dict[str, str] = {} for clip_id, clip_title, _, _ in rows: seen.setdefault(clip_id, clip_title) clip_ids = list(seen.keys()) points = [ TopicPoint( topic_title=t["title"], clip_id=clip_id, clip_title=clip_title, level=int(t.get("level", 1)), start_s=float(t.get("start_s", 0.0)), duration_s=float(t.get("duration_s", 0.0)), number=str(t.get("number", "")), summary=str(t.get("summary", "")), x=float(coords[i, 0]), y=float(coords[i, 1]), z=float(coords[i, 2]), cluster_id=int(cluster_labels[i]) if i < len(cluster_labels) else 0, ) for i, (clip_id, clip_title, t, _topic_index) in enumerate(rows) ] edges = [Edge(a=a, b=b, similarity=s) for a, b, s in raw_edges] total_clusters = int(cluster_labels.max()) + 1 if len(cluster_labels) > 0 else 0 return VizData( points=points, clip_ids=clip_ids, clip_titles=seen, total_topics=len(points), total_clips=len(clips), edges=edges, total_clusters=total_clusters, ) @router.get("/viz/search", response_model=list[SearchResult]) def viz_search( q: str = Query(..., min_length=1), top_k: int = Query(default=10, ge=1, le=100), ) -> list[SearchResult]: """Return the ``top_k`` topics whose embedding is closest to ``q``. Embeds the query with the same Ollama model used at ingest time and runs the search inside Chroma. Hits that don't have a corresponding point in the current viz dataset are skipped silently. """ q = q.strip() if not q: return [] settings = get_settings() db_path = settings.db_path.expanduser() db = Database(db_path) clips = db.list_clips() chroma = ChromaStore(default_chroma_dir(db_path)) try: query_vec = Embedder().embed_texts([q])[0] except Exception as exc: raise HTTPException(status_code=503, detail=str(exc)) from exc try: rows, _embeddings = _collect_rows_embeddings(clips, chroma) results = chroma.query(query_vec, top_k=top_k) except Exception as exc: raise HTTPException(status_code=503, detail=str(exc)) from exc if not rows: return [] row_lookup = {(r[0], r[3]): i for i, r in enumerate(rows)} out: list[SearchResult] = [] for r in results: key = (r["clip_id"], r["topic_index"]) if key not in row_lookup: continue out.append( SearchResult( point_index=row_lookup[key], topic_title=str(r["title"]), clip_title=str(r["clip_title"]), clip_id=str(r["clip_id"]), similarity=float(r["similarity"]), summary=str(r["summary"]), ) ) return out