"""Lazy-loaded whisperX model cache with CUDA-preferred / CPU-fallback device selection.
The main transcription model (CTranslate2 / faster-whisper backend) is removed
from the module cache after each run so Python GC can free VRAM; the smaller
wav2vec2 alignment model is offloaded to CPU after aligning and restored on the
next call (PyTorch .to() round-trip). Both are re-created from local HF cache
on the next pipeline run, which is fast (<1 s for models already downloaded).
"""
from __future__ import annotations
import logging
import os
import shutil
import threading
import time
from pathlib import Path
from typing import Any
from autorag.otel import get_tracer
logger = logging.getLogger(__name__)
# Module-level state -----------------------------------------------------------
_MODEL_LOCK = threading.Lock()
_MODEL_CACHE: dict[tuple[str, str], Any] = {} # (size, device) → whisperx model
_ALIGN_CACHE: dict[tuple[str, str], tuple[Any, Any]] = {} # (lang, device) → (align_model, meta)
_cpu_pinned = False
_device_log_emitted = False
_resolved_device: str | None = None
def _torch_cuda_available() -> bool:
try:
import torch
except Exception:
return False
try:
return bool(torch.cuda.is_available())
except Exception:
return False
def _device_preference() -> str:
"""Resolve the preferred device honoring ``AUTORAG_WHISPER_DEVICE``."""
if _cpu_pinned:
return "cpu"
raw = os.environ.get("AUTORAG_WHISPER_DEVICE", "auto").strip().lower()
if raw == "cpu":
return "cpu"
if raw == "cuda":
return "cuda" if _torch_cuda_available() else "cpu"
return "cuda" if _torch_cuda_available() else "cpu"
def _ensure_ffmpeg_on_path() -> None:
"""Reuse imageio_ffmpeg's bundled binary so whisperX's subprocess finds it."""
ff = shutil.which("ffmpeg")
if not ff:
try:
import imageio_ffmpeg
ff = str(imageio_ffmpeg.get_ffmpeg_exe())
except Exception:
raise RuntimeError("missing ffmpeg") from None
ff_dir = str(Path(ff).parent)
current = os.environ.get("PATH", "")
parts = current.split(os.pathsep) if current else []
if ff_dir and ff_dir not in parts:
os.environ["PATH"] = (ff_dir + os.pathsep + current) if current else ff_dir
[docs]
def resolved_device() -> str:
"""Return the device most recently used (or the preference if nothing loaded yet)."""
if _resolved_device is not None:
return _resolved_device
return _device_preference()
def _compute_type(device: str) -> str:
return "float16" if device == "cuda" else "int8"
def _load_model_on(size: str, device: str) -> Any:
import whisperx
compute = _compute_type(device)
logger.info("Loading whisperX model size=%s device=%s compute_type=%s", size, device, compute)
# Cache-miss-only span: ``get_model`` short-circuits cached entries
# before calling here, so this fires at most once per (size, device)
# per worker lifetime — its presence in any job >= 2 is a regression
# signal for the CT2 residency contract.
tracer = get_tracer("autorag.whisper")
with tracer.start_as_current_span(
"autorag.whisper.load_model",
attributes={
"model.size": size,
"model.device": device,
"model.compute_type": compute,
},
):
return whisperx.load_model(size, device, compute_type=compute)
[docs]
def get_model(size: str, device_hint: str | None = None) -> Any:
"""Return a cached whisperX model for *size*.
``device_hint`` is advisory: ignored when the process is already CPU-pinned.
"""
global _device_log_emitted, _resolved_device
_ensure_ffmpeg_on_path()
device = (device_hint or "").strip().lower() or _device_preference()
if _cpu_pinned:
device = "cpu"
if device not in ("cuda", "cpu"):
device = "cpu"
key = (size, device)
with _MODEL_LOCK:
if not _device_log_emitted:
logger.info("whisperX device preference resolved to: %s", device)
_device_log_emitted = True
cached = _MODEL_CACHE.get(key)
if cached is not None:
_resolved_device = device
return cached
model = _load_model_on(size, device)
_MODEL_CACHE[key] = model
_resolved_device = device
return model
def _pin_cpu(reason: str) -> None:
global _cpu_pinned
if not _cpu_pinned:
logger.warning("Pinning whisperX to CPU for remainder of process: %s", reason)
_cpu_pinned = True
def _is_cuda_error(exc: BaseException) -> bool:
cls = type(exc).__name__.lower()
msg = (str(exc) or "").lower()
if "cuda" in cls or "cuda" in msg:
return True
if "out of memory" in msg:
return True
return "nvml" in msg or ("driver" in msg and "cuda" in msg)
def _get_align_model(language: str, device: str) -> tuple[Any, Any]:
"""Return (align_model, metadata) for *language*, restoring from CPU cache when possible."""
tracer = get_tracer("autorag.whisper")
key = (language, device)
# Read the cache holding only the lock, then emit the span outside —
# ``start_as_current_span`` can call into SpanProcessors / exporters,
# so we never hold ``_MODEL_LOCK`` across a tracer call.
with _MODEL_LOCK:
cached = _ALIGN_CACHE.get(key)
cpu_cached = _ALIGN_CACHE.get((language, "cpu"))
if cached is not None:
with tracer.start_as_current_span(
"autorag.whisper.get_align_model",
attributes={
"align.language": language,
"align.device": device,
"align.restored_from": "cuda",
},
):
pass
return cached
if device == "cuda" and cpu_cached is not None:
model_a, metadata = cpu_cached
try:
import torch
with tracer.start_as_current_span(
"autorag.whisper.get_align_model",
attributes={
"align.language": language,
"align.device": device,
"align.restored_from": "cpu_to_cuda_restore",
},
):
model_a.to(torch.device("cuda"))
with _MODEL_LOCK:
_ALIGN_CACHE[(language, "cuda")] = (model_a, metadata)
_ALIGN_CACHE.pop((language, "cpu"), None)
logger.debug("whisperX align model restored to CUDA.")
return model_a, metadata
except Exception as exc:
logger.warning("whisperX align model CUDA restore failed (%s); reloading.", exc)
import whisperx
logger.info("Loading whisperX align model language=%s device=%s", language, device)
with tracer.start_as_current_span(
"autorag.whisper.get_align_model",
attributes={
"align.language": language,
"align.device": device,
"align.restored_from": "fresh_load",
},
):
model_a, metadata = whisperx.load_align_model(language_code=language, device=device)
with _MODEL_LOCK:
_ALIGN_CACHE[(language, device)] = (model_a, metadata)
return model_a, metadata
def _offload_align_model(language: str) -> None:
"""Move the wav2vec2 alignment model to CPU and free VRAM."""
tracer = get_tracer("autorag.whisper")
with tracer.start_as_current_span(
"autorag.whisper.offload_align",
attributes={"align.language": language},
):
try:
import torch
with _MODEL_LOCK:
cuda_key = (language, "cuda")
cached = _ALIGN_CACHE.get(cuda_key)
if cached is None:
return
model_a, metadata = cached
model_a.to(torch.device("cpu"))
_ALIGN_CACHE[(language, "cpu")] = (model_a, metadata)
del _ALIGN_CACHE[cuda_key]
torch.cuda.empty_cache()
logger.debug("whisperX align model offloaded to CPU; VRAM freed.")
except Exception as exc:
logger.debug("whisperX align model offload failed (%s); continuing.", exc)
[docs]
def transcribe_segment(
model: Any,
file_path: str,
language: str | None,
) -> list[dict[str, Any]]:
"""Transcribe *file_path* and return frame-aligned word dicts.
Each dict: ``{"w": str, "s": float, "e": float, "p": float}``.
The alignment pass uses wav2vec2 for frame-accurate word timestamps; if it
fails the unaligned faster-whisper timestamps are used as a fallback.
"""
import whisperx
tracer = get_tracer("autorag.whisper")
_ensure_ffmpeg_on_path()
with tracer.start_as_current_span(
"autorag.whisper.load_audio",
attributes={"audio.path": Path(file_path).name},
) as load_span:
audio = whisperx.load_audio(file_path)
# whisperx resamples to 16 kHz mono in load_audio; sample-count / SR
# is the source-of-truth for the actual audio length the model sees.
try:
audio_duration_s = float(len(audio)) / 16000.0
load_span.set_attribute("audio.duration_s", audio_duration_s)
except Exception: # pragma: no cover - defensive
audio_duration_s = 0.0
device = resolved_device() or "cpu"
batch_size = 16 if device == "cuda" else 4
transcribe_kwargs: dict[str, Any] = {"batch_size": batch_size}
if language:
transcribe_kwargs["language"] = language
ct2_start = time.perf_counter()
with tracer.start_as_current_span(
"autorag.whisper.ct2_transcribe",
attributes={
"audio.batch_size": batch_size,
"model.device": device,
"language": language or "auto",
},
) as ct2_span:
try:
result: dict[str, Any] = model.transcribe(audio, **transcribe_kwargs)
except Exception as exc: # pragma: no cover - hardware-dependent
if _is_cuda_error(exc) and not _cpu_pinned:
logger.warning(
"whisperX CUDA failure on %s (%s); reloading on CPU and retrying once.",
file_path,
exc,
)
_pin_cpu(str(exc))
size_guess = _current_model_size(model) or "base"
cpu_model = get_model(size_guess, device_hint="cpu")
cpu_kwargs: dict[str, Any] = {"batch_size": 4}
if language:
cpu_kwargs["language"] = language
result = cpu_model.transcribe(audio, **cpu_kwargs)
model = cpu_model
device = "cpu"
ct2_span.set_attribute("model.device", device)
else:
raise
ct2_wall = max(time.perf_counter() - ct2_start, 1e-9)
if audio_duration_s > 0:
ct2_span.set_attribute("transcribe.realtime_factor", audio_duration_s / ct2_wall)
detected_language: str = result.get("language") or language or "en"
try:
model_a, metadata = _get_align_model(detected_language, device)
with tracer.start_as_current_span(
"autorag.whisper.align",
attributes={
"align.language": detected_language,
"align.device": device,
"segments.count": len(result.get("segments", []) or []),
},
):
aligned: dict[str, Any] = whisperx.align(
result["segments"], model_a, metadata, audio, device, return_char_alignments=False
)
segments: list[Any] = aligned.get("segments", result.get("segments", []))
except Exception as exc:
logger.warning("whisperX alignment failed (%s); using unaligned timestamps.", exc)
segments = result.get("segments", [])
words_out: list[dict[str, Any]] = []
for seg in segments:
raw_words = seg.get("words") if isinstance(seg, dict) else None
if not raw_words:
continue
for w in raw_words:
try:
token = str(w.get("word", "") or "")
start_raw = w.get("start")
end_raw = w.get("end")
# whisperX omits start/end for words it could not align — skip them.
if start_raw is None and end_raw is None:
continue
start = float(start_raw or 0.0)
end = float(end_raw or start)
prob = float(w.get("score", w.get("probability", 0.0)) or 0.0)
except (TypeError, ValueError):
continue
if not token.strip():
continue
words_out.append({"w": token, "s": start, "e": end, "p": prob})
# The CT2 model stays resident in ``_MODEL_CACHE`` for the worker's
# lifetime: this per-call destroy is gone (residency fix) and
# ``GpuArbiter._default_offload_whisper`` no longer drops the cache on
# ``whisper -> llm`` flips either (only the torch parts go to CPU). The
# ~0.5 s CUDA-JIT rebuild therefore fires at most once per worker boot.
_offload_align_model(detected_language)
return words_out
def _offload_main_model(model: Any) -> None:
"""Remove the whisperX model from cache; CTranslate2 frees VRAM when the object is GC'd."""
with _MODEL_LOCK:
keys_to_del = [k for k, v in _MODEL_CACHE.items() if v is model]
for k in keys_to_del:
del _MODEL_CACHE[k]
if keys_to_del:
logger.debug("whisperX model removed from cache; VRAM freed on GC.")
def _current_model_size(model: Any) -> str | None:
"""Best-effort reverse lookup of a model's cache key for its size."""
with _MODEL_LOCK:
for (size, _), m in _MODEL_CACHE.items():
if m is model:
return size
return None
__all__ = [
"get_model",
"resolved_device",
"transcribe_segment",
]