193 lines
7.4 KiB
Python
193 lines
7.4 KiB
Python
"""
|
|
POST /api/v2/transcribe — Transcription with speaker diarization.
|
|
|
|
Calls the original Whisper server for transcription (with timestamps),
|
|
then runs speaker diarization and aligns the results.
|
|
"""
|
|
import os
|
|
import logging
|
|
import tempfile
|
|
from typing import List, Optional
|
|
|
|
import httpx
|
|
from fastapi import APIRouter, UploadFile, File, Form, HTTPException
|
|
from pydantic import BaseModel
|
|
|
|
from services.diarization import diarize, align_transcription_with_diarization, extract_embedding
|
|
|
|
logger = logging.getLogger(__name__)
|
|
router = APIRouter()
|
|
|
|
WHISPER_ORIGIN = os.getenv("WHISPER_ORIGIN", "http://10.100.16.13:5003")
|
|
|
|
|
|
# ── Response model ────────────────────────────────────────────────────────────
|
|
class TranscriptionSegment(BaseModel):
|
|
start: float
|
|
end: float
|
|
speaker: str
|
|
text: str
|
|
|
|
|
|
class DiarizedTranscriptionResponse(BaseModel):
|
|
segments: List[TranscriptionSegment]
|
|
speakers_count: int
|
|
audio_duration: float
|
|
language: str
|
|
full_text: str
|
|
diarization_method: str
|
|
|
|
|
|
# ── Endpoint ──────────────────────────────────────────────────────────────────
|
|
@router.post(
|
|
"/transcribe",
|
|
response_model=DiarizedTranscriptionResponse,
|
|
summary="Transcription + Speaker Diarization",
|
|
description=(
|
|
"Transcribes the audio using the Whisper.cpp GPU server and adds speaker diarization. "
|
|
"Returns segments labelled by speaker (SPEAKER_01, SPEAKER_02, …)."
|
|
),
|
|
)
|
|
async def v2_transcribe(
|
|
audio: UploadFile = File(..., description="Audio file (WAV, MP3, M4A, OGG, FLAC)"),
|
|
language: str = Form("pt", description="Language code (default: pt)"),
|
|
num_speakers: Optional[int] = Form(None, description="Expected number of speakers (optional, auto-detected if omitted)"),
|
|
):
|
|
audio_bytes = await audio.read()
|
|
filename = audio.filename or "audio.wav"
|
|
|
|
# Step 1 — call original Whisper server for timestamps
|
|
transcription_segments = await _call_whisper_segments(audio_bytes, filename, language)
|
|
|
|
# Step 2 — save to temp file for diarization
|
|
suffix = os.path.splitext(filename)[-1].lower() or ".wav"
|
|
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp:
|
|
tmp.write(audio_bytes)
|
|
tmp_path = tmp.name
|
|
|
|
try:
|
|
# Step 3 — diarize
|
|
diarization_segs, method = diarize(tmp_path, num_speakers=num_speakers)
|
|
|
|
# Step 4 — align
|
|
aligned = align_transcription_with_diarization(transcription_segments, diarization_segs)
|
|
|
|
# Auto-identify speakers against enrolled profiles
|
|
try:
|
|
import librosa
|
|
import numpy as np
|
|
import soundfile as sf
|
|
from services.enrollment import identify_speaker
|
|
|
|
y, sr = librosa.load(tmp_path, sr=16000, mono=True)
|
|
speaker_labels = sorted({s["speaker"] for s in aligned})
|
|
|
|
for speaker_label in speaker_labels:
|
|
chunks = []
|
|
for dseg in diarization_segs:
|
|
if dseg["speaker"] == speaker_label:
|
|
s_i = int(dseg["start"] * sr)
|
|
e_i = int(dseg["end"] * sr)
|
|
if e_i > s_i + int(0.5 * sr):
|
|
chunks.append(y[s_i:min(e_i, len(y))])
|
|
|
|
if not chunks:
|
|
continue
|
|
|
|
speaker_audio = np.concatenate(chunks)
|
|
if len(speaker_audio) < sr * 3:
|
|
continue
|
|
|
|
spk_path = None
|
|
try:
|
|
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as spk_tmp:
|
|
sf.write(spk_tmp.name, speaker_audio, sr)
|
|
spk_path = spk_tmp.name
|
|
|
|
embedding = extract_embedding(spk_path)
|
|
id_result = identify_speaker(embedding)
|
|
|
|
if id_result["matched"]:
|
|
identified_name = id_result["speaker_name"]
|
|
logger.info(f"Auto-identified {speaker_label} as '{identified_name}' (conf={id_result['confidence']:.2f})")
|
|
for seg in aligned:
|
|
if seg["speaker"] == speaker_label:
|
|
seg["speaker"] = identified_name
|
|
finally:
|
|
if spk_path:
|
|
try:
|
|
os.unlink(spk_path)
|
|
except Exception:
|
|
pass
|
|
except Exception as auto_id_err:
|
|
logger.warning(f"Auto-identification failed (non-fatal): {auto_id_err}")
|
|
|
|
# Step 5 — compute metadata
|
|
speakers = sorted({s["speaker"] for s in aligned})
|
|
full_text = " ".join(s["text"] for s in aligned if s.get("text"))
|
|
|
|
# Audio duration from diarization or transcription
|
|
if diarization_segs:
|
|
audio_duration = max(s["end"] for s in diarization_segs)
|
|
elif transcription_segments:
|
|
audio_duration = max(s.get("end", 0) for s in transcription_segments)
|
|
else:
|
|
audio_duration = 0.0
|
|
|
|
return DiarizedTranscriptionResponse(
|
|
segments=[TranscriptionSegment(**s) for s in aligned],
|
|
speakers_count=len(speakers),
|
|
audio_duration=round(audio_duration, 2),
|
|
language=language,
|
|
full_text=full_text.strip(),
|
|
diarization_method=method,
|
|
)
|
|
|
|
finally:
|
|
try:
|
|
os.unlink(tmp_path)
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
# ── Helper: call Whisper for transcription with timestamps ────────────────────
|
|
async def _call_whisper_segments(audio_bytes: bytes, filename: str, language: str) -> list:
|
|
"""
|
|
Calls the original Whisper server at /transcribe-segments and returns
|
|
the segments list [{start, end, text}, ...].
|
|
"""
|
|
try:
|
|
async with httpx.AsyncClient(timeout=600.0) as client:
|
|
files = {"file": (filename, audio_bytes, "application/octet-stream")}
|
|
data = {"language": language}
|
|
resp = await client.post(
|
|
f"{WHISPER_ORIGIN}/transcribe-segments",
|
|
files=files,
|
|
data=data,
|
|
)
|
|
resp.raise_for_status()
|
|
result = resp.json()
|
|
|
|
# Expected: {"text": "...", "language": "...", "segments": [...]}
|
|
segments = result.get("segments", [])
|
|
if not segments and result.get("text"):
|
|
# Fallback: no segments, create a single segment
|
|
segments = [{"start": 0.0, "end": 0.0, "text": result["text"]}]
|
|
return segments
|
|
|
|
except httpx.ConnectError as e:
|
|
logger.error(f"Cannot connect to Whisper server: {e}")
|
|
raise HTTPException(
|
|
status_code=503,
|
|
detail=f"Whisper server unavailable: {str(e)}",
|
|
)
|
|
except httpx.HTTPStatusError as e:
|
|
logger.error(f"Whisper server error {e.response.status_code}: {e.response.text}")
|
|
raise HTTPException(
|
|
status_code=502,
|
|
detail=f"Whisper server returned {e.response.status_code}",
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Transcription error: {e}", exc_info=True)
|
|
raise HTTPException(status_code=500, detail=f"Transcription failed: {str(e)}")
|