Source code for autorag.persistence

"""Topic-tree → SQLite/Chroma persistence helpers.

Pure functions extracted from the CLI so the SDK's
:meth:`autorag.core.AutoRAG.persist_transcription` can reuse them.
"""

from __future__ import annotations

import json
import logging
import uuid
from datetime import timedelta
from pathlib import Path
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
    from collections.abc import Generator
    from datetime import datetime

    from autorag.db import Database
    from autorag.types import TopicDict, TopicTree, WordSpan

logger = logging.getLogger(__name__)


[docs] def derive_session_id(file_or_url: str | Path) -> str: """Compute the same ``session_id`` :meth:`AutoRAG.persist_transcription` would write. Mirrors the inline logic in :meth:`AutoRAG.persist_transcription`: - YouTube URL → ``uuid5(NAMESPACE_URL, _canonical_youtube_url(url))`` - Local Path → ``uuid5(NAMESPACE_URL, str(path.resolve()))`` Only ``autorag.audio_source`` is imported (base-safe; ``yt_dlp`` stays behind its own lazy import). Safe to call without ``[audio]``/``[rag]``. """ from autorag.audio_source import _canonical_youtube_url, is_youtube_url if isinstance(file_or_url, str) and is_youtube_url(file_or_url): return str(uuid.uuid5(uuid.NAMESPACE_URL, _canonical_youtube_url(file_or_url))) path = Path(file_or_url) return str(uuid.uuid5(uuid.NAMESPACE_URL, str(path.resolve())))
[docs] def load_transcription(db: Database, session_id: str) -> list[WordSpan] | None: """Return the stored word list for ``session_id``, or ``None`` if the row is missing or has no transcription. Parses the JSON string written by :meth:`Database.store_transcription`. Reads via raw ``sqlite_utils`` (matching :meth:`Database.list_clips`) so a freshly-opened :class:`Database` instance can read rows it didn't write — ``pydantic_sqlite``'s model registry is in-memory only. """ inner = db.db._db # pyright: ignore[reportPrivateUsage] try: if "audio_clips" not in inner.table_names(): return None rows = list(inner["audio_clips"].rows_where("id = ?", [session_id])) except Exception: return None if not rows: return None raw = rows[0].get("transcription") if raw is None: return None return list(json.loads(raw))
[docs] def load_clip(db: Database, session_id: str) -> dict[str, Any] | None: """Return the clip row as a plain dict, or ``None`` if absent. The cross-process analogue of :meth:`Database.get_clip`: that method goes through ``pydantic_sqlite.model_from_table``, whose model registry is in-memory only, so a process that did not *write* the row (e.g. the API reading a row a worker wrote) cannot see it. This reads via raw ``sqlite_utils`` exactly like :func:`load_transcription` / :meth:`Database.list_clips`. """ inner = db.db._db # pyright: ignore[reportPrivateUsage] try: if "audio_clips" not in inner.table_names(): return None rows = list(inner["audio_clips"].rows_where("id = ?", [session_id])) except Exception: return None return dict(rows[0]) if rows else None
[docs] def collapse_lone_children(tree: TopicTree) -> TopicTree: """Drop single-child chains so a subtopic level only exists with >=2 siblings.""" def walk(nodes: list[TopicDict]) -> list[TopicDict]: out: list[TopicDict] = [] for node in nodes: children = list(node.get("children") or []) while len(children) == 1: lone = children[0] children = list(lone.get("children") or []) node["children"] = walk(children) out.append(node) return out return {"topics": walk(tree.get("topics") or [])}
[docs] def iter_topics_flat(tree: TopicTree) -> Generator[tuple[int, TopicDict, str], None, None]: """Yield (level, node, number_label) like '1', '1.2', '1.2.3'.""" def walk( nodes: list[TopicDict], level: int, parent_number: str ) -> Generator[tuple[int, TopicDict, str], None, None]: sibling_count = 0 for node in nodes: title = str(node.get("title", "") or "").strip() if not title: continue sibling_count += 1 number_label = ( str(sibling_count) if not parent_number else f"{parent_number}.{sibling_count}" ) yield level, node, number_label children = node.get("children") or [] if level < 3 and children: yield from walk(children, level + 1, number_label) yield from walk(tree.get("topics") or [], 1, "")
[docs] def topics_to_events( db: Database, session_id: str, tree: TopicTree, *, audio_start: datetime, provider: str, llm_model: str, topic_category_ids: tuple[str, str, str], ) -> list[dict[str, Any]]: """Walk the topic tree and produce analytics events for each titled node. Reads the hierarchical-agent's ``s`` / ``e`` keys (not ``start_s`` / ``end_s``). """ cat_by_level = {1: topic_category_ids[0], 2: topic_category_ids[1], 3: topic_category_ids[2]} events: list[dict[str, Any]] = [] for level, node, number_label in iter_topics_flat(tree): title = str(node.get("title", "") or "").strip() if not title: continue try: start_s = float(node.get("s", 0.0) or 0.0) except (TypeError, ValueError): start_s = 0.0 word_end_s: float | None = None raw_end = node.get("e") try: if raw_end is not None: word_end_s = float(raw_end) except (TypeError, ValueError): word_end_s = None summary = str(node.get("summary", "") or "").strip() metadata: dict[str, Any] = { "transcription": { "level": level, "provider": provider, "model": llm_model, "number_label": number_label, "word_start_s": start_s, "word_end_s": word_end_s, "summary": summary, } } marked_at = audio_start + timedelta(seconds=max(0.0, start_s)) category_id = cat_by_level.get(level) if not category_id: continue try: event = db.add_analytics_event( session_id, category=category_id, message=title, metadata=metadata, marked_at_utc=marked_at, ) events.append(event) except Exception as exc: logger.warning( "add_analytics_event failed for topic %r (level=%d): %s", title, level, exc ) return events