Source code for autorag.agent

"""Audio → hierarchical topic tree. The single agent for AutoRAG.

Multi-pass L0 / L1 / L2 extractor — each LLM stage has one focused job::

    1. Whisper                              -> list[WordSpan]               1 call
    2. L1 boundaries  (single LLM call)     -> list[{s,e}]                  1 LLM
    3a Decide subdivide  (per long L1)      -> list[bool]                   N LLM
    3b L2 boundaries  (per yes-L1, batched) -> list[list[{s,e}]]            M LLM (M<=N)
    4. Summarize nodes  (per L1+L2, batched)-> {title,summary} per node     K LLM
    5. L0 aggregate                         -> {title, summary}             1 LLM

Final shape: ``{"topics": [L0]}`` with ``L0.children = [L1...]``, each
``L1.children = [L2...]`` or ``[]``. The L0 root is the explicit "what is
this audio about" node.

Boundary calls receive a time-bucketed (``format_blocks``,
``boundary_block_seconds``, default 30s) transcript and
emit ``{s, e}`` as ``MM:SS`` strings, which we parse back to float seconds here
(never the LLM — no model-side arithmetic). Per-node summary calls operate on
the slice's plain text (no timestamps) and emit ``{title, summary}``. The
K=N1+N2 summary calls share an identical prompt prefix for cache reuse.
"""

from __future__ import annotations

import logging
import os
import time
from itertools import pairwise
from pathlib import Path
from typing import TYPE_CHECKING, Any, NamedTuple, cast

from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import Runnable, RunnableConfig, RunnableLambda
from langchain_ollama import ChatOllama
from pydantic import BaseModel

from autorag import diarize, whisper_runner
from autorag.blocks import format_blocks, mmss
from autorag.blocks import group_by_speaker as _group_by_speaker
from autorag.otel import bind_current_context, get_tracer
from autorag.otel_callbacks import OTelSpanCallbackHandler
from autorag.types import TopicDict, TopicTree, TranscriptionResult, WordSpan

if TYPE_CHECKING:
    from collections.abc import Callable

logger = logging.getLogger(__name__)

# Default window size for the block-formatted transcript fed to the L1/L2
# boundary prompts (overridable via the `boundary_block_seconds` kwarg). Smaller
# = more frequent MM:SS anchors (finer possible boundaries) at the cost of more
# lines; 30s gives a fresh anchor at least twice a minute even inside a long
# single-speaker monologue.
_BOUNDARY_BLOCK_SECONDS = 30


class _Boundary(BaseModel):
    # `str`, not `float`, on purpose: the structured-output JSON schema is
    # derived from these annotations, so a `float` here would instruct the
    # model to emit numbers. We want it to copy the transcript's `MM:SS`
    # markers verbatim; `_parse_ts` converts them to seconds before tiling.
    s: str
    e: str


class _BoundaryList(BaseModel):
    topics: list[_Boundary]


class _SubdivideDecision(BaseModel):
    # `reason` is placed BEFORE `subdivide` on purpose: small structured-output
    # models produce more accurate booleans when they emit a short rationale
    # first. The reason is parsed but discarded by the orchestrator.
    reason: str
    subdivide: bool


class _NodeSummary(BaseModel):
    title: str
    summary: str


class _L0Summary(BaseModel):
    title: str
    summary: str


_L1_SYS = (
    "You are a topic boundary detector. You receive a recording's "
    "transcript as time-bucketed blocks: blocks are separated by blank "
    "lines and every line is `MM:SS-MM:SS Speaker K: <words>` (the two "
    "MM:SS values are that turn's start and end). You must split the "
    "recording into ordered, non-overlapping top-level (L1) topics that "
    "TILE the audio from start to end.\n\n"
    "Speaker changes are useful evidence for topic boundaries, but a "
    "single topic may span multiple speakers.\n\n"
    "Rules:\n"
    "1. Return ONLY intervals -- no titles, no summaries. Each item is "
    '{{"s": "MM:SS", "e": "MM:SS"}}.\n'
    "2. The first topic's `s` equals the first MM:SS in the transcript; "
    "the last topic's `e` equals the last MM:SS in the transcript.\n"
    "3. Adjacent topics tile end-to-start: for siblings A then B, set "
    "B.s = A.e (no gaps, no overlap). Order siblings by time.\n"
    "4. Aim for roughly the suggested topic count -- it is calibrated to "
    "duration. Do NOT over-split into 15+ tiny topics; do NOT collapse "
    "into a single topic unless the audio is very short.\n"
    "5. Topics typically span tens to hundreds of seconds, not single "
    "lines.\n"
    "6. Copy MM:SS values directly from the transcript's range markers. "
    "Do not invent or reformat timestamps."
)
_L1_HUMAN = (
    "Audio runs from 00:00 to {audio_e} (~{duration_min:.1f} min). "
    "Suggested topic count: ~{target_count}. Spread topics across the "
    "FULL duration; do NOT cluster them near the start.\n\n"
    "Time-bucketed transcript (blocks separated by blank lines; each "
    "line is `MM:SS-MM:SS Speaker K: <words>`):\n{transcript}"
)

_DECIDE_SYS = (
    "You decide whether a passage of speech is substantial enough to be "
    "broken into 2 or more distinct subtopics, or whether it covers a "
    "single coherent point that should NOT be subdivided.\n\n"
    "Text may include `Speaker N:` prefixes when multiple speakers are "
    "present. Consider all speakers together when deciding; speaker "
    "turns alone are not subtopics.\n\n"
    "Rules:\n"
    "1. Set subdivide=true ONLY if you can identify at least 2 distinct, "
    "well-bounded subtopics inside the passage. Each subtopic must cover "
    "a meaningful span of speech (tens of seconds, not a few words).\n"
    "2. Set subdivide=false when the passage is on a single subject, when "
    "it is short, or when any split would be artificial.\n"
    "3. The `reason` field should be one short sentence describing why."
)
_DECIDE_HUMAN = (
    "Passage runs ~{duration_min:.1f} minutes.\n\nTranscript (plain text):\n{transcript}"
)

_L2_SYS = (
    "You are a topic boundary detector. You receive a SLICE of a longer "
    "recording's transcript as time-bucketed blocks: blocks are "
    "separated by blank lines and every line is "
    "`MM:SS-MM:SS Speaker K: <words>` (the two MM:SS values are that "
    "turn's start and end). You must split the slice into ordered, "
    "non-overlapping subtopics that TILE the slice from start to end.\n\n"
    "Speaker changes are useful evidence for subtopic boundaries, but a "
    "single subtopic may span multiple speakers.\n\n"
    "Rules:\n"
    "1. Return ONLY intervals -- no titles, no summaries. Each item is "
    '{{"s": "MM:SS", "e": "MM:SS"}}.\n'
    "2. The first subtopic's `s` equals the slice start; the last "
    "subtopic's `e` equals the slice end.\n"
    "3. Adjacent subtopics tile end-to-start: for siblings A then B, set "
    "B.s = A.e (no gaps, no overlap). Order by time.\n"
    "4. Copy MM:SS values directly from the transcript's range markers. "
    "Do not invent or reformat timestamps."
)
_L2_HUMAN = (
    "Slice spans [{slice_s} to {slice_e}] (~{duration_min:.1f} min). "
    "Suggested subtopic count: ~{target_count}. Produce at least 2 "
    "subtopics that together tile the slice; if you genuinely cannot "
    "find 2 distinct subjects, return exactly 2 anyway by splitting on "
    "the clearest natural break.\n\n"
    "Time-bucketed slice transcript (blocks separated by blank lines; "
    "each line is `MM:SS-MM:SS Speaker K: <words>`):\n{transcript}"
)

_NODE_SUM_SYS = (
    "You summarize a passage of transcribed speech. Given the passage "
    "text, return a short title and a 1-2 sentence summary describing "
    "what was said.\n\n"
    "Text may include `Speaker N:` prefixes when multiple speakers are "
    "present. Consider all speakers together; the summary should "
    "describe the passage's content, mentioning who said what only "
    "when it materially aids understanding.\n\n"
    "Rules:\n"
    "1. `title` is a noun phrase, at most 120 characters. No trailing "
    "punctuation. Not a full sentence.\n"
    "2. `summary` is 1-2 sentences describing the passage's content.\n"
    "3. Do not invent content beyond what the passage says. Do not "
    "speculate about surrounding context."
)
_NODE_SUM_HUMAN = "Passage:\n{text}"

_AGG_SYS = (
    "You are summarizing a whole audio from its top-level topics. Given "
    "the topics' titles and summaries, produce a single overall title "
    "(<=120 chars) and a 2-4 sentence summary capturing the unifying "
    "theme. Do not invent content beyond what the topics describe."
)
_AGG_HUMAN = "Top-level topics:\n{children}"


def _ollama_base_url() -> str:
    """Resolve the Ollama base URL from env, falling back to localhost."""
    raw = os.environ.get("AUTORAG_OLLAMA_BASE_URL", "").strip()
    return raw or "http://localhost:11434"


def _warm_pyannote() -> None:
    """Restore the pyannote pipeline onto CUDA, populating the module
    cache that :func:`diarize.diarize_file` reads on its hot path.

    A best-effort companion to the wav2vec2 align warm-up — exists as a
    separate helper purely so it can be ``ex.submit``-ed alongside
    :func:`whisper_runner._get_align_model` on the warm-up pool.
    """
    pipeline = diarize.get_pipeline()
    if pipeline is not None:
        diarize._ensure_pipeline_on_cuda(pipeline)


def _run_whisper(file: Path, *, model_size: str, language: str | None) -> list[WordSpan]:
    if not file.exists():
        raise FileNotFoundError(f"audio file not found: {file}")
    from concurrent.futures import ThreadPoolExecutor

    tracer = get_tracer("autorag.agent")

    # ``cache.hit`` lets a reader confirm at a glance that the CTranslate2
    # rebuild is gone on jobs >= 2 — true on hit, false on miss + load.
    cache_key = (model_size, "cuda")
    with whisper_runner._MODEL_LOCK:
        cache_hit = cache_key in whisper_runner._MODEL_CACHE
    with tracer.start_as_current_span(
        "autorag.whisper.get_model",
        attributes={
            "model.size": model_size,
            "model.device_hint": "cuda",
            "cache.hit": cache_hit,
        },
    ):
        model = whisper_runner.get_model(model_size, device_hint="cuda")

    # Overlap the wav2vec2-align + pyannote CPU→CUDA restores with the
    # CT2 transcribe (the longest leg). Both populate module-level
    # caches that ``transcribe_segment``'s internal align call and
    # ``diarize_file`` read on the hot path, so a successful warm-up
    # turns those inline restore calls into cache hits. Failures are
    # warnings — the inline restore is the fallback.
    lang = language or "en"
    with ThreadPoolExecutor(max_workers=2, thread_name_prefix="warmup") as ex:
        align_fut = ex.submit(bind_current_context(whisper_runner._get_align_model), lang, "cuda")
        pyannote_fut = ex.submit(bind_current_context(_warm_pyannote))

        with tracer.start_as_current_span(
            "autorag.whisper.transcribe_segment",
            attributes={"audio.path": file.name},
        ):
            raw_words = whisper_runner.transcribe_segment(model, str(file), language)

        for fut, label in ((align_fut, "align"), (pyannote_fut, "pyannote")):
            exc = fut.exception()
            if exc is not None:
                logger.warning(
                    "whisper warmup %s failed (%s); inline restore will run.", label, exc
                )

    with tracer.start_as_current_span(
        "autorag.whisper.diarize_file",
        attributes={"audio.path": file.name},
    ):
        turns = diarize.diarize_file(str(file))

    with tracer.start_as_current_span(
        "autorag.whisper.assign_speakers",
        attributes={"words.count": len(raw_words), "turns.count": len(turns)},
    ):
        labels = diarize.assign_speakers(raw_words, turns)

    spans: list[WordSpan] = []
    for w, label in zip(raw_words, labels, strict=True):
        s = float(w["s"])
        spans.append(
            {
                "w": str(w["w"]),
                "s": s,
                "e": float(w["e"]),
                "segment_id": "single",
                "speaker": label,
            }
        )
    return spans


def _parse_ts(value: str) -> float:
    """Parse an ``MM:SS`` / ``H:MM:SS`` (or bare-number) timestamp to seconds.

    The boundary LLM copies ``MM:SS`` markers straight from the block-formatted
    transcript; we do the arithmetic here rather than trusting the model to.
    Each ``:``-separated field is a base-60 digit, so minutes may exceed 59 for
    long audio (``"120:00"`` -> 7200.0). A bare number passes through. Anything
    unparseable returns ``0.0`` — ``_snap_tile`` / ``_drop_zero`` then repair
    the degenerate node.
    """
    raw = str(value).strip()
    if not raw:
        return 0.0
    try:
        if ":" in raw:
            total = 0.0
            for part in raw.split(":"):
                total = total * 60.0 + float(part)
            return total
        return float(raw)
    except (TypeError, ValueError):
        return 0.0


def _format_words_only(spans: list[WordSpan]) -> str:
    lines: list[str] = []
    for speaker, group in _group_by_speaker(spans):
        tokens = [t for ws in group if (t := str(ws.get("w", "")).strip())]
        if tokens:
            lines.append(f"Speaker {speaker}: {' '.join(tokens)}")
    return "\n".join(lines)


def _format_children(children: list[TopicDict]) -> str:
    lines: list[str] = []
    for c in children:
        lines.append(f"- title: {c.get('title', '') or ''}")
        lines.append(f"  summary: {c.get('summary', '') or ''}")
    return "\n".join(lines)


def _slice_spans(spans: list[WordSpan], s: float, e: float) -> list[WordSpan]:
    return [w for w in spans if s <= float(w.get("s", 0.0)) <= e]


def _audio_end(spans: list[WordSpan]) -> float:
    return max((float(w.get("e", 0.0)) for w in spans), default=0.0)


def _target_count(slice_s: float, slice_e: float) -> int:
    duration_s = max(0.0, slice_e - slice_s)
    target = round(duration_s / 60.0)
    return max(2, min(7, target))


def _snap_tile(siblings: list[TopicDict], slice_s: float, slice_e: float) -> None:
    """Element-wise clamp into [slice_s, slice_e], sort by `s`, anchor first/last,
    then force `cur.s = prev.e` for every adjacent pair so any remaining gaps
    OR overlaps collapse in one pass. Mutates `siblings` in place.
    """
    if not siblings:
        return
    for c in siblings:
        cs = max(slice_s, min(slice_e, float(c.get("s", slice_s))))
        ce = max(slice_s, min(slice_e, float(c.get("e", slice_s))))
        if ce < cs:
            ce = cs
        c["s"] = cs
        c["e"] = ce
    siblings.sort(key=lambda c: float(c.get("s", 0.0)))
    siblings[0]["s"] = slice_s
    siblings[-1]["e"] = slice_e
    for prev, cur in pairwise(siblings):
        cur["s"] = float(prev.get("e", 0.0))
        if float(cur.get("e", 0.0)) < float(cur["s"]):
            cur["e"] = float(cur["s"])


def _drop_zero(siblings: list[TopicDict]) -> list[TopicDict]:
    kept: list[TopicDict] = []
    for c in siblings:
        c["children"] = _drop_zero(c.get("children") or [])
        if float(c.get("e", 0.0)) - float(c.get("s", 0.0)) > 1e-6:
            kept.append(c)
    return kept


def _new_node(s: float, e: float, *, title: str = "", summary: str = "") -> TopicDict:
    return {"title": title, "summary": summary, "s": s, "e": e, "children": []}


class _StageClosures(NamedTuple):
    """The five warm per-stage closures plus the model-eviction callable.

    Built once by :func:`_build_stage_closures` and shared by both
    :func:`build_topic_runnable` (the sequential ``_build_tree``) and
    :func:`build_stage_handlers` (the distributed/queued pipeline in
    :mod:`autorag.services`) so the two entry points are byte-for-byte
    identical in how the Ollama chat clients are constructed, kept warm,
    and evicted.
    """

    extract_l1: Callable[[list[WordSpan], float], list[TopicDict]]
    decide_subdivide: Callable[[list[TopicDict], list[WordSpan]], list[bool]]
    extract_l2: Callable[[list[TopicDict], list[WordSpan]], list[list[TopicDict]]]
    summarize_nodes: Callable[[list[TopicDict], list[WordSpan]], None]
    aggregate_l0: Callable[[list[TopicDict], float], TopicDict]
    evict: Callable[[], None]


def _build_stage_closures(
    *,
    llm_model: str = "gemma4:latest",
    ollama_base_url: str | None = None,
    num_ctx_l1: int = 8192,
    num_ctx_fanout: int = 8192,
    max_concurrency: int = 8,
    min_subdivide_duration_s: float = 120.0,
    reasoning: bool = False,
    boundary_block_seconds: int = _BOUNDARY_BLOCK_SECONDS,
) -> _StageClosures:
    """Construct the warm Ollama chains and the five per-stage closures.

    See :class:`_StageClosures`. The full Ollama-settings rationale lives
    on the public :func:`build_topic_runnable` docstring.
    """

    base_kwargs: dict[str, Any] = {
        "model": llm_model,
        "temperature": 0.0,
        "reasoning": reasoning,
        "keep_alive": "5m",
        "base_url": ollama_base_url or _ollama_base_url(),
    }

    boundary_llm_l1 = ChatOllama(num_ctx=num_ctx_l1, **base_kwargs).with_structured_output(
        _BoundaryList, method="json_schema"
    )
    boundary_llm_fanout = ChatOllama(num_ctx=num_ctx_fanout, **base_kwargs).with_structured_output(
        _BoundaryList, method="json_schema"
    )
    decide_llm = ChatOllama(num_ctx=num_ctx_fanout, **base_kwargs).with_structured_output(
        _SubdivideDecision, method="json_schema"
    )
    node_sum_llm = ChatOllama(num_ctx=num_ctx_fanout, **base_kwargs).with_structured_output(
        _NodeSummary, method="json_schema"
    )
    agg_llm = ChatOllama(num_ctx=num_ctx_fanout, **base_kwargs).with_structured_output(
        _L0Summary, method="json_schema"
    )
    # Bare (no structured output) one-token call used only to evict the model
    # from VRAM after the run. `num_ctx` matches the resident model so this
    # hits the warm model and tells Ollama to unload it, rather than
    # triggering a reload-just-to-unload.
    unload_llm = ChatOllama(
        num_ctx=num_ctx_fanout,
        model=llm_model,
        base_url=base_kwargs["base_url"],
        temperature=0.0,
        reasoning=reasoning,
        keep_alive=0,
        num_predict=1,
    )

    l1_chain = (
        ChatPromptTemplate.from_messages([("system", _L1_SYS), ("human", _L1_HUMAN)])
        | boundary_llm_l1
    )
    decide_chain = (
        ChatPromptTemplate.from_messages([("system", _DECIDE_SYS), ("human", _DECIDE_HUMAN)])
        | decide_llm
    )
    l2_chain = (
        ChatPromptTemplate.from_messages([("system", _L2_SYS), ("human", _L2_HUMAN)])
        | boundary_llm_fanout
    )
    node_sum_chain = (
        ChatPromptTemplate.from_messages([("system", _NODE_SUM_SYS), ("human", _NODE_SUM_HUMAN)])
        | node_sum_llm
    )
    agg_chain = (
        ChatPromptTemplate.from_messages([("system", _AGG_SYS), ("human", _AGG_HUMAN)]) | agg_llm
    )

    # One callback handler per build, shared across every stage. LangChain
    # assigns a fresh ``run_id`` per call, so the handler can key spans by
    # run id without stages stomping each other. ``ThreadingInstrumentor``
    # (otel.py:153) carries parent context into the ``Runnable.batch``
    # thread pool so per-item spans nest under the active stage span.
    handler = OTelSpanCallbackHandler()

    def _cfg(stage_label: str) -> RunnableConfig:
        return {
            "max_concurrency": max_concurrency,
            "callbacks": [handler],
            "metadata": {"autorag.stage": stage_label},
        }

    l1_cfg = _cfg("l1")
    decide_cfg = _cfg("decide")
    l2_cfg = _cfg("l2")
    summarize_cfg = _cfg("summarize")
    l0_cfg = _cfg("l0")

    def _extract_l1_boundaries(spans: list[WordSpan], audio_e: float) -> list[TopicDict]:
        result = cast(
            "_BoundaryList",
            l1_chain.invoke(
                {
                    "transcript": format_blocks(spans, boundary_block_seconds),
                    "audio_e": mmss(audio_e),
                    "duration_min": audio_e / 60.0,
                    "target_count": _target_count(0.0, audio_e),
                },
                config=l1_cfg,
            ),
        )
        nodes: list[TopicDict] = [_new_node(_parse_ts(b.s), _parse_ts(b.e)) for b in result.topics]
        _snap_tile(nodes, 0.0, audio_e)
        nodes = _drop_zero(nodes)
        if not nodes:
            return [_new_node(0.0, audio_e)]
        return nodes

    def _decide_subdivide_batch(l1_nodes: list[TopicDict], spans: list[WordSpan]) -> list[bool]:
        # Build inputs only for L1s long enough to plausibly subdivide;
        # short ones force False without an LLM call.
        decisions: list[bool] = [False] * len(l1_nodes)
        long_indices: list[int] = []
        long_inputs: list[dict[str, Any]] = []
        for i, l1 in enumerate(l1_nodes):
            ls = float(l1.get("s", 0.0))
            le = float(l1.get("e", 0.0))
            if (le - ls) < min_subdivide_duration_s:
                continue
            sliced = _slice_spans(spans, ls, le)
            long_indices.append(i)
            long_inputs.append(
                {
                    "transcript": _format_words_only(sliced),
                    "duration_min": (le - ls) / 60.0,
                }
            )
        if not long_inputs:
            return decisions
        results = cast(
            "list[_SubdivideDecision]",
            decide_chain.batch(long_inputs, config=decide_cfg),
        )
        for idx, dec in zip(long_indices, results, strict=True):
            decisions[idx] = bool(dec.subdivide)
        return decisions

    def _extract_l2_boundaries_batch(
        yes_l1_nodes: list[TopicDict], spans: list[WordSpan]
    ) -> list[list[TopicDict]]:
        if not yes_l1_nodes:
            return []
        inputs: list[dict[str, Any]] = []
        for l1 in yes_l1_nodes:
            ls = float(l1.get("s", 0.0))
            le = float(l1.get("e", 0.0))
            sliced = _slice_spans(spans, ls, le)
            inputs.append(
                {
                    "transcript": format_blocks(sliced, boundary_block_seconds),
                    "slice_s": mmss(ls),
                    "slice_e": mmss(le),
                    "duration_min": (le - ls) / 60.0,
                    "target_count": _target_count(ls, le),
                }
            )
        results = cast("list[_BoundaryList]", l2_chain.batch(inputs, config=l2_cfg))
        out: list[list[TopicDict]] = []
        for r, l1 in zip(results, yes_l1_nodes, strict=True):
            ls = float(l1.get("s", 0.0))
            le = float(l1.get("e", 0.0))
            kids: list[TopicDict] = [_new_node(_parse_ts(b.s), _parse_ts(b.e)) for b in r.topics]
            _snap_tile(kids, ls, le)
            kids = _drop_zero(kids)
            out.append(kids)
        return out

    def _summarize_nodes_batch(nodes: list[TopicDict], spans: list[WordSpan]) -> None:
        if not nodes:
            return
        inputs: list[dict[str, Any]] = []
        keep_idx: list[int] = []
        for i, n in enumerate(nodes):
            ns = float(n.get("s", 0.0))
            ne = float(n.get("e", 0.0))
            text = _format_words_only(_slice_spans(spans, ns, ne))
            if not text.strip():
                continue
            keep_idx.append(i)
            inputs.append({"text": text})
        if not inputs:
            return
        results = cast(
            "list[_NodeSummary]",
            node_sum_chain.batch(inputs, config=summarize_cfg),
        )
        for idx, summ in zip(keep_idx, results, strict=True):
            nodes[idx]["title"] = summ.title
            nodes[idx]["summary"] = summ.summary

    def _aggregate_l0(l1_nodes: list[TopicDict], audio_e: float) -> TopicDict:
        if not l1_nodes:
            return _new_node(0.0, audio_e, title="(empty)", summary="")
        result = cast(
            "_L0Summary",
            agg_chain.invoke({"children": _format_children(l1_nodes)}, config=l0_cfg),
        )
        return _new_node(0.0, audio_e, title=result.title, summary=result.summary)

    def _evict() -> None:
        # Evict the model from VRAM now that the run is done (or has
        # errored) so it doesn't squat memory during the downstream
        # embed/viz step. Swallow + debug-log on failure — never mask a
        # pipeline error, matching the whisper/pyannote offload idiom.
        try:
            unload_llm.invoke(".")
            logger.debug("Ollama topic model evicted (keep_alive=0).")
        except Exception as exc:
            logger.debug("Ollama model eviction failed (%s); continuing.", exc)

    return _StageClosures(
        extract_l1=_extract_l1_boundaries,
        decide_subdivide=_decide_subdivide_batch,
        extract_l2=_extract_l2_boundaries_batch,
        summarize_nodes=_summarize_nodes_batch,
        aggregate_l0=_aggregate_l0,
        evict=_evict,
    )


[docs] def build_topic_runnable( *, llm_model: str = "gemma4:latest", ollama_base_url: str | None = None, num_ctx_l1: int = 8192, num_ctx_fanout: int = 8192, max_concurrency: int = 8, min_subdivide_duration_s: float = 120.0, reasoning: bool = False, boundary_block_seconds: int = _BOUNDARY_BLOCK_SECONDS, ) -> Runnable[list[WordSpan], TopicTree]: """Build a Runnable mapping list[WordSpan] -> TopicTree (L0/L1/L2 hierarchy). Notes on Ollama settings (server-side, controlled outside this module): - Every stage uses the same `num_ctx` (`num_ctx_fanout`, default 8192) and `keep_alive="5m"`, so the model stays resident across the sub-second inter-stage gaps. Ollama reloads a model whenever `num_ctx` changes between requests, so a uniform context size is what actually keeps it warm — there are zero mid-run reloads. After Stage 5 (and on any stage error) `_build_tree` issues one throwaway `keep_alive=0` call that evicts the model so it doesn't squat VRAM during the downstream embed/viz step. The finite 5-minute `keep_alive` is a crash-safety fallback: if the run dies before the explicit eviction, Ollama still unloads the model on its own. - `temperature=0.0` plus identical system prompts per chain give per-slot prefix-cache hits across all calls inside a single chain. (This works with Ollama's default per-slot cache — it does *not* require `OLLAMA_MULTIUSER_CACHE`, which must stay unset alongside the devcontainer's `FLASH_ATTENTION=1` + concurrent slots; see `CLAUDE.md` "Ollama tuning".) - `reasoning=False` (default) disables thinking on thinking-capable models. The default `gemma4:latest` is a `thinking` model; all five stages do mechanical JSON extraction (boundaries / yes-no / `{title, summary}`) where a chain-of-thought preamble is pure latency and a structured-output parse hazard — the same rationale as `temperature=0.0`. Pass `reasoning=True` to benchmark the quality/latency trade-off (the agent-lab `gemma4-thinking` design) or with a non-thinking model where it is a no-op. - `num_ctx_l1` is still overridable. The Stage 2 (L1) call sees the whole time-bucketed transcript; on very long audio (≈1 hr+) 8192 tokens can truncate it and degrade L1 boundaries. Raising `num_ctx_l1` back to e.g. 16384 fixes that at the cost of exactly one model reload at the Stage 2→3a boundary (the L1 call then differs in `num_ctx`). - `boundary_block_seconds` (default 30) sizes the time-bucketed transcript fed to the L1/L2 boundary prompts. Smaller windows give more frequent `MM:SS` anchors (finer possible boundaries) but more lines (more boundary-prompt tokens); larger windows are terser but coarser. It does not affect the per-node summary input (plain text, no timestamps). - With `OLLAMA_NUM_PARALLEL=1` the server serializes batched requests, so Stage 3a/3b wall-clock is `N x per-call`, not `N/4 x per-call`. Raising `NUM_PARALLEL` requires more VRAM (the server reserves all slots' KV-cache up front at the request's `num_ctx`). See `CLAUDE.md` "Ollama tuning notes". """ sc = _build_stage_closures( llm_model=llm_model, ollama_base_url=ollama_base_url, num_ctx_l1=num_ctx_l1, num_ctx_fanout=num_ctx_fanout, max_concurrency=max_concurrency, min_subdivide_duration_s=min_subdivide_duration_s, reasoning=reasoning, boundary_block_seconds=boundary_block_seconds, ) def _build_tree(spans: list[WordSpan]) -> TopicTree: if not spans: return {"topics": []} audio_e = _audio_end(spans) try: logger.info("Stage 2 (L1 boundaries): 1 call, num_ctx=%d", num_ctx_l1) t0 = time.time() l1_nodes = sc.extract_l1(spans, audio_e) logger.info("Stage 2 done in %.1fs (%d L1 topics)", time.time() - t0, len(l1_nodes)) logger.info( "Stage 3a (decide subdivide): up to %d batched calls (min slice=%.0fs)", len(l1_nodes), min_subdivide_duration_s, ) t0 = time.time() decisions = sc.decide_subdivide(l1_nodes, spans) yes_count = sum(1 for d in decisions if d) logger.info( "Stage 3a done in %.1fs (%d/%d subdivide=true)", time.time() - t0, yes_count, len(l1_nodes), ) yes_l1_nodes = [l1 for l1, d in zip(l1_nodes, decisions, strict=True) if d] logger.info("Stage 3b (L2 boundaries): %d batched calls", len(yes_l1_nodes)) t0 = time.time() l2_lists = sc.extract_l2(yes_l1_nodes, spans) for l1, kids in zip(yes_l1_nodes, l2_lists, strict=True): l1["children"] = kids l2_total = sum(len(k) for k in l2_lists) logger.info("Stage 3b done in %.1fs (%d L2 topics)", time.time() - t0, l2_total) nodes_to_summarize: list[TopicDict] = [] for l1 in l1_nodes: nodes_to_summarize.append(l1) for l2 in l1.get("children") or []: nodes_to_summarize.append(l2) logger.info("Stage 4 (summarize nodes): %d batched calls", len(nodes_to_summarize)) t0 = time.time() sc.summarize_nodes(nodes_to_summarize, spans) logger.info("Stage 4 done in %.1fs", time.time() - t0) logger.info("Stage 5 (L0 aggregate): 1 call") t0 = time.time() l0 = sc.aggregate_l0(l1_nodes, audio_e) l0["children"] = l1_nodes logger.info("Stage 5 done in %.1fs", time.time() - t0) return {"topics": [l0]} finally: sc.evict() return RunnableLambda(_build_tree)
[docs] def build_stage_handlers( *, llm_model: str = "gemma4:latest", ollama_base_url: str | None = None, num_ctx_l1: int = 8192, num_ctx_fanout: int = 8192, max_concurrency: int = 8, min_subdivide_duration_s: float = 120.0, reasoning: bool = False, boundary_block_seconds: int = _BOUNDARY_BLOCK_SECONDS, ) -> dict[str, Callable[..., Any]]: """Return the per-stage closures keyed by canonical stage name. The distributed/queued pipeline (:mod:`autorag.services`) runs one stage at a time, batched across many concurrent requests, so it needs the individual stage functions rather than the sequential ``_build_tree`` that :func:`build_topic_runnable` composes. Both share :func:`_build_stage_closures`, so the warm-chain construction and the ``keep_alive=0`` eviction are identical to the in-process path. Keys: ``"l1"``, ``"decide"``, ``"l2"``, ``"summarize"``, ``"l0"`` (the boundary/summary LLM stages) and ``"evict"`` (the zero-arg ``keep_alive=0`` model-eviction call the GPU arbiter owns once a distributed run's L0 stage completes). Stage 1 (Whisper) and the persist stage are not LLM stages and live in :mod:`autorag.whisper_runner` / :class:`autorag.core.AutoRAG`. """ sc = _build_stage_closures( llm_model=llm_model, ollama_base_url=ollama_base_url, num_ctx_l1=num_ctx_l1, num_ctx_fanout=num_ctx_fanout, max_concurrency=max_concurrency, min_subdivide_duration_s=min_subdivide_duration_s, reasoning=reasoning, boundary_block_seconds=boundary_block_seconds, ) return { "l1": sc.extract_l1, "decide": sc.decide_subdivide, "l2": sc.extract_l2, "summarize": sc.summarize_nodes, "l0": sc.aggregate_l0, "evict": sc.evict, }
[docs] def build_agent( *, whisper_model: str = "base", language: str | None = "en", llm_model: str = "gemma4:latest", ollama_base_url: str | None = None, num_ctx_l1: int = 8192, num_ctx_fanout: int = 8192, max_concurrency: int = 8, min_subdivide_duration_s: float = 120.0, reasoning: bool = False, boundary_block_seconds: int = _BOUNDARY_BLOCK_SECONDS, ) -> Runnable[Path | str, TranscriptionResult]: """Build a Runnable mapping audio file -> {transcription, topics:{topics:[L0]}}.""" topic_runnable = build_topic_runnable( llm_model=llm_model, ollama_base_url=ollama_base_url, num_ctx_l1=num_ctx_l1, num_ctx_fanout=num_ctx_fanout, max_concurrency=max_concurrency, min_subdivide_duration_s=min_subdivide_duration_s, reasoning=reasoning, boundary_block_seconds=boundary_block_seconds, ) def _whisper_step(file: Path | str) -> list[WordSpan]: f = Path(file) logger.info("Stage 1 (Whisper): transcribing %s", f.name) t0 = time.time() spans = _run_whisper(f, model_size=whisper_model, language=language) logger.info("Stage 1 done in %.1fs (%d words)", time.time() - t0, len(spans)) return spans def _assemble(spans: list[WordSpan]) -> TranscriptionResult: topics: TopicTree = topic_runnable.invoke(spans) return {"transcription": spans, "topics": topics} return RunnableLambda(_whisper_step) | RunnableLambda(_assemble)
[docs] def transcribe_audio( file: Path | str, *, whisper_model: str = "base", language: str | None = "en", ) -> list[WordSpan]: """Run Whisper + diarization on a local audio file, returning word spans.""" f = Path(file) logger.info("Stage 1 (Whisper): transcribing %s", f.name) t0 = time.time() spans = _run_whisper(f, model_size=whisper_model, language=language) logger.info("Stage 1 done in %.1fs (%d words)", time.time() - t0, len(spans)) return spans
[docs] def generate_topics(words: list[WordSpan], **kwargs: Any) -> TopicTree: """Build the topic runnable and invoke it once.""" return build_topic_runnable(**kwargs).invoke(words)
[docs] def transcribe(file: Path | str, **kwargs: Any) -> TranscriptionResult: """Build the agent and invoke it once.""" return build_agent(**kwargs).invoke(file)
__all__ = [ "TopicDict", "TopicTree", "TranscriptionResult", "WordSpan", "build_agent", "build_stage_handlers", "build_topic_runnable", "generate_topics", "transcribe", "transcribe_audio", ]