Maia/diarization/main.py

68 lines
2.3 KiB
Python

"""
Whisper Diarization API — Enhancement Server
Adds /api/v2 endpoints (diarization + voice enrollment) to the existing Whisper server.
Transparently proxies existing endpoints to the original Whisper server.
Original Whisper server: http://10.100.16.13:5003
This server runs on: port 5060
"""
import logging
import os
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from routers.proxy import router as proxy_router
from routers.v2_transcribe import router as v2_transcribe_router
from routers.v2_speakers import router as v2_speakers_router
from db.database import init_db
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
WHISPER_ORIGIN = os.getenv("WHISPER_ORIGIN", "http://10.100.16.13:5003")
app = FastAPI(
title="Whisper Diarization Enhancement Server",
description=(
"Adds Speaker Diarization + Voice Enrollment on top of the existing Whisper.cpp GPU server. "
"Existing endpoints (/transcribe, /transcribe-segments, /whisper, /transcribe-text, /health) "
"are transparently proxied to the original server."
),
version="2.0.0",
docs_url="/docs",
redoc_url="/redoc",
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.on_event("startup")
async def startup():
logger.info("Whisper Diarization Enhancement Server starting...")
init_db()
logger.info(f"Proxying existing endpoints to: {WHISPER_ORIGIN}")
logger.info("New endpoints available under /api/v2/")
logger.info("Server ready!")
# Proxy existing Whisper endpoints (NEVER MODIFY THESE)
app.include_router(proxy_router, tags=["proxy — original whisper endpoints"])
# New v2 endpoints
app.include_router(v2_transcribe_router, prefix="/api/v2", tags=["v2 — diarization"])
app.include_router(v2_speakers_router, prefix="/api/v2", tags=["v2 — voice enrollment"])
@app.exception_handler(Exception)
async def global_exception_handler(request, exc):
logger.error(f"Unhandled exception: {exc}", exc_info=True)
return JSONResponse(status_code=500, content={"detail": "Internal server error"})