from __future__ import annotations
import json
import logging
import time
import uuid
from datetime import UTC, datetime
from pathlib import Path
from typing import TYPE_CHECKING, Any
from autorag.blocks import format_blocks
from autorag.config import Settings, get_settings
from autorag.embed import Embedder
from autorag.errors import MissingExtraError, _missing_extra
from autorag.generate import Generator
from autorag.ingest import chunk_document, load_documents
from autorag.retrieve import Retriever
from autorag.schemas import IngestResponse, QueryResponse
from autorag.store import InMemoryStore, VectorStore
if TYPE_CHECKING:
from langchain_core.runnables import Runnable
from autorag.schemas import Chunk
from autorag.types import TopicTree, TranscriptionResult, WordSpan
logger = logging.getLogger(__name__)
__all__ = ["AutoRAG", "MissingExtraError"]
[docs]
class AutoRAG:
"""Unified facade for the audio→topics agent and the document-RAG pipeline.
Heavy dependencies (whisper, torch, pyannote, chromadb, ...) are loaded
lazily on first use, so a base install can import :class:`AutoRAG` without
pulling them. Methods raise :class:`MissingExtraError` with the specific
extras hint when an extra is missing.
"""
def __init__(
self,
settings: Settings | None = None,
store: VectorStore | None = None,
embedder: Embedder | None = None,
generator: Generator | None = None,
) -> None:
self.settings = settings or get_settings()
self.store = store or InMemoryStore()
self.embedder = embedder or Embedder()
self.generator = generator or Generator(model=self.settings.model)
self.retriever = Retriever(self.store, self.embedder)
# ── Document RAG ──────────────────────────────────────────────────────
[docs]
def ingest(self, paths: list[str | Path]) -> IngestResponse:
docs = load_documents(paths)
all_chunks: list[Chunk] = []
for doc in docs:
all_chunks.extend(
chunk_document(
doc,
chunk_size=self.settings.chunk_size,
chunk_overlap=self.settings.chunk_overlap,
)
)
self.embedder.embed_chunks(all_chunks)
self.store.add(all_chunks)
return IngestResponse(ingested=len(docs), chunks=len(all_chunks))
[docs]
def query(self, question: str, top_k: int | None = None) -> QueryResponse:
k = top_k or self.settings.top_k
retrieved = self.retriever.retrieve(question, top_k=k)
answer = self.generator.generate(question, retrieved)
return QueryResponse(answer=answer, sources=retrieved)
# ── Audio → topics ────────────────────────────────────────────────────
def _resolve_clip_identity(
self,
file: Path | str,
source_url: str | None,
upload_date: str | None,
) -> tuple[str, datetime, str]:
"""Return (session_id, audio_start, stored_file_path) for a clip."""
from autorag.audio_source import _canonical_youtube_url, is_youtube_url
path = Path(file)
canonical_source_url: str | None = None
if source_url is not None and is_youtube_url(source_url):
canonical_source_url = _canonical_youtube_url(source_url)
session_id = str(uuid.uuid5(uuid.NAMESPACE_URL, canonical_source_url))
elif source_url is not None:
canonical_source_url = source_url
session_id = str(uuid.uuid5(uuid.NAMESPACE_URL, source_url))
else:
session_id = str(uuid.uuid5(uuid.NAMESPACE_URL, str(path.resolve())))
if upload_date:
uploaded_at = datetime.strptime(upload_date, "%Y%m%d").replace(tzinfo=UTC)
audio_start: datetime = uploaded_at
else:
try:
mtime = path.stat().st_mtime
audio_start = datetime.fromtimestamp(mtime, tz=UTC)
except OSError:
audio_start = datetime.now(tz=UTC)
stored_file_path = canonical_source_url or str(path.resolve())
return (session_id, audio_start, stored_file_path)
[docs]
def transcribe(
self,
file: Path | str,
*,
whisper_model: str = "base",
language: str | None = "en",
) -> list[WordSpan]:
"""Run Whisper + diarization on an audio file or YouTube URL.
``file`` is either a local audio file path or a YouTube URL
(``youtube.com``, ``youtu.be``, ``m.youtube.com``,
``music.youtube.com``). YouTube URLs are downloaded to a temporary
``.webm`` for the duration of the call.
Returns raw word spans. Use :meth:`generate_topics` for the LLM
topic tree, and :meth:`persist_transcription` / :meth:`persist_topics`
to store results (separate ``[rag]`` extra).
``language`` defaults to English (``"en"``); pass ``language=None``
to let Whisper auto-detect.
Requires ``pip install 'autorag[audio,diarize]'``,
plus ``[youtube]`` when passing a URL.
"""
try:
from autorag.agent import transcribe_audio as _transcribe_audio
from autorag.audio_source import resolve_audio_input
except ModuleNotFoundError as exc:
raise _missing_extra("audio,diarize", exc) from exc
with resolve_audio_input(file) as src:
return _transcribe_audio(src.path, whisper_model=whisper_model, language=language)
[docs]
def generate_topics(
self,
words: list[WordSpan],
*,
llm_model: str = "gemma4:latest",
ollama_base_url: str | None = None,
num_ctx_l1: int = 8192,
num_ctx_fanout: int = 8192,
max_concurrency: int = 8,
min_subdivide_duration_s: float = 120.0,
reasoning: bool = False,
boundary_block_seconds: int = 30,
) -> TopicTree:
"""Run LLM topic extraction on pre-computed word spans.
Requires ``pip install 'autorag[audio,diarize]'`` (LangChain + Ollama).
"""
try:
from autorag.agent import generate_topics as _agent_generate_topics
from autorag.persistence import collapse_lone_children
except ModuleNotFoundError as exc:
raise _missing_extra("audio,diarize", exc) from exc
raw = _agent_generate_topics(
words,
llm_model=llm_model,
ollama_base_url=ollama_base_url,
num_ctx_l1=num_ctx_l1,
num_ctx_fanout=num_ctx_fanout,
max_concurrency=max_concurrency,
min_subdivide_duration_s=min_subdivide_duration_s,
reasoning=reasoning,
boundary_block_seconds=boundary_block_seconds,
)
return collapse_lone_children(raw)
[docs]
def build_agent(self, **kwargs: Any) -> Runnable[Path | str, TranscriptionResult]:
"""Return the LangChain :class:`Runnable` for batched / streaming use.
Same extras as :meth:`transcribe`. Forwards ``**kwargs`` to
:func:`autorag.agent.build_agent`.
"""
try:
from autorag.agent import build_agent as _agent_build
except ModuleNotFoundError as exc:
raise _missing_extra("audio,diarize", exc) from exc
return _agent_build(**kwargs)
[docs]
def transcribe_blocks(
self,
file: Path | str,
seconds: int = 10,
*,
force_retranscribe: bool = False,
db_path: Path | None = None,
whisper_model: str = "base",
language: str | None = "en",
title: str | None = None,
) -> str:
"""Return the transcription formatted as N-second time blocks.
Resolution order:
1. ``session_id = derive_session_id(file)``.
2. If SQLite has a row for ``session_id`` with a non-null
``transcription`` and ``force_retranscribe`` is False, decode
it and format — returns immediately (no ``[audio]`` needed).
3. Else run :meth:`transcribe` and :meth:`persist_transcription`,
then format. Topic generation is not performed here; call
:meth:`generate_topics` and :meth:`persist_topics` separately.
Each non-empty bucket emits one line per speaker turn,
``MM:SS-MM:SS Speaker K: <words>``. See
:func:`autorag.blocks.format_blocks` for the full algorithm.
Requires ``pip install 'autorag[rag]'`` for the cached path;
``[audio,diarize]`` (+ ``[youtube]`` for URLs) on cache miss.
"""
if seconds <= 0:
raise ValueError("seconds must be a positive integer")
try:
from autorag.audio_source import default_title_from
from autorag.db import Database
from autorag.persistence import derive_session_id, load_transcription
except ModuleNotFoundError as exc:
raise _missing_extra("rag", exc) from exc
session_id = derive_session_id(file)
resolved_db = (db_path or self.settings.db_path).expanduser()
db = Database(resolved_db)
if not force_retranscribe:
cached = load_transcription(db, session_id)
if cached is not None:
return format_blocks(cached, seconds)
try:
from autorag.audio_source import resolve_audio_input
except ModuleNotFoundError as exc:
raise _missing_extra("audio,diarize", exc) from exc
source_str = file if isinstance(file, str) else str(file)
with resolve_audio_input(file) as src:
words = self.transcribe(src.path, whisper_model=whisper_model, language=language)
resolved_title = title or src.title or default_title_from(source_str)
self.persist_transcription(
src.path,
words,
title=resolved_title,
db_path=db_path,
source_url=src.source_url,
upload_date=src.upload_date,
duration_s=src.duration_s,
)
return format_blocks(words, seconds)
[docs]
def persist_transcription(
self,
file: Path | str,
words: list[WordSpan],
*,
title: str | None = None,
db_path: Path | None = None,
source_url: str | None = None,
upload_date: str | None = None,
duration_s: float | None = None,
) -> dict[str, Any]:
"""Write word spans to SQLite (clip row + words). Returns clip + session_id + timings.
Requires ``pip install 'autorag[rag]'`` (pydantic_sqlite).
``duration_s`` is informational and not persisted.
``source_url`` (optional) seeds ``session_id`` from the canonical URL
so re-fetching the same URL overwrites the existing row.
``upload_date`` (optional, ``"YYYYMMDD"`` from yt-dlp) anchors
``created_at`` to the video's publish date.
Use :meth:`persist_topics` to store the topic tree and embed titles.
"""
del duration_s # informational; no schema column for it yet
try:
from autorag.db import Database
except ModuleNotFoundError as exc:
raise _missing_extra("rag", exc) from exc
path = Path(file)
if not path.is_file():
raise FileNotFoundError(f"{path} is not a file.")
resolved_db = (db_path or self.settings.db_path).expanduser()
db = Database(resolved_db)
session_id, audio_start, stored_file_path = self._resolve_clip_identity(
file, source_url, upload_date
)
clip_title = title or path.stem
created_at = audio_start.isoformat().replace("+00:00", "Z")
db.create_clip(
session_id,
title=clip_title,
file_path=stored_file_path,
created_at=created_at,
)
t = time.perf_counter()
db.store_transcription(session_id, words) # type: ignore[arg-type]
store_words_s = time.perf_counter() - t
clip = db.get_clip(session_id)
return {
"clip": clip,
"session_id": session_id,
"timings": {"store_words": store_words_s},
}
[docs]
def persist_topics(
self,
file: Path | str,
topics: TopicTree,
*,
words: list[WordSpan] | None = None,
transcript_end_s: float | None = None,
title: str | None = None,
provider: str = "ollama",
llm_model: str = "gemma4:latest",
whisper_model: str = "base",
db_path: Path | None = None,
source_url: str | None = None,
upload_date: str | None = None,
duration_s: float | None = None,
) -> dict[str, Any]:
"""Store topic tree to SQLite and embed topic titles into Chroma.
Requires ``pip install 'autorag[rag]'`` (chromadb + pydantic_sqlite).
Call :meth:`persist_transcription` first to create the clip row;
this method will create it idempotently if needed.
``transcript_end_s``: audio end time in seconds used to anchor events.
Computed from ``words[-1]`` when ``words`` is supplied, else ``0.0``.
``duration_s`` is informational and not persisted.
"""
del duration_s # informational; no schema column for it yet
try:
from autorag.audio_source import is_youtube_url
from autorag.chroma_store import ChromaStore, default_chroma_dir
from autorag.db import Database
from autorag.persistence import topics_to_events
except ModuleNotFoundError as exc:
raise _missing_extra("rag", exc) from exc
path = Path(file)
if not is_youtube_url(str(file)) and not path.is_file():
raise FileNotFoundError(f"{path} is not a file.")
resolved_db = (db_path or self.settings.db_path).expanduser()
db = Database(resolved_db)
session_id, audio_start, stored_file_path = self._resolve_clip_identity(
file, source_url, upload_date
)
clip_title = title or path.stem
created_at = audio_start.isoformat().replace("+00:00", "Z")
db.create_clip(
session_id,
title=clip_title,
file_path=stored_file_path,
created_at=created_at,
)
if transcript_end_s is not None:
end_s = transcript_end_s
elif words:
last = words[-1]
end_s = last.get("e", 0.0)
else:
end_s = 0.0
t = time.perf_counter()
pending_events = topics_to_events(
db,
session_id,
topics,
audio_start=audio_start,
provider=provider,
llm_model=llm_model,
topic_category_ids=("l1", "l2", "l3"),
)
db.finalize_topics(
session_id,
end_s,
events=pending_events,
provider=provider,
llm_model=llm_model,
whisper_model=whisper_model,
)
finalize_s = time.perf_counter() - t
t = time.perf_counter()
clip_data = db.get_clip(session_id)
if clip_data and clip_data.get("topics"):
topic_list = [tp for tp in json.loads(clip_data["topics"]) if tp.get("title")]
texts = [
f"{tp['title']}. {tp['summary']}" if tp.get("summary") else tp["title"]
for tp in topic_list
]
if texts:
try:
embeddings = self.embedder.embed_texts(texts)
chroma = ChromaStore(default_chroma_dir(resolved_db))
chroma.delete_clip(session_id)
chroma.add_topic_embeddings(
session_id,
str(clip_data.get("title", "")),
topic_list,
embeddings,
)
except Exception as exc:
logger.warning("embedding/index failed: %s", exc)
embed_s = time.perf_counter() - t
clip = db.get_clip(session_id)
return {
"clip": clip,
"session_id": session_id,
"timings": {"finalize": finalize_s, "embed": embed_s},
}