Maia/diarization/services/diarization.py

491 lines
18 KiB
Python

"""
Diarization service — speaker identification pipeline.
Priority order:
1. pyannote.audio 3.1 — best quality, requires HuggingFace token with model access
2. SpeechBrain ECAPA-TDNN — good quality (~85%), no license needed, GPU-accelerated
3. Energy + MFCC fallback — basic quality (~60%), pure CPU, always available
Speaker embeddings for enrollment:
Same priority: pyannote → SpeechBrain → MFCC
"""
import os
import io
import logging
import tempfile
import numpy as np
from typing import List, Dict, Any, Optional, Tuple
logger = logging.getLogger(__name__)
# ── Pyannote (primary) ────────────────────────────────────────────────────────
_pyannote_pipeline = None
_pyannote_tried = False
HF_TOKEN = os.getenv("HF_TOKEN", "hf_placeholder_token")
# ── SpeechBrain (secondary) ───────────────────────────────────────────────────
_speechbrain_model = None
_speechbrain_tried = False
SPEECHBRAIN_CACHE = os.getenv("SPEECHBRAIN_CACHE", "/tmp/speechbrain_ecapa")
def _try_load_pyannote():
global _pyannote_pipeline, _pyannote_tried
if _pyannote_tried:
return _pyannote_pipeline
_pyannote_tried = True
try:
from pyannote.audio import Pipeline
logger.info("Loading pyannote speaker-diarization-3.1 ...")
pipeline = Pipeline.from_pretrained(
"pyannote/speaker-diarization-3.1",
use_auth_token=HF_TOKEN,
)
import torch
if torch.cuda.is_available():
pipeline = pipeline.to(torch.device("cuda"))
_pyannote_pipeline = pipeline
logger.info("pyannote pipeline loaded successfully.")
except Exception as e:
logger.warning(f"pyannote unavailable ({e}), will use energy-based fallback.")
_pyannote_pipeline = None
return _pyannote_pipeline
# ── pyannote embedding (for enrollment) ──────────────────────────────────────
_embedding_model = None
_embedding_tried = False
def _try_load_speechbrain():
global _speechbrain_model, _speechbrain_tried
if _speechbrain_tried:
return _speechbrain_model
_speechbrain_tried = True
try:
import torch
from speechbrain.inference.speaker import EncoderClassifier
logger.info("Loading SpeechBrain ECAPA-TDNN model...")
device = "cuda" if torch.cuda.is_available() else "cpu"
classifier = EncoderClassifier.from_hparams(
source="speechbrain/spkrec-ecapa-voxceleb",
savedir=SPEECHBRAIN_CACHE,
run_opts={"device": device},
)
_speechbrain_model = classifier
logger.info(f"SpeechBrain ECAPA-TDNN loaded on {device}.")
except Exception as e:
logger.warning(f"SpeechBrain unavailable ({e}), will use MFCC fallback.")
_speechbrain_model = None
return _speechbrain_model
def _try_load_embedding_model():
global _embedding_model, _embedding_tried
if _embedding_tried:
return _embedding_model
_embedding_tried = True
try:
from pyannote.audio import Model
from pyannote.audio import Inference
logger.info("Loading pyannote embedding model...")
model = Model.from_pretrained("pyannote/embedding", use_auth_token=HF_TOKEN)
_embedding_model = Inference(model, window="whole")
logger.info("pyannote embedding model loaded.")
except Exception:
sb = _try_load_speechbrain()
if sb is not None:
_embedding_model = ("speechbrain", sb)
logger.info("Using SpeechBrain for embeddings.")
else:
logger.warning("No embedding model available. Using MFCC fallback.")
_embedding_model = None
return _embedding_model
# ── Pyannote diarization ──────────────────────────────────────────────────────
def _diarize_pyannote(audio_path: str) -> List[Dict]:
"""Run pyannote diarization. Returns list of {start, end, speaker}."""
pipeline = _try_load_pyannote()
if pipeline is None:
raise RuntimeError("pyannote not available")
diarization = pipeline(audio_path)
segments = []
for turn, _, speaker in diarization.itertracks(yield_label=True):
segments.append({"start": round(turn.start, 3), "end": round(turn.end, 3), "speaker": speaker})
return segments
# ── Shared VAD: detects speech segments using energy thresholding ─────────────
def _get_speech_segments(audio_path: str, min_silence_sec: float = 0.5) -> Tuple[np.ndarray, int, List[Dict]]:
"""
Energy-based Voice Activity Detection.
Returns (y, sr, segments) where segments = [{start, end}, ...].
"""
import librosa
y, sr = librosa.load(audio_path, sr=16000, mono=True)
frame_len = int(0.025 * sr) # 25 ms
hop_len = int(0.010 * sr) # 10 ms
energy = librosa.feature.rms(y=y, frame_length=frame_len, hop_length=hop_len)[0]
nonzero_energy = energy[energy > 0]
if len(nonzero_energy) == 0:
duration = len(y) / sr
return y, sr, [{"start": 0.0, "end": round(float(duration), 3)}]
threshold = np.percentile(nonzero_energy, 20)
is_speech = energy > threshold
times = librosa.frames_to_time(np.arange(len(is_speech)), sr=sr, hop_length=hop_len)
min_silence_frames = int(min_silence_sec / (hop_len / sr))
segments_raw = []
in_speech = False
seg_start = 0.0
silence_count = 0
for i, (t, sp) in enumerate(zip(times, is_speech)):
if sp:
if not in_speech:
seg_start = t
in_speech = True
silence_count = 0
else:
if in_speech:
silence_count += 1
if silence_count >= min_silence_frames:
seg_end = times[i - silence_count] if i >= silence_count else t
if seg_end - seg_start > 0.3:
segments_raw.append({"start": round(seg_start, 3), "end": round(seg_end, 3)})
in_speech = False
silence_count = 0
if in_speech:
segments_raw.append({"start": round(seg_start, 3), "end": round(float(times[-1]), 3)})
if not segments_raw:
segments_raw = [{"start": 0.0, "end": round(float(len(y) / sr), 3)}]
return y, sr, segments_raw
def _estimate_n_speakers(embeddings: np.ndarray, max_speakers: int = 8) -> int:
"""
Estimate optimal number of speakers using the elbow method on
intra-cluster variance. Tries k=1..max_speakers and picks the k where
adding one more cluster yields < 15% improvement of total variance range.
"""
n = len(embeddings)
if n <= 1:
return 1
max_k = min(max_speakers, n)
if max_k <= 2:
return max_k
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
norms[norms == 0] = 1e-10
normed = embeddings / norms
variances = []
for k in range(1, max_k + 1):
labels = _cosine_cluster(normed, n_clusters=k)
total_var = 0.0
for c in range(k):
mask = labels == c
if mask.sum() > 1:
centroid = normed[mask].mean(axis=0)
total_var += float(np.mean(1 - normed[mask] @ centroid))
variances.append(total_var)
if len(variances) < 2:
return 1
total_range = variances[0] - variances[-1]
if total_range < 1e-6:
return 1
for k_idx in range(1, len(variances)):
improvement = variances[k_idx - 1] - variances[k_idx]
if improvement < 0.15 * total_range:
return k_idx # k = k_idx (1-indexed)
return max_k
# ── SpeechBrain ECAPA-TDNN diarization (secondary) ───────────────────────────
def _diarize_speechbrain(audio_path: str, num_speakers: Optional[int] = None) -> List[Dict]:
"""
SpeechBrain ECAPA-TDNN diarization.
Uses energy-based VAD for segmentation, ECAPA-TDNN 192-dim embeddings,
and cosine k-means clustering. ~85% accuracy vs ~60% for MFCC.
"""
import torch
classifier = _try_load_speechbrain()
if classifier is None:
raise RuntimeError("SpeechBrain not available")
y, sr, segments_raw = _get_speech_segments(audio_path, min_silence_sec=0.5)
if len(segments_raw) == 0:
return []
embeddings = []
for seg in segments_raw:
start_sample = int(seg["start"] * sr)
end_sample = int(seg["end"] * sr)
chunk = y[start_sample:end_sample]
min_samples = int(0.1 * sr) # 100ms minimum
if len(chunk) < min_samples:
embeddings.append(np.zeros(192))
continue
waveform = torch.tensor(chunk).unsqueeze(0).float()
try:
with torch.no_grad():
emb = classifier.encode_batch(waveform)
embeddings.append(emb.squeeze().cpu().numpy())
except Exception as e:
logger.warning(f"ECAPA-TDNN embedding failed for segment: {e}")
embeddings.append(np.zeros(192))
embeddings_arr = np.array(embeddings)
n_speakers = num_speakers if num_speakers else _estimate_n_speakers(embeddings_arr, max_speakers=8)
n_speakers = max(1, min(n_speakers, len(segments_raw)))
labels = _cosine_cluster(embeddings_arr, n_clusters=n_speakers)
for seg, label in zip(segments_raw, labels):
seg["speaker"] = f"SPEAKER_{(label + 1):02d}"
return segments_raw
# ── Energy + MFCC fallback diarization ───────────────────────────────────────
def _diarize_energy(audio_path: str, min_silence_sec: float = 0.5, num_speakers: Optional[int] = None) -> List[Dict]:
"""
Energy-based VAD + MFCC cosine clustering fallback (~60% accuracy).
Used when both pyannote and SpeechBrain are unavailable.
"""
import librosa
y, sr, segments_raw = _get_speech_segments(audio_path, min_silence_sec=min_silence_sec)
if len(segments_raw) == 0:
return []
frame_len = int(0.025 * sr)
def seg_mfcc(s):
start_sample = int(s["start"] * sr)
end_sample = int(s["end"] * sr)
chunk = y[start_sample:end_sample]
if len(chunk) < frame_len:
return np.zeros(20)
mfcc = librosa.feature.mfcc(y=chunk, sr=sr, n_mfcc=20)
return mfcc.mean(axis=1)
features = np.array([seg_mfcc(s) for s in segments_raw])
n_sp = num_speakers if num_speakers else _estimate_n_speakers(features, max_speakers=6)
n_sp = max(1, min(n_sp, len(segments_raw)))
labels = _cosine_cluster(features, n_clusters=n_sp)
for seg, label in zip(segments_raw, labels):
seg["speaker"] = f"SPEAKER_{(label + 1):02d}"
return segments_raw
def _cosine_cluster(features: np.ndarray, n_clusters: int) -> np.ndarray:
"""Greedy cosine-distance clustering (no sklearn dependency)."""
if n_clusters <= 1 or len(features) <= 1:
return np.zeros(len(features), dtype=int)
norms = np.linalg.norm(features, axis=1, keepdims=True)
norms[norms == 0] = 1e-10
normed = features / norms
# Initialize centroids as first n_clusters samples
centroids = normed[:n_clusters].copy()
labels = np.zeros(len(features), dtype=int)
for _ in range(20): # max iterations
# Assign to nearest centroid (highest cosine similarity)
sims = normed @ centroids.T # (N, K)
new_labels = np.argmax(sims, axis=1)
if np.all(new_labels == labels):
break
labels = new_labels
# Update centroids
for k in range(n_clusters):
mask = labels == k
if mask.any():
centroids[k] = normed[mask].mean(axis=0)
norm = np.linalg.norm(centroids[k])
if norm > 0:
centroids[k] /= norm
return labels
# ── Public diarization entry point ───────────────────────────────────────────
def diarize(audio_path: str, num_speakers: Optional[int] = None) -> Tuple[List[Dict], str]:
"""
Returns (segments, method) where:
segments: [{start, end, speaker}, ...]
method: "pyannote" | "speechbrain" | "energy_fallback"
Priority: pyannote → SpeechBrain ECAPA-TDNN → energy+MFCC
"""
try:
segs = _diarize_pyannote(audio_path)
return segs, "pyannote"
except Exception as e:
logger.warning(f"pyannote diarization failed ({e}), trying SpeechBrain")
try:
segs = _diarize_speechbrain(audio_path, num_speakers=num_speakers)
return segs, "speechbrain"
except Exception as e:
logger.warning(f"SpeechBrain diarization failed ({e}), using MFCC fallback")
segs = _diarize_energy(audio_path, num_speakers=num_speakers)
return segs, "energy_fallback"
# ── Transcription + diarization alignment ────────────────────────────────────
def align_transcription_with_diarization(
transcription_segments: List[Dict],
diarization_segments: List[Dict],
) -> List[Dict]:
"""
For each transcription segment {start, end, text}, find the speaker with
the most time overlap from diarization_segments.
Falls back to temporally closest speaker if no overlap.
Merges consecutive same-speaker segments < 1 second apart.
"""
result = []
for tseg in transcription_segments:
t_start = tseg.get("start", 0.0)
t_end = tseg.get("end", t_start + 0.1)
text = tseg.get("text", "").strip()
if not text:
continue
# Find best speaker by overlap
best_speaker = None
best_overlap = -1.0
best_dist = float("inf")
for dseg in diarization_segments:
d_start = dseg["start"]
d_end = dseg["end"]
speaker = dseg["speaker"]
overlap = max(0.0, min(t_end, d_end) - max(t_start, d_start))
if overlap > best_overlap:
best_overlap = overlap
best_speaker = speaker
# Track closest for fallback
dist = min(abs(t_start - d_start), abs(t_start - d_end),
abs(t_end - d_start), abs(t_end - d_end))
if dist < best_dist:
best_dist = dist
if best_overlap <= 0:
best_speaker = speaker
if best_speaker is None:
best_speaker = "SPEAKER_01"
result.append({
"start": round(t_start, 3),
"end": round(t_end, 3),
"speaker": best_speaker,
"text": text,
})
# Merge consecutive same-speaker segments where gap < 1 second
if not result:
return result
merged = [result[0].copy()]
for seg in result[1:]:
prev = merged[-1]
gap = seg["start"] - prev["end"]
if seg["speaker"] == prev["speaker"] and gap < 1.0:
prev["end"] = seg["end"]
prev["text"] = prev["text"].rstrip() + " " + seg["text"].lstrip()
else:
merged.append(seg.copy())
return merged
# ── Embedding extraction ──────────────────────────────────────────────────────
def extract_embedding(audio_path: str) -> np.ndarray:
"""
Extract speaker embedding from audio file.
Returns 1-D numpy array.
Falls back to MFCC-based embedding if models unavailable.
"""
model = _try_load_embedding_model()
if model is not None:
if isinstance(model, tuple) and model[0] == "speechbrain":
_, classifier = model
try:
import torch
import torchaudio
waveform, sr = torchaudio.load(audio_path)
if sr != 16000:
resampler = torchaudio.transforms.Resample(sr, 16000)
waveform = resampler(waveform)
embedding = classifier.encode_batch(waveform)
return embedding.squeeze().cpu().numpy()
except Exception as e:
logger.warning(f"SpeechBrain embedding failed: {e}")
else:
# pyannote Inference
try:
embedding = model(audio_path)
if hasattr(embedding, 'data'):
return np.array(embedding.data).flatten()
return np.array(embedding).flatten()
except Exception as e:
logger.warning(f"pyannote embedding failed: {e}")
# MFCC fallback
return _mfcc_embedding(audio_path)
def _mfcc_embedding(audio_path: str) -> np.ndarray:
"""Create a simple MFCC-based speaker embedding (128-dim)."""
import librosa
y, sr = librosa.load(audio_path, sr=16000, mono=True)
mfcc = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=40)
delta = librosa.feature.delta(mfcc)
delta2 = librosa.feature.delta(mfcc, order=2)
features = np.concatenate([
mfcc.mean(axis=1), mfcc.std(axis=1),
delta.mean(axis=1),
delta2.mean(axis=1),
])
norm = np.linalg.norm(features)
if norm > 0:
features = features / norm
return features.astype(np.float32)
def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
"""Cosine similarity between two vectors."""
na = np.linalg.norm(a)
nb = np.linalg.norm(b)
if na == 0 or nb == 0:
return 0.0
return float(np.dot(a, b) / (na * nb))