Maia/diarization/routers/v2_speakers.py

242 lines
9.4 KiB
Python

"""
Voice Enrollment & Speaker Identification endpoints.
POST /api/v2/enroll — enroll a new voice profile
POST /api/v2/enroll/from-meeting — enroll from meeting segments
GET /api/v2/speakers — list all profiles
GET /api/v2/speakers/{id} — get profile details
PUT /api/v2/speakers/{id} — update profile
DELETE /api/v2/speakers/{id} — delete profile
POST /api/v2/identify — identify a speaker from audio
"""
import os
import json
import logging
import tempfile
from typing import List, Optional
from fastapi import APIRouter, UploadFile, File, Form, HTTPException
from pydantic import BaseModel
from services.enrollment import (
create_profile,
get_profile,
list_profiles,
update_profile,
delete_profile,
add_embedding,
identify_speaker,
)
from services.diarization import extract_embedding
logger = logging.getLogger(__name__)
router = APIRouter()
# ── Schemas ───────────────────────────────────────────────────────────────────
class VoiceProfileResponse(BaseModel):
id: str
name: str
email: Optional[str] = None
metadata: Optional[str] = None
embeddings_count: int
created_at: str
updated_at: str
class UpdateProfileRequest(BaseModel):
name: Optional[str] = None
email: Optional[str] = None
metadata: Optional[str] = None
class IdentifyResponse(BaseModel):
matched: bool
speaker_id: Optional[str] = None
speaker_name: Optional[str] = None
confidence: float
threshold: float
# ── Helper ────────────────────────────────────────────────────────────────────
async def _save_and_embed(file_bytes: bytes, suffix: str) -> object:
"""Save audio bytes to a temp file and extract embedding."""
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp:
tmp.write(file_bytes)
tmp_path = tmp.name
try:
embedding = extract_embedding(tmp_path)
return embedding
finally:
try:
os.unlink(tmp_path)
except Exception:
pass
# ── POST /enroll ──────────────────────────────────────────────────────────────
@router.post(
"/enroll",
response_model=VoiceProfileResponse,
status_code=201,
summary="Enroll a new voice profile",
description=(
"Submit an audio file (minimum 30 seconds recommended) to enroll a new speaker. "
"Extracts a vocal embedding and saves the profile to the local SQLite database."
),
)
async def enroll(
file: UploadFile = File(..., description="Audio file with the speaker's voice (min 30s recommended)"),
name: str = Form(..., description="Speaker full name"),
email: Optional[str] = Form(None, description="Speaker email (optional)"),
metadata: Optional[str] = Form(None, description="Extra JSON metadata (optional)"),
):
audio_bytes = await file.read()
suffix = os.path.splitext(file.filename or ".wav")[-1].lower() or ".wav"
embedding = await _save_and_embed(audio_bytes, suffix)
profile = create_profile(name=name, email=email, metadata=metadata)
add_embedding(profile["id"], embedding)
profile["embeddings_count"] = 1
return VoiceProfileResponse(**profile)
# ── POST /enroll/from-meeting ─────────────────────────────────────────────────
@router.post(
"/enroll/from-meeting",
response_model=VoiceProfileResponse,
status_code=201,
summary="Enroll speaker from meeting audio segments",
description=(
"Extract voice embedding from specific speaker segments of an already processed meeting. "
"Creates a new profile or adds embeddings to an existing one (matched by name)."
),
)
async def enroll_from_meeting(
meeting_segments: str = Form(
...,
description='JSON array of {speaker_label, audio_file_path}: [{"speaker_label":"SPEAKER_01","audio_file_path":"/path/audio.wav"}]',
),
speaker_name: str = Form(..., description="Name to assign to this speaker"),
speaker_label: str = Form(..., description="Speaker label to extract (e.g. SPEAKER_01)"),
existing_profile_id: Optional[str] = Form(None, description="Existing profile ID to add embeddings to (optional)"),
):
try:
segments = json.loads(meeting_segments)
except (json.JSONDecodeError, ValueError) as e:
raise HTTPException(status_code=422, detail=f"Invalid meeting_segments JSON: {e}")
# Filter segments matching the requested speaker
target_segments = [s for s in segments if s.get("speaker_label") == speaker_label]
if not target_segments:
raise HTTPException(
status_code=404,
detail=f"No segments found for speaker_label='{speaker_label}'",
)
# Get or create profile
if existing_profile_id:
profile = get_profile(existing_profile_id)
if profile is None:
raise HTTPException(status_code=404, detail=f"Profile '{existing_profile_id}' not found")
else:
profile = create_profile(name=speaker_name, email=None, metadata=None)
# Extract embeddings for each segment
emb_count = 0
for seg in target_segments:
audio_path = seg.get("audio_file_path", "")
if not os.path.isfile(audio_path):
logger.warning(f"Audio file not found: {audio_path}")
continue
try:
embedding = extract_embedding(audio_path)
add_embedding(profile["id"], embedding)
emb_count += 1
except Exception as e:
logger.warning(f"Failed to extract embedding from {audio_path}: {e}")
if emb_count == 0:
raise HTTPException(
status_code=422,
detail="No valid audio segments could be processed. Check that audio_file_path values exist on the server.",
)
return VoiceProfileResponse(**get_profile(profile["id"]))
# ── GET /speakers ─────────────────────────────────────────────────────────────
@router.get(
"/speakers",
response_model=List[VoiceProfileResponse],
summary="List all enrolled voice profiles",
)
async def list_speakers():
return [VoiceProfileResponse(**p) for p in list_profiles()]
# ── GET /speakers/{id} ────────────────────────────────────────────────────────
@router.get(
"/speakers/{profile_id}",
response_model=VoiceProfileResponse,
summary="Get a voice profile by ID",
)
async def get_speaker(profile_id: str):
profile = get_profile(profile_id)
if profile is None:
raise HTTPException(status_code=404, detail="Voice profile not found")
return VoiceProfileResponse(**profile)
# ── PUT /speakers/{id} ────────────────────────────────────────────────────────
@router.put(
"/speakers/{profile_id}",
response_model=VoiceProfileResponse,
summary="Update a voice profile (name, email, metadata — all optional)",
)
async def update_speaker(profile_id: str, data: UpdateProfileRequest):
profile = update_profile(
profile_id,
name=data.name,
email=data.email,
metadata=data.metadata,
)
if profile is None:
raise HTTPException(status_code=404, detail="Voice profile not found")
return VoiceProfileResponse(**profile)
# ── DELETE /speakers/{id} ─────────────────────────────────────────────────────
@router.delete(
"/speakers/{profile_id}",
status_code=204,
summary="Delete a voice profile and all its embeddings",
)
async def delete_speaker(profile_id: str):
if not delete_profile(profile_id):
raise HTTPException(status_code=404, detail="Voice profile not found")
# ── POST /identify ────────────────────────────────────────────────────────────
@router.post(
"/identify",
response_model=IdentifyResponse,
summary="Identify the speaker in an audio clip",
description=(
"Extracts a vocal embedding from the submitted audio and compares it against all "
"enrolled profiles. Returns the best match if confidence >= threshold."
),
)
async def identify(
file: UploadFile = File(..., description="Short audio clip to identify the speaker"),
threshold: Optional[float] = Form(None, description="Confidence threshold (default: 0.75)"),
):
audio_bytes = await file.read()
suffix = os.path.splitext(file.filename or ".wav")[-1].lower() or ".wav"
embedding = await _save_and_embed(audio_bytes, suffix)
result = identify_speaker(embedding, threshold=threshold)
return IdentifyResponse(**result)