""" Voice Enrollment & Speaker Identification endpoints. POST /api/v2/enroll — enroll a new voice profile POST /api/v2/enroll/from-meeting — enroll from meeting segments GET /api/v2/speakers — list all profiles GET /api/v2/speakers/{id} — get profile details PUT /api/v2/speakers/{id} — update profile DELETE /api/v2/speakers/{id} — delete profile POST /api/v2/identify — identify a speaker from audio """ import os import json import logging import tempfile from typing import List, Optional from fastapi import APIRouter, UploadFile, File, Form, HTTPException from pydantic import BaseModel from services.enrollment import ( create_profile, get_profile, list_profiles, update_profile, delete_profile, add_embedding, identify_speaker, ) from services.diarization import extract_embedding logger = logging.getLogger(__name__) router = APIRouter() # ── Schemas ─────────────────────────────────────────────────────────────────── class VoiceProfileResponse(BaseModel): id: str name: str email: Optional[str] = None metadata: Optional[str] = None embeddings_count: int created_at: str updated_at: str class UpdateProfileRequest(BaseModel): name: Optional[str] = None email: Optional[str] = None metadata: Optional[str] = None class IdentifyResponse(BaseModel): matched: bool speaker_id: Optional[str] = None speaker_name: Optional[str] = None confidence: float threshold: float # ── Helper ──────────────────────────────────────────────────────────────────── async def _save_and_embed(file_bytes: bytes, suffix: str) -> object: """Save audio bytes to a temp file and extract embedding.""" with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp: tmp.write(file_bytes) tmp_path = tmp.name try: embedding = extract_embedding(tmp_path) return embedding finally: try: os.unlink(tmp_path) except Exception: pass # ── POST /enroll ────────────────────────────────────────────────────────────── @router.post( "/enroll", response_model=VoiceProfileResponse, status_code=201, summary="Enroll a new voice profile", description=( "Submit an audio file (minimum 30 seconds recommended) to enroll a new speaker. " "Extracts a vocal embedding and saves the profile to the local SQLite database." ), ) async def enroll( file: UploadFile = File(..., description="Audio file with the speaker's voice (min 30s recommended)"), name: str = Form(..., description="Speaker full name"), email: Optional[str] = Form(None, description="Speaker email (optional)"), metadata: Optional[str] = Form(None, description="Extra JSON metadata (optional)"), ): audio_bytes = await file.read() suffix = os.path.splitext(file.filename or ".wav")[-1].lower() or ".wav" embedding = await _save_and_embed(audio_bytes, suffix) profile = create_profile(name=name, email=email, metadata=metadata) add_embedding(profile["id"], embedding) profile["embeddings_count"] = 1 return VoiceProfileResponse(**profile) # ── POST /enroll/from-meeting ───────────────────────────────────────────────── @router.post( "/enroll/from-meeting", response_model=VoiceProfileResponse, status_code=201, summary="Enroll speaker from meeting audio segments", description=( "Extract voice embedding from specific speaker segments of an already processed meeting. " "Creates a new profile or adds embeddings to an existing one (matched by name)." ), ) async def enroll_from_meeting( meeting_segments: str = Form( ..., description='JSON array of {speaker_label, audio_file_path}: [{"speaker_label":"SPEAKER_01","audio_file_path":"/path/audio.wav"}]', ), speaker_name: str = Form(..., description="Name to assign to this speaker"), speaker_label: str = Form(..., description="Speaker label to extract (e.g. SPEAKER_01)"), existing_profile_id: Optional[str] = Form(None, description="Existing profile ID to add embeddings to (optional)"), ): try: segments = json.loads(meeting_segments) except (json.JSONDecodeError, ValueError) as e: raise HTTPException(status_code=422, detail=f"Invalid meeting_segments JSON: {e}") # Filter segments matching the requested speaker target_segments = [s for s in segments if s.get("speaker_label") == speaker_label] if not target_segments: raise HTTPException( status_code=404, detail=f"No segments found for speaker_label='{speaker_label}'", ) # Get or create profile if existing_profile_id: profile = get_profile(existing_profile_id) if profile is None: raise HTTPException(status_code=404, detail=f"Profile '{existing_profile_id}' not found") else: profile = create_profile(name=speaker_name, email=None, metadata=None) # Extract embeddings for each segment emb_count = 0 for seg in target_segments: audio_path = seg.get("audio_file_path", "") if not os.path.isfile(audio_path): logger.warning(f"Audio file not found: {audio_path}") continue try: embedding = extract_embedding(audio_path) add_embedding(profile["id"], embedding) emb_count += 1 except Exception as e: logger.warning(f"Failed to extract embedding from {audio_path}: {e}") if emb_count == 0: raise HTTPException( status_code=422, detail="No valid audio segments could be processed. Check that audio_file_path values exist on the server.", ) return VoiceProfileResponse(**get_profile(profile["id"])) # ── GET /speakers ───────────────────────────────────────────────────────────── @router.get( "/speakers", response_model=List[VoiceProfileResponse], summary="List all enrolled voice profiles", ) async def list_speakers(): return [VoiceProfileResponse(**p) for p in list_profiles()] # ── GET /speakers/{id} ──────────────────────────────────────────────────────── @router.get( "/speakers/{profile_id}", response_model=VoiceProfileResponse, summary="Get a voice profile by ID", ) async def get_speaker(profile_id: str): profile = get_profile(profile_id) if profile is None: raise HTTPException(status_code=404, detail="Voice profile not found") return VoiceProfileResponse(**profile) # ── PUT /speakers/{id} ──────────────────────────────────────────────────────── @router.put( "/speakers/{profile_id}", response_model=VoiceProfileResponse, summary="Update a voice profile (name, email, metadata — all optional)", ) async def update_speaker(profile_id: str, data: UpdateProfileRequest): profile = update_profile( profile_id, name=data.name, email=data.email, metadata=data.metadata, ) if profile is None: raise HTTPException(status_code=404, detail="Voice profile not found") return VoiceProfileResponse(**profile) # ── DELETE /speakers/{id} ───────────────────────────────────────────────────── @router.delete( "/speakers/{profile_id}", status_code=204, summary="Delete a voice profile and all its embeddings", ) async def delete_speaker(profile_id: str): if not delete_profile(profile_id): raise HTTPException(status_code=404, detail="Voice profile not found") # ── POST /identify ──────────────────────────────────────────────────────────── @router.post( "/identify", response_model=IdentifyResponse, summary="Identify the speaker in an audio clip", description=( "Extracts a vocal embedding from the submitted audio and compares it against all " "enrolled profiles. Returns the best match if confidence >= threshold." ), ) async def identify( file: UploadFile = File(..., description="Short audio clip to identify the speaker"), threshold: Optional[float] = Form(None, description="Confidence threshold (default: 0.75)"), ): audio_bytes = await file.read() suffix = os.path.splitext(file.filename or ".wav")[-1].lower() or ".wav" embedding = await _save_and_embed(audio_bytes, suffix) result = identify_speaker(embedding, threshold=threshold) return IdentifyResponse(**result)