Maia/diarization/services/enrollment.py

229 lines
8.1 KiB
Python

"""
Voice enrollment and speaker identification service.
Uses SQLite at /opt/Backend/whisper-diarization-api/voice_profiles.db
"""
import os
import uuid
import pickle
import logging
import sqlite3
import numpy as np
from datetime import datetime, timezone
from typing import List, Optional, Dict, Any
logger = logging.getLogger(__name__)
DB_PATH = os.getenv("VOICE_DB_PATH", "/opt/Backend/whisper-diarization-api/voice_profiles.db")
CONFIDENCE_THRESHOLD = float(os.getenv("SPEAKER_CONFIDENCE_THRESHOLD", "0.75"))
def _conn():
conn = sqlite3.connect(DB_PATH)
conn.row_factory = sqlite3.Row
return conn
# ── Schema ───────────────────────────────────────────────────────────────────
def init_db():
os.makedirs(os.path.dirname(DB_PATH), exist_ok=True)
with _conn() as conn:
conn.executescript("""
CREATE TABLE IF NOT EXISTS voice_profiles (
id TEXT PRIMARY KEY,
name TEXT NOT NULL,
email TEXT,
metadata TEXT,
created_at TEXT NOT NULL,
updated_at TEXT NOT NULL
);
CREATE TABLE IF NOT EXISTS voice_embeddings (
id TEXT PRIMARY KEY,
profile_id TEXT NOT NULL REFERENCES voice_profiles(id) ON DELETE CASCADE,
embedding BLOB NOT NULL,
created_at TEXT NOT NULL
);
""")
logger.info(f"Voice profiles DB initialised at {DB_PATH}")
# ── Helpers ───────────────────────────────────────────────────────────────────
def _now() -> str:
return datetime.now(timezone.utc).isoformat()
def _serialize(arr: np.ndarray) -> bytes:
return pickle.dumps(arr)
def _deserialize(blob: bytes) -> np.ndarray:
return pickle.loads(blob)
def _profile_row_to_dict(row: sqlite3.Row, embeddings_count: int = 0) -> Dict[str, Any]:
return {
"id": row["id"],
"name": row["name"],
"email": row["email"],
"metadata": row["metadata"],
"embeddings_count": embeddings_count,
"created_at": row["created_at"],
"updated_at": row["updated_at"],
}
# ── CRUD: Voice Profiles ──────────────────────────────────────────────────────
def create_profile(name: str, email: Optional[str], metadata: Optional[str]) -> Dict[str, Any]:
profile_id = str(uuid.uuid4())
now = _now()
with _conn() as conn:
conn.execute(
"INSERT INTO voice_profiles (id, name, email, metadata, created_at, updated_at) VALUES (?,?,?,?,?,?)",
(profile_id, name, email, metadata, now, now),
)
return {
"id": profile_id,
"name": name,
"email": email,
"metadata": metadata,
"embeddings_count": 0,
"created_at": now,
"updated_at": now,
}
def get_profile(profile_id: str) -> Optional[Dict[str, Any]]:
with _conn() as conn:
row = conn.execute("SELECT * FROM voice_profiles WHERE id = ?", (profile_id,)).fetchone()
if row is None:
return None
count = conn.execute(
"SELECT COUNT(*) FROM voice_embeddings WHERE profile_id = ?", (profile_id,)
).fetchone()[0]
return _profile_row_to_dict(row, count)
def list_profiles() -> List[Dict[str, Any]]:
with _conn() as conn:
rows = conn.execute("SELECT * FROM voice_profiles ORDER BY created_at DESC").fetchall()
result = []
for row in rows:
count = conn.execute(
"SELECT COUNT(*) FROM voice_embeddings WHERE profile_id = ?", (row["id"],)
).fetchone()[0]
result.append(_profile_row_to_dict(row, count))
return result
def update_profile(
profile_id: str,
name: Optional[str] = None,
email: Optional[str] = None,
metadata: Optional[str] = None,
) -> Optional[Dict[str, Any]]:
profile = get_profile(profile_id)
if profile is None:
return None
new_name = name if name is not None else profile["name"]
new_email = email if email is not None else profile["email"]
new_meta = metadata if metadata is not None else profile["metadata"]
now = _now()
with _conn() as conn:
conn.execute(
"UPDATE voice_profiles SET name=?, email=?, metadata=?, updated_at=? WHERE id=?",
(new_name, new_email, new_meta, now, profile_id),
)
return get_profile(profile_id)
def delete_profile(profile_id: str) -> bool:
with _conn() as conn:
cur = conn.execute("DELETE FROM voice_profiles WHERE id = ?", (profile_id,))
return cur.rowcount > 0
# ── Embeddings ────────────────────────────────────────────────────────────────
def add_embedding(profile_id: str, embedding: np.ndarray) -> str:
emb_id = str(uuid.uuid4())
now = _now()
with _conn() as conn:
conn.execute(
"INSERT INTO voice_embeddings (id, profile_id, embedding, created_at) VALUES (?,?,?,?)",
(emb_id, profile_id, _serialize(embedding), now),
)
conn.execute("UPDATE voice_profiles SET updated_at=? WHERE id=?", (now, profile_id))
return emb_id
def get_all_embeddings() -> List[Dict[str, Any]]:
"""Returns all embeddings with profile info."""
with _conn() as conn:
rows = conn.execute("""
SELECT ve.id, ve.profile_id, ve.embedding, vp.name
FROM voice_embeddings ve
JOIN voice_profiles vp ON ve.profile_id = vp.id
""").fetchall()
return [
{
"embedding_id": r["id"],
"profile_id": r["profile_id"],
"speaker_name": r["name"],
"embedding": _deserialize(r["embedding"]),
}
for r in rows
]
# ── Speaker identification ────────────────────────────────────────────────────
def identify_speaker(query_embedding: np.ndarray, threshold: Optional[float] = None) -> Dict[str, Any]:
"""
Compare query_embedding against all registered speakers.
Returns the best match if confidence >= threshold.
"""
from services.diarization import cosine_similarity
if threshold is None:
threshold = CONFIDENCE_THRESHOLD
all_embs = get_all_embeddings()
if not all_embs:
return {"matched": False, "speaker_id": None, "speaker_name": None, "confidence": 0.0, "threshold": threshold}
# Average embeddings per profile
profile_map: Dict[str, Dict[str, Any]] = {}
for item in all_embs:
pid = item["profile_id"]
if pid not in profile_map:
profile_map[pid] = {"speaker_name": item["speaker_name"], "embeddings": []}
profile_map[pid]["embeddings"].append(item["embedding"])
best_profile_id = None
best_score = -1.0
for pid, info in profile_map.items():
# Average cosine similarity against all stored embeddings
scores = [cosine_similarity(query_embedding, e) for e in info["embeddings"]]
avg_score = float(np.mean(scores))
if avg_score > best_score:
best_score = avg_score
best_profile_id = pid
if best_profile_id is None or best_score < threshold:
return {
"matched": False,
"speaker_id": best_profile_id,
"speaker_name": profile_map[best_profile_id]["speaker_name"] if best_profile_id else None,
"confidence": round(best_score, 4),
"threshold": threshold,
}
return {
"matched": True,
"speaker_id": best_profile_id,
"speaker_name": profile_map[best_profile_id]["speaker_name"],
"confidence": round(best_score, 4),
"threshold": threshold,
}