229 lines
8.1 KiB
Python
229 lines
8.1 KiB
Python
"""
|
|
Voice enrollment and speaker identification service.
|
|
Uses SQLite at /opt/Backend/whisper-diarization-api/voice_profiles.db
|
|
"""
|
|
import os
|
|
import uuid
|
|
import pickle
|
|
import logging
|
|
import sqlite3
|
|
import numpy as np
|
|
from datetime import datetime, timezone
|
|
from typing import List, Optional, Dict, Any
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
DB_PATH = os.getenv("VOICE_DB_PATH", "/opt/Backend/whisper-diarization-api/voice_profiles.db")
|
|
CONFIDENCE_THRESHOLD = float(os.getenv("SPEAKER_CONFIDENCE_THRESHOLD", "0.75"))
|
|
|
|
|
|
def _conn():
|
|
conn = sqlite3.connect(DB_PATH)
|
|
conn.row_factory = sqlite3.Row
|
|
return conn
|
|
|
|
|
|
# ── Schema ───────────────────────────────────────────────────────────────────
|
|
def init_db():
|
|
os.makedirs(os.path.dirname(DB_PATH), exist_ok=True)
|
|
with _conn() as conn:
|
|
conn.executescript("""
|
|
CREATE TABLE IF NOT EXISTS voice_profiles (
|
|
id TEXT PRIMARY KEY,
|
|
name TEXT NOT NULL,
|
|
email TEXT,
|
|
metadata TEXT,
|
|
created_at TEXT NOT NULL,
|
|
updated_at TEXT NOT NULL
|
|
);
|
|
|
|
CREATE TABLE IF NOT EXISTS voice_embeddings (
|
|
id TEXT PRIMARY KEY,
|
|
profile_id TEXT NOT NULL REFERENCES voice_profiles(id) ON DELETE CASCADE,
|
|
embedding BLOB NOT NULL,
|
|
created_at TEXT NOT NULL
|
|
);
|
|
""")
|
|
logger.info(f"Voice profiles DB initialised at {DB_PATH}")
|
|
|
|
|
|
# ── Helpers ───────────────────────────────────────────────────────────────────
|
|
def _now() -> str:
|
|
return datetime.now(timezone.utc).isoformat()
|
|
|
|
|
|
def _serialize(arr: np.ndarray) -> bytes:
|
|
return pickle.dumps(arr)
|
|
|
|
|
|
def _deserialize(blob: bytes) -> np.ndarray:
|
|
return pickle.loads(blob)
|
|
|
|
|
|
def _profile_row_to_dict(row: sqlite3.Row, embeddings_count: int = 0) -> Dict[str, Any]:
|
|
return {
|
|
"id": row["id"],
|
|
"name": row["name"],
|
|
"email": row["email"],
|
|
"metadata": row["metadata"],
|
|
"embeddings_count": embeddings_count,
|
|
"created_at": row["created_at"],
|
|
"updated_at": row["updated_at"],
|
|
}
|
|
|
|
|
|
# ── CRUD: Voice Profiles ──────────────────────────────────────────────────────
|
|
def create_profile(name: str, email: Optional[str], metadata: Optional[str]) -> Dict[str, Any]:
|
|
profile_id = str(uuid.uuid4())
|
|
now = _now()
|
|
with _conn() as conn:
|
|
conn.execute(
|
|
"INSERT INTO voice_profiles (id, name, email, metadata, created_at, updated_at) VALUES (?,?,?,?,?,?)",
|
|
(profile_id, name, email, metadata, now, now),
|
|
)
|
|
return {
|
|
"id": profile_id,
|
|
"name": name,
|
|
"email": email,
|
|
"metadata": metadata,
|
|
"embeddings_count": 0,
|
|
"created_at": now,
|
|
"updated_at": now,
|
|
}
|
|
|
|
|
|
def get_profile(profile_id: str) -> Optional[Dict[str, Any]]:
|
|
with _conn() as conn:
|
|
row = conn.execute("SELECT * FROM voice_profiles WHERE id = ?", (profile_id,)).fetchone()
|
|
if row is None:
|
|
return None
|
|
count = conn.execute(
|
|
"SELECT COUNT(*) FROM voice_embeddings WHERE profile_id = ?", (profile_id,)
|
|
).fetchone()[0]
|
|
return _profile_row_to_dict(row, count)
|
|
|
|
|
|
def list_profiles() -> List[Dict[str, Any]]:
|
|
with _conn() as conn:
|
|
rows = conn.execute("SELECT * FROM voice_profiles ORDER BY created_at DESC").fetchall()
|
|
result = []
|
|
for row in rows:
|
|
count = conn.execute(
|
|
"SELECT COUNT(*) FROM voice_embeddings WHERE profile_id = ?", (row["id"],)
|
|
).fetchone()[0]
|
|
result.append(_profile_row_to_dict(row, count))
|
|
return result
|
|
|
|
|
|
def update_profile(
|
|
profile_id: str,
|
|
name: Optional[str] = None,
|
|
email: Optional[str] = None,
|
|
metadata: Optional[str] = None,
|
|
) -> Optional[Dict[str, Any]]:
|
|
profile = get_profile(profile_id)
|
|
if profile is None:
|
|
return None
|
|
|
|
new_name = name if name is not None else profile["name"]
|
|
new_email = email if email is not None else profile["email"]
|
|
new_meta = metadata if metadata is not None else profile["metadata"]
|
|
now = _now()
|
|
|
|
with _conn() as conn:
|
|
conn.execute(
|
|
"UPDATE voice_profiles SET name=?, email=?, metadata=?, updated_at=? WHERE id=?",
|
|
(new_name, new_email, new_meta, now, profile_id),
|
|
)
|
|
return get_profile(profile_id)
|
|
|
|
|
|
def delete_profile(profile_id: str) -> bool:
|
|
with _conn() as conn:
|
|
cur = conn.execute("DELETE FROM voice_profiles WHERE id = ?", (profile_id,))
|
|
return cur.rowcount > 0
|
|
|
|
|
|
# ── Embeddings ────────────────────────────────────────────────────────────────
|
|
def add_embedding(profile_id: str, embedding: np.ndarray) -> str:
|
|
emb_id = str(uuid.uuid4())
|
|
now = _now()
|
|
with _conn() as conn:
|
|
conn.execute(
|
|
"INSERT INTO voice_embeddings (id, profile_id, embedding, created_at) VALUES (?,?,?,?)",
|
|
(emb_id, profile_id, _serialize(embedding), now),
|
|
)
|
|
conn.execute("UPDATE voice_profiles SET updated_at=? WHERE id=?", (now, profile_id))
|
|
return emb_id
|
|
|
|
|
|
def get_all_embeddings() -> List[Dict[str, Any]]:
|
|
"""Returns all embeddings with profile info."""
|
|
with _conn() as conn:
|
|
rows = conn.execute("""
|
|
SELECT ve.id, ve.profile_id, ve.embedding, vp.name
|
|
FROM voice_embeddings ve
|
|
JOIN voice_profiles vp ON ve.profile_id = vp.id
|
|
""").fetchall()
|
|
return [
|
|
{
|
|
"embedding_id": r["id"],
|
|
"profile_id": r["profile_id"],
|
|
"speaker_name": r["name"],
|
|
"embedding": _deserialize(r["embedding"]),
|
|
}
|
|
for r in rows
|
|
]
|
|
|
|
|
|
# ── Speaker identification ────────────────────────────────────────────────────
|
|
def identify_speaker(query_embedding: np.ndarray, threshold: Optional[float] = None) -> Dict[str, Any]:
|
|
"""
|
|
Compare query_embedding against all registered speakers.
|
|
Returns the best match if confidence >= threshold.
|
|
"""
|
|
from services.diarization import cosine_similarity
|
|
|
|
if threshold is None:
|
|
threshold = CONFIDENCE_THRESHOLD
|
|
|
|
all_embs = get_all_embeddings()
|
|
if not all_embs:
|
|
return {"matched": False, "speaker_id": None, "speaker_name": None, "confidence": 0.0, "threshold": threshold}
|
|
|
|
# Average embeddings per profile
|
|
profile_map: Dict[str, Dict[str, Any]] = {}
|
|
for item in all_embs:
|
|
pid = item["profile_id"]
|
|
if pid not in profile_map:
|
|
profile_map[pid] = {"speaker_name": item["speaker_name"], "embeddings": []}
|
|
profile_map[pid]["embeddings"].append(item["embedding"])
|
|
|
|
best_profile_id = None
|
|
best_score = -1.0
|
|
|
|
for pid, info in profile_map.items():
|
|
# Average cosine similarity against all stored embeddings
|
|
scores = [cosine_similarity(query_embedding, e) for e in info["embeddings"]]
|
|
avg_score = float(np.mean(scores))
|
|
if avg_score > best_score:
|
|
best_score = avg_score
|
|
best_profile_id = pid
|
|
|
|
if best_profile_id is None or best_score < threshold:
|
|
return {
|
|
"matched": False,
|
|
"speaker_id": best_profile_id,
|
|
"speaker_name": profile_map[best_profile_id]["speaker_name"] if best_profile_id else None,
|
|
"confidence": round(best_score, 4),
|
|
"threshold": threshold,
|
|
}
|
|
|
|
return {
|
|
"matched": True,
|
|
"speaker_id": best_profile_id,
|
|
"speaker_name": profile_map[best_profile_id]["speaker_name"],
|
|
"confidence": round(best_score, 4),
|
|
"threshold": threshold,
|
|
}
|