""" POST /api/v2/transcribe — Transcription with speaker diarization. Calls the original Whisper server for transcription (with timestamps), then runs speaker diarization and aligns the results. """ import os import logging import tempfile from typing import List, Optional import httpx from fastapi import APIRouter, UploadFile, File, Form, HTTPException from pydantic import BaseModel from services.diarization import diarize, align_transcription_with_diarization, extract_embedding logger = logging.getLogger(__name__) router = APIRouter() WHISPER_ORIGIN = os.getenv("WHISPER_ORIGIN", "http://10.100.16.13:5003") # ── Response model ──────────────────────────────────────────────────────────── class TranscriptionSegment(BaseModel): start: float end: float speaker: str text: str class DiarizedTranscriptionResponse(BaseModel): segments: List[TranscriptionSegment] speakers_count: int audio_duration: float language: str full_text: str diarization_method: str # ── Endpoint ────────────────────────────────────────────────────────────────── @router.post( "/transcribe", response_model=DiarizedTranscriptionResponse, summary="Transcription + Speaker Diarization", description=( "Transcribes the audio using the Whisper.cpp GPU server and adds speaker diarization. " "Returns segments labelled by speaker (SPEAKER_01, SPEAKER_02, …)." ), ) async def v2_transcribe( audio: UploadFile = File(..., description="Audio file (WAV, MP3, M4A, OGG, FLAC)"), language: str = Form("pt", description="Language code (default: pt)"), num_speakers: Optional[int] = Form(None, description="Expected number of speakers (optional, auto-detected if omitted)"), ): audio_bytes = await audio.read() filename = audio.filename or "audio.wav" # Step 1 — call original Whisper server for timestamps transcription_segments = await _call_whisper_segments(audio_bytes, filename, language) # Step 2 — save to temp file for diarization suffix = os.path.splitext(filename)[-1].lower() or ".wav" with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp: tmp.write(audio_bytes) tmp_path = tmp.name try: # Step 3 — diarize diarization_segs, method = diarize(tmp_path, num_speakers=num_speakers) # Step 4 — align aligned = align_transcription_with_diarization(transcription_segments, diarization_segs) # Auto-identify speakers against enrolled profiles try: import librosa import numpy as np import soundfile as sf from services.enrollment import identify_speaker y, sr = librosa.load(tmp_path, sr=16000, mono=True) speaker_labels = sorted({s["speaker"] for s in aligned}) for speaker_label in speaker_labels: chunks = [] for dseg in diarization_segs: if dseg["speaker"] == speaker_label: s_i = int(dseg["start"] * sr) e_i = int(dseg["end"] * sr) if e_i > s_i + int(0.5 * sr): chunks.append(y[s_i:min(e_i, len(y))]) if not chunks: continue speaker_audio = np.concatenate(chunks) if len(speaker_audio) < sr * 3: continue spk_path = None try: with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as spk_tmp: sf.write(spk_tmp.name, speaker_audio, sr) spk_path = spk_tmp.name embedding = extract_embedding(spk_path) id_result = identify_speaker(embedding) if id_result["matched"]: identified_name = id_result["speaker_name"] logger.info(f"Auto-identified {speaker_label} as '{identified_name}' (conf={id_result['confidence']:.2f})") for seg in aligned: if seg["speaker"] == speaker_label: seg["speaker"] = identified_name finally: if spk_path: try: os.unlink(spk_path) except Exception: pass except Exception as auto_id_err: logger.warning(f"Auto-identification failed (non-fatal): {auto_id_err}") # Step 5 — compute metadata speakers = sorted({s["speaker"] for s in aligned}) full_text = " ".join(s["text"] for s in aligned if s.get("text")) # Audio duration from diarization or transcription if diarization_segs: audio_duration = max(s["end"] for s in diarization_segs) elif transcription_segments: audio_duration = max(s.get("end", 0) for s in transcription_segments) else: audio_duration = 0.0 return DiarizedTranscriptionResponse( segments=[TranscriptionSegment(**s) for s in aligned], speakers_count=len(speakers), audio_duration=round(audio_duration, 2), language=language, full_text=full_text.strip(), diarization_method=method, ) finally: try: os.unlink(tmp_path) except Exception: pass # ── Helper: call Whisper for transcription with timestamps ──────────────────── async def _call_whisper_segments(audio_bytes: bytes, filename: str, language: str) -> list: """ Calls the original Whisper server at /transcribe-segments and returns the segments list [{start, end, text}, ...]. """ try: async with httpx.AsyncClient(timeout=600.0) as client: files = {"file": (filename, audio_bytes, "application/octet-stream")} data = {"language": language} resp = await client.post( f"{WHISPER_ORIGIN}/transcribe-segments", files=files, data=data, ) resp.raise_for_status() result = resp.json() # Expected: {"text": "...", "language": "...", "segments": [...]} segments = result.get("segments", []) if not segments and result.get("text"): # Fallback: no segments, create a single segment segments = [{"start": 0.0, "end": 0.0, "text": result["text"]}] return segments except httpx.ConnectError as e: logger.error(f"Cannot connect to Whisper server: {e}") raise HTTPException( status_code=503, detail=f"Whisper server unavailable: {str(e)}", ) except httpx.HTTPStatusError as e: logger.error(f"Whisper server error {e.response.status_code}: {e.response.text}") raise HTTPException( status_code=502, detail=f"Whisper server returned {e.response.status_code}", ) except Exception as e: logger.error(f"Transcription error: {e}", exc_info=True) raise HTTPException(status_code=500, detail=f"Transcription failed: {str(e)}")