Maia/diarization/routers/v2_transcribe.py

193 lines
7.4 KiB
Python

"""
POST /api/v2/transcribe — Transcription with speaker diarization.
Calls the original Whisper server for transcription (with timestamps),
then runs speaker diarization and aligns the results.
"""
import os
import logging
import tempfile
from typing import List, Optional
import httpx
from fastapi import APIRouter, UploadFile, File, Form, HTTPException
from pydantic import BaseModel
from services.diarization import diarize, align_transcription_with_diarization, extract_embedding
logger = logging.getLogger(__name__)
router = APIRouter()
WHISPER_ORIGIN = os.getenv("WHISPER_ORIGIN", "http://10.100.16.13:5003")
# ── Response model ────────────────────────────────────────────────────────────
class TranscriptionSegment(BaseModel):
start: float
end: float
speaker: str
text: str
class DiarizedTranscriptionResponse(BaseModel):
segments: List[TranscriptionSegment]
speakers_count: int
audio_duration: float
language: str
full_text: str
diarization_method: str
# ── Endpoint ──────────────────────────────────────────────────────────────────
@router.post(
"/transcribe",
response_model=DiarizedTranscriptionResponse,
summary="Transcription + Speaker Diarization",
description=(
"Transcribes the audio using the Whisper.cpp GPU server and adds speaker diarization. "
"Returns segments labelled by speaker (SPEAKER_01, SPEAKER_02, …)."
),
)
async def v2_transcribe(
audio: UploadFile = File(..., description="Audio file (WAV, MP3, M4A, OGG, FLAC)"),
language: str = Form("pt", description="Language code (default: pt)"),
num_speakers: Optional[int] = Form(None, description="Expected number of speakers (optional, auto-detected if omitted)"),
):
audio_bytes = await audio.read()
filename = audio.filename or "audio.wav"
# Step 1 — call original Whisper server for timestamps
transcription_segments = await _call_whisper_segments(audio_bytes, filename, language)
# Step 2 — save to temp file for diarization
suffix = os.path.splitext(filename)[-1].lower() or ".wav"
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp:
tmp.write(audio_bytes)
tmp_path = tmp.name
try:
# Step 3 — diarize
diarization_segs, method = diarize(tmp_path, num_speakers=num_speakers)
# Step 4 — align
aligned = align_transcription_with_diarization(transcription_segments, diarization_segs)
# Auto-identify speakers against enrolled profiles
try:
import librosa
import numpy as np
import soundfile as sf
from services.enrollment import identify_speaker
y, sr = librosa.load(tmp_path, sr=16000, mono=True)
speaker_labels = sorted({s["speaker"] for s in aligned})
for speaker_label in speaker_labels:
chunks = []
for dseg in diarization_segs:
if dseg["speaker"] == speaker_label:
s_i = int(dseg["start"] * sr)
e_i = int(dseg["end"] * sr)
if e_i > s_i + int(0.5 * sr):
chunks.append(y[s_i:min(e_i, len(y))])
if not chunks:
continue
speaker_audio = np.concatenate(chunks)
if len(speaker_audio) < sr * 3:
continue
spk_path = None
try:
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as spk_tmp:
sf.write(spk_tmp.name, speaker_audio, sr)
spk_path = spk_tmp.name
embedding = extract_embedding(spk_path)
id_result = identify_speaker(embedding)
if id_result["matched"]:
identified_name = id_result["speaker_name"]
logger.info(f"Auto-identified {speaker_label} as '{identified_name}' (conf={id_result['confidence']:.2f})")
for seg in aligned:
if seg["speaker"] == speaker_label:
seg["speaker"] = identified_name
finally:
if spk_path:
try:
os.unlink(spk_path)
except Exception:
pass
except Exception as auto_id_err:
logger.warning(f"Auto-identification failed (non-fatal): {auto_id_err}")
# Step 5 — compute metadata
speakers = sorted({s["speaker"] for s in aligned})
full_text = " ".join(s["text"] for s in aligned if s.get("text"))
# Audio duration from diarization or transcription
if diarization_segs:
audio_duration = max(s["end"] for s in diarization_segs)
elif transcription_segments:
audio_duration = max(s.get("end", 0) for s in transcription_segments)
else:
audio_duration = 0.0
return DiarizedTranscriptionResponse(
segments=[TranscriptionSegment(**s) for s in aligned],
speakers_count=len(speakers),
audio_duration=round(audio_duration, 2),
language=language,
full_text=full_text.strip(),
diarization_method=method,
)
finally:
try:
os.unlink(tmp_path)
except Exception:
pass
# ── Helper: call Whisper for transcription with timestamps ────────────────────
async def _call_whisper_segments(audio_bytes: bytes, filename: str, language: str) -> list:
"""
Calls the original Whisper server at /transcribe-segments and returns
the segments list [{start, end, text}, ...].
"""
try:
async with httpx.AsyncClient(timeout=600.0) as client:
files = {"file": (filename, audio_bytes, "application/octet-stream")}
data = {"language": language}
resp = await client.post(
f"{WHISPER_ORIGIN}/transcribe-segments",
files=files,
data=data,
)
resp.raise_for_status()
result = resp.json()
# Expected: {"text": "...", "language": "...", "segments": [...]}
segments = result.get("segments", [])
if not segments and result.get("text"):
# Fallback: no segments, create a single segment
segments = [{"start": 0.0, "end": 0.0, "text": result["text"]}]
return segments
except httpx.ConnectError as e:
logger.error(f"Cannot connect to Whisper server: {e}")
raise HTTPException(
status_code=503,
detail=f"Whisper server unavailable: {str(e)}",
)
except httpx.HTTPStatusError as e:
logger.error(f"Whisper server error {e.response.status_code}: {e.response.text}")
raise HTTPException(
status_code=502,
detail=f"Whisper server returned {e.response.status_code}",
)
except Exception as e:
logger.error(f"Transcription error: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Transcription failed: {str(e)}")