Source code for autorag.diarize

"""Lazy-loaded pyannote speaker-diarization pipeline with CUDA->CPU fallback.

Mirrors `whisper_runner.py` in structure: a module-level cache, a `threading.Lock`,
and a single CUDA failure flips the process to CPU for the rest of its life.

Public surface:

- `get_pipeline()` -> Pipeline | None  (None means "no token / load failed";
  callers should fall back to single-speaker behavior)
- `diarize_file(path)` -> list[(start_s, end_s, speaker_label)]
- `assign_speakers(words, turns)` -> list[str]  (parallel labels, '0' fallback)
"""

from __future__ import annotations

import logging
import os
import shutil
import subprocess
import tempfile
import threading
import warnings
from pathlib import Path
from typing import Any

from autorag.otel import get_tracer

logger = logging.getLogger(__name__)

# Module-level state -----------------------------------------------------------

_PIPELINE_LOCK = threading.Lock()
_PIPELINE: Any | None = None
_PIPELINE_LOAD_ATTEMPTED = False
_cpu_pinned = False  # set True after any CUDA failure for the process lifetime
_pipeline_device: str = "cpu"  # tracks current device of _PIPELINE after load

_MODEL_NAME = "pyannote/speaker-diarization-3.1"

# pyannote/torchaudio decodes these reliably; everything else gets transcoded.
_NATIVE_AUDIO_EXTS = frozenset({".wav", ".flac"})


def _torch_cuda_available() -> bool:
    try:
        import torch
    except Exception:
        return False
    try:
        return bool(torch.cuda.is_available())
    except Exception:
        return False


def _device_preference() -> str:
    if _cpu_pinned:
        return "cpu"
    raw = os.environ.get("AUTORAG_WHISPER_DEVICE", "auto").strip().lower()
    if raw == "cpu":
        return "cpu"
    return "cuda" if _torch_cuda_available() else "cpu"


def _ffmpeg_exe() -> str | None:
    """Locate ffmpeg, preferring system PATH then imageio_ffmpeg's bundled binary."""
    found = shutil.which("ffmpeg")
    if found:
        return found
    try:
        import imageio_ffmpeg

        return str(imageio_ffmpeg.get_ffmpeg_exe())
    except Exception:
        return None


def _ensure_ffmpeg_on_path() -> None:
    """Same trick as whisper_runner: pyannote loads audio via torchaudio,
    which shells out to ffmpeg on some backends."""
    ff = _ffmpeg_exe()
    if not ff:
        return
    ff_dir = str(Path(ff).parent)
    current = os.environ.get("PATH", "")
    parts = current.split(os.pathsep) if current else []
    if ff_dir and ff_dir not in parts:
        os.environ["PATH"] = (ff_dir + os.pathsep + current) if current else ff_dir


def _transcode_to_wav(src: str, dst: str) -> bool:
    """Decode `src` → 16 kHz mono PCM wav at `dst`. Returns True on success."""
    ff = _ffmpeg_exe()
    if not ff:
        logger.warning("ffmpeg unavailable; cannot transcode %s for diarization.", src)
        return False
    try:
        subprocess.run(
            [ff, "-y", "-loglevel", "error", "-i", src, "-ac", "1", "-ar", "16000", "-vn", dst],
            check=True,
        )
    except (subprocess.CalledProcessError, FileNotFoundError) as exc:
        logger.warning("ffmpeg transcode failed for %s (%s); diarization skipped.", src, exc)
        return False
    return True


def _hf_token() -> str | None:
    raw = os.environ.get("HF_TOKEN", "").strip()
    return raw or None


def _is_cuda_error(exc: BaseException) -> bool:
    cls = type(exc).__name__.lower()
    msg = (str(exc) or "").lower()
    if "cuda" in cls or "cuda" in msg:
        return True
    if "out of memory" in msg:
        return True
    return "nvml" in msg or ("driver" in msg and "cuda" in msg)


def _pin_cpu(reason: str) -> None:
    global _cpu_pinned
    if not _cpu_pinned:
        logger.warning("Pinning pyannote to CPU for remainder of process: %s", reason)
    _cpu_pinned = True


def _ensure_pipeline_on_cuda(pipeline: Any) -> None:
    """Re-move an offloaded pipeline back to CUDA before inference."""
    global _pipeline_device
    if _pipeline_device == "cuda" or _cpu_pinned or _device_preference() != "cuda":
        return
    tracer = get_tracer("autorag.pyannote")
    with tracer.start_as_current_span(
        "autorag.pyannote.ensure_on_cuda",
        attributes={"pipeline.previous_device": _pipeline_device},
    ):
        try:
            import torch

            pipeline.to(torch.device("cuda"))
            _pipeline_device = "cuda"
            logger.debug("pyannote pipeline moved back to CUDA for inference.")
        except Exception as exc:
            _pin_cpu(str(exc))


def _offload_pipeline(pipeline: Any) -> None:
    """Move pipeline to CPU and free VRAM after inference completes."""
    global _pipeline_device
    if _pipeline_device != "cuda":
        return
    tracer = get_tracer("autorag.pyannote")
    with tracer.start_as_current_span("autorag.pyannote.offload"):
        try:
            import torch

            pipeline.to(torch.device("cpu"))
            torch.cuda.empty_cache()
            _pipeline_device = "cpu"
            logger.debug("pyannote pipeline offloaded to CPU; VRAM freed.")
        except Exception as exc:
            logger.debug("pyannote VRAM offload failed (%s); continuing.", exc)


def _load_pipeline_on(device: str) -> Any | None:
    token = _hf_token()
    if not token:
        logger.warning(
            "HF_TOKEN not set; speaker diarization disabled. "
            "All words will be labeled as a single speaker."
        )
        return None

    try:
        from pyannote.audio import Pipeline
    except Exception as exc:
        logger.warning("pyannote.audio import failed (%s); diarization disabled.", exc)
        return None

    try:
        pipeline = Pipeline.from_pretrained(_MODEL_NAME, token=token)
    except Exception as exc:
        logger.warning("Failed to load %s (%s); diarization disabled.", _MODEL_NAME, exc)
        return None
    if pipeline is None:
        # pyannote returns None when the model cannot be loaded (e.g. accepted
        # license missing for the gated repo).
        logger.warning(
            "%s could not be loaded (token may lack access to the gated model); "
            "diarization disabled.",
            _MODEL_NAME,
        )
        return None

    if device == "cuda":
        try:
            import torch

            pipeline.to(torch.device("cuda"))
            global _pipeline_device
            _pipeline_device = "cuda"
        except Exception as exc:
            if _is_cuda_error(exc):
                _pin_cpu(str(exc))
                # Pipeline already loaded on CPU by from_pretrained default;
                # no reload needed.
                logger.warning("pyannote CUDA move failed (%s); using CPU.", exc)
            else:
                raise
    logger.info("Loaded pyannote pipeline %s on device=%s", _MODEL_NAME, device)
    return pipeline


[docs] def get_pipeline() -> Any | None: """Return the cached pyannote pipeline, loading on first call. Returns None if HF_TOKEN is missing or load failed; callers MUST handle None by skipping diarization. """ global _PIPELINE, _PIPELINE_LOAD_ATTEMPTED with _PIPELINE_LOCK: if _PIPELINE_LOAD_ATTEMPTED: return _PIPELINE _PIPELINE_LOAD_ATTEMPTED = True _ensure_ffmpeg_on_path() _PIPELINE = _load_pipeline_on(_device_preference()) return _PIPELINE
[docs] def diarize_file(file_path: str) -> list[tuple[float, float, str]]: """Run diarization. Returns sorted [(start, end, label), ...] or [] on failure. pyannote/torchaudio only decodes a small set of containers reliably (wav, flac); everything else (webm, mp3, m4a, ogg, ...) is transcoded to a temporary 16 kHz mono wav with ffmpeg first. """ pipeline = get_pipeline() if pipeline is None: return [] ext = Path(file_path).suffix.lower() if ext in _NATIVE_AUDIO_EXTS: return _run_diarization(pipeline, file_path) with tempfile.TemporaryDirectory(prefix="autorag-diarize-") as tmpdir: wav_path = str(Path(tmpdir) / "diarize_input.wav") if not _transcode_to_wav(file_path, wav_path): return [] return _run_diarization(pipeline, wav_path)
def _run_diarization(pipeline: Any, audio_path: str) -> list[tuple[float, float, str]]: global _pipeline_device tracer = get_tracer("autorag.pyannote") _ensure_pipeline_on_cuda(pipeline) try: with warnings.catch_warnings(): # pyannote's StatsPool calls std(correction=1) on single-frame segments # (dof=0 → NaN), which it handles internally. Suppress the noise. warnings.filterwarnings( "ignore", message=r"std\(\).*degrees of freedom", category=UserWarning, ) with tracer.start_as_current_span( "autorag.pyannote.inference", attributes={ "audio.path": Path(audio_path).name, "device": _pipeline_device, }, ): diarization = pipeline(audio_path) except Exception as exc: # pragma: no cover - hardware-dependent if _is_cuda_error(exc) and not _cpu_pinned: logger.warning("pyannote CUDA failure on %s (%s); retrying on CPU.", audio_path, exc) _pin_cpu(str(exc)) try: import torch pipeline.to(torch.device("cpu")) _pipeline_device = "cpu" with tracer.start_as_current_span( "autorag.pyannote.inference", attributes={"audio.path": Path(audio_path).name, "device": "cpu"}, ): diarization = pipeline(audio_path) except Exception as exc2: logger.warning("pyannote CPU retry failed (%s); skipping diarization.", exc2) return [] else: logger.warning("pyannote diarization failed on %s (%s); skipping.", audio_path, exc) return [] # pyannote 4.x wraps the Annotation in a DiarizeOutput; older versions # return the Annotation directly. Unwrap if needed. annotation = getattr(diarization, "speaker_diarization", diarization) turns: list[tuple[float, float, str]] = [] try: for turn, _, speaker in annotation.itertracks(yield_label=True): turns.append((float(turn.start), float(turn.end), str(speaker))) except Exception as exc: logger.warning("Failed to extract pyannote tracks (%s); skipping.", exc) return [] turns.sort(key=lambda t: t[0]) result = _normalize_labels(turns) _offload_pipeline(pipeline) return result def _normalize_labels( turns: list[tuple[float, float, str]], ) -> list[tuple[float, float, str]]: """Map pyannote labels (SPEAKER_00, SPEAKER_01, ...) to '0', '1', ... in first-appearance order. Keeps prompt output clean and stable.""" mapping: dict[str, str] = {} out: list[tuple[float, float, str]] = [] for s, e, raw in turns: if raw not in mapping: mapping[raw] = str(len(mapping)) out.append((s, e, mapping[raw])) return out
[docs] def assign_speakers( words: list[dict[str, Any]], turns: list[tuple[float, float, str]], ) -> list[str]: """Assign a speaker label to each word. Strategy: pick the turn with maximum temporal overlap with the word's [s, e] interval. If no turn overlaps, fall back to the nearest turn (by midpoint distance). If `turns` is empty, every word becomes "0". """ if not turns: return ["0"] * len(words) labels: list[str] = [] for w in words: try: ws = float(w.get("s", 0.0)) we = float(w.get("e", ws)) except (TypeError, ValueError): ws = we = 0.0 if we < ws: we = ws best_overlap = 0.0 best_label: str | None = None for ts, te, label in turns: overlap = max(0.0, min(we, te) - max(ws, ts)) if overlap > best_overlap: best_overlap = overlap best_label = label if best_label is None: wmid = (ws + we) / 2.0 best_dist = float("inf") for ts, te, label in turns: tmid = (ts + te) / 2.0 dist = abs(wmid - tmid) if dist < best_dist: best_dist = dist best_label = label labels.append(best_label or "0") return labels
__all__ = [ "assign_speakers", "diarize_file", "get_pipeline", ]