242 lines
9.4 KiB
Python
242 lines
9.4 KiB
Python
"""
|
|
Voice Enrollment & Speaker Identification endpoints.
|
|
|
|
POST /api/v2/enroll — enroll a new voice profile
|
|
POST /api/v2/enroll/from-meeting — enroll from meeting segments
|
|
GET /api/v2/speakers — list all profiles
|
|
GET /api/v2/speakers/{id} — get profile details
|
|
PUT /api/v2/speakers/{id} — update profile
|
|
DELETE /api/v2/speakers/{id} — delete profile
|
|
POST /api/v2/identify — identify a speaker from audio
|
|
"""
|
|
import os
|
|
import json
|
|
import logging
|
|
import tempfile
|
|
from typing import List, Optional
|
|
|
|
from fastapi import APIRouter, UploadFile, File, Form, HTTPException
|
|
from pydantic import BaseModel
|
|
|
|
from services.enrollment import (
|
|
create_profile,
|
|
get_profile,
|
|
list_profiles,
|
|
update_profile,
|
|
delete_profile,
|
|
add_embedding,
|
|
identify_speaker,
|
|
)
|
|
from services.diarization import extract_embedding
|
|
|
|
logger = logging.getLogger(__name__)
|
|
router = APIRouter()
|
|
|
|
|
|
# ── Schemas ───────────────────────────────────────────────────────────────────
|
|
class VoiceProfileResponse(BaseModel):
|
|
id: str
|
|
name: str
|
|
email: Optional[str] = None
|
|
metadata: Optional[str] = None
|
|
embeddings_count: int
|
|
created_at: str
|
|
updated_at: str
|
|
|
|
|
|
class UpdateProfileRequest(BaseModel):
|
|
name: Optional[str] = None
|
|
email: Optional[str] = None
|
|
metadata: Optional[str] = None
|
|
|
|
|
|
class IdentifyResponse(BaseModel):
|
|
matched: bool
|
|
speaker_id: Optional[str] = None
|
|
speaker_name: Optional[str] = None
|
|
confidence: float
|
|
threshold: float
|
|
|
|
|
|
# ── Helper ────────────────────────────────────────────────────────────────────
|
|
async def _save_and_embed(file_bytes: bytes, suffix: str) -> object:
|
|
"""Save audio bytes to a temp file and extract embedding."""
|
|
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp:
|
|
tmp.write(file_bytes)
|
|
tmp_path = tmp.name
|
|
try:
|
|
embedding = extract_embedding(tmp_path)
|
|
return embedding
|
|
finally:
|
|
try:
|
|
os.unlink(tmp_path)
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
# ── POST /enroll ──────────────────────────────────────────────────────────────
|
|
@router.post(
|
|
"/enroll",
|
|
response_model=VoiceProfileResponse,
|
|
status_code=201,
|
|
summary="Enroll a new voice profile",
|
|
description=(
|
|
"Submit an audio file (minimum 30 seconds recommended) to enroll a new speaker. "
|
|
"Extracts a vocal embedding and saves the profile to the local SQLite database."
|
|
),
|
|
)
|
|
async def enroll(
|
|
file: UploadFile = File(..., description="Audio file with the speaker's voice (min 30s recommended)"),
|
|
name: str = Form(..., description="Speaker full name"),
|
|
email: Optional[str] = Form(None, description="Speaker email (optional)"),
|
|
metadata: Optional[str] = Form(None, description="Extra JSON metadata (optional)"),
|
|
):
|
|
audio_bytes = await file.read()
|
|
suffix = os.path.splitext(file.filename or ".wav")[-1].lower() or ".wav"
|
|
|
|
embedding = await _save_and_embed(audio_bytes, suffix)
|
|
|
|
profile = create_profile(name=name, email=email, metadata=metadata)
|
|
add_embedding(profile["id"], embedding)
|
|
|
|
profile["embeddings_count"] = 1
|
|
return VoiceProfileResponse(**profile)
|
|
|
|
|
|
# ── POST /enroll/from-meeting ─────────────────────────────────────────────────
|
|
@router.post(
|
|
"/enroll/from-meeting",
|
|
response_model=VoiceProfileResponse,
|
|
status_code=201,
|
|
summary="Enroll speaker from meeting audio segments",
|
|
description=(
|
|
"Extract voice embedding from specific speaker segments of an already processed meeting. "
|
|
"Creates a new profile or adds embeddings to an existing one (matched by name)."
|
|
),
|
|
)
|
|
async def enroll_from_meeting(
|
|
meeting_segments: str = Form(
|
|
...,
|
|
description='JSON array of {speaker_label, audio_file_path}: [{"speaker_label":"SPEAKER_01","audio_file_path":"/path/audio.wav"}]',
|
|
),
|
|
speaker_name: str = Form(..., description="Name to assign to this speaker"),
|
|
speaker_label: str = Form(..., description="Speaker label to extract (e.g. SPEAKER_01)"),
|
|
existing_profile_id: Optional[str] = Form(None, description="Existing profile ID to add embeddings to (optional)"),
|
|
):
|
|
try:
|
|
segments = json.loads(meeting_segments)
|
|
except (json.JSONDecodeError, ValueError) as e:
|
|
raise HTTPException(status_code=422, detail=f"Invalid meeting_segments JSON: {e}")
|
|
|
|
# Filter segments matching the requested speaker
|
|
target_segments = [s for s in segments if s.get("speaker_label") == speaker_label]
|
|
if not target_segments:
|
|
raise HTTPException(
|
|
status_code=404,
|
|
detail=f"No segments found for speaker_label='{speaker_label}'",
|
|
)
|
|
|
|
# Get or create profile
|
|
if existing_profile_id:
|
|
profile = get_profile(existing_profile_id)
|
|
if profile is None:
|
|
raise HTTPException(status_code=404, detail=f"Profile '{existing_profile_id}' not found")
|
|
else:
|
|
profile = create_profile(name=speaker_name, email=None, metadata=None)
|
|
|
|
# Extract embeddings for each segment
|
|
emb_count = 0
|
|
for seg in target_segments:
|
|
audio_path = seg.get("audio_file_path", "")
|
|
if not os.path.isfile(audio_path):
|
|
logger.warning(f"Audio file not found: {audio_path}")
|
|
continue
|
|
try:
|
|
embedding = extract_embedding(audio_path)
|
|
add_embedding(profile["id"], embedding)
|
|
emb_count += 1
|
|
except Exception as e:
|
|
logger.warning(f"Failed to extract embedding from {audio_path}: {e}")
|
|
|
|
if emb_count == 0:
|
|
raise HTTPException(
|
|
status_code=422,
|
|
detail="No valid audio segments could be processed. Check that audio_file_path values exist on the server.",
|
|
)
|
|
|
|
return VoiceProfileResponse(**get_profile(profile["id"]))
|
|
|
|
|
|
# ── GET /speakers ─────────────────────────────────────────────────────────────
|
|
@router.get(
|
|
"/speakers",
|
|
response_model=List[VoiceProfileResponse],
|
|
summary="List all enrolled voice profiles",
|
|
)
|
|
async def list_speakers():
|
|
return [VoiceProfileResponse(**p) for p in list_profiles()]
|
|
|
|
|
|
# ── GET /speakers/{id} ────────────────────────────────────────────────────────
|
|
@router.get(
|
|
"/speakers/{profile_id}",
|
|
response_model=VoiceProfileResponse,
|
|
summary="Get a voice profile by ID",
|
|
)
|
|
async def get_speaker(profile_id: str):
|
|
profile = get_profile(profile_id)
|
|
if profile is None:
|
|
raise HTTPException(status_code=404, detail="Voice profile not found")
|
|
return VoiceProfileResponse(**profile)
|
|
|
|
|
|
# ── PUT /speakers/{id} ────────────────────────────────────────────────────────
|
|
@router.put(
|
|
"/speakers/{profile_id}",
|
|
response_model=VoiceProfileResponse,
|
|
summary="Update a voice profile (name, email, metadata — all optional)",
|
|
)
|
|
async def update_speaker(profile_id: str, data: UpdateProfileRequest):
|
|
profile = update_profile(
|
|
profile_id,
|
|
name=data.name,
|
|
email=data.email,
|
|
metadata=data.metadata,
|
|
)
|
|
if profile is None:
|
|
raise HTTPException(status_code=404, detail="Voice profile not found")
|
|
return VoiceProfileResponse(**profile)
|
|
|
|
|
|
# ── DELETE /speakers/{id} ─────────────────────────────────────────────────────
|
|
@router.delete(
|
|
"/speakers/{profile_id}",
|
|
status_code=204,
|
|
summary="Delete a voice profile and all its embeddings",
|
|
)
|
|
async def delete_speaker(profile_id: str):
|
|
if not delete_profile(profile_id):
|
|
raise HTTPException(status_code=404, detail="Voice profile not found")
|
|
|
|
|
|
# ── POST /identify ────────────────────────────────────────────────────────────
|
|
@router.post(
|
|
"/identify",
|
|
response_model=IdentifyResponse,
|
|
summary="Identify the speaker in an audio clip",
|
|
description=(
|
|
"Extracts a vocal embedding from the submitted audio and compares it against all "
|
|
"enrolled profiles. Returns the best match if confidence >= threshold."
|
|
),
|
|
)
|
|
async def identify(
|
|
file: UploadFile = File(..., description="Short audio clip to identify the speaker"),
|
|
threshold: Optional[float] = Form(None, description="Confidence threshold (default: 0.75)"),
|
|
):
|
|
audio_bytes = await file.read()
|
|
suffix = os.path.splitext(file.filename or ".wav")[-1].lower() or ".wav"
|
|
|
|
embedding = await _save_and_embed(audio_bytes, suffix)
|
|
result = identify_speaker(embedding, threshold=threshold)
|
|
return IdentifyResponse(**result)
|