Spaces:
Sleeping
Sleeping
Commit
·
2809642
1
Parent(s):
f541218
Check point 4
Browse files
app.py
CHANGED
|
@@ -8,6 +8,7 @@ import os
|
|
| 8 |
import urllib.request
|
| 9 |
import torchaudio
|
| 10 |
from scipy.spatial.distance import cosine
|
|
|
|
| 11 |
from RealtimeSTT import AudioToTextRecorder
|
| 12 |
from fastapi import FastAPI, APIRouter
|
| 13 |
from fastrtc import Stream, AsyncStreamHandler
|
|
@@ -419,7 +420,7 @@ class RealtimeSpeakerDiarization:
|
|
| 419 |
# Setup recorder configuration
|
| 420 |
recorder_config = {
|
| 421 |
'spinner': False,
|
| 422 |
-
'use_microphone': False, #
|
| 423 |
'model': FINAL_TRANSCRIPTION_MODEL,
|
| 424 |
'language': TRANSCRIPTION_LANGUAGE,
|
| 425 |
'silero_sensitivity': SILERO_SENSITIVITY,
|
|
@@ -456,8 +457,11 @@ class RealtimeSpeakerDiarization:
|
|
| 456 |
def run_transcription(self):
|
| 457 |
"""Run the transcription loop"""
|
| 458 |
try:
|
|
|
|
| 459 |
while self.is_running:
|
| 460 |
-
|
|
|
|
|
|
|
| 461 |
except Exception as e:
|
| 462 |
logger.error(f"Transcription error: {e}")
|
| 463 |
|
|
@@ -559,14 +563,30 @@ class RealtimeSpeakerDiarization:
|
|
| 559 |
if embedding is not None:
|
| 560 |
self.speaker_detector.add_embedding(embedding)
|
| 561 |
|
| 562 |
-
# Feed audio to
|
| 563 |
-
if self.recorder:
|
| 564 |
-
# Convert float32
|
| 565 |
-
|
| 566 |
-
|
|
|
|
|
|
|
| 567 |
|
| 568 |
except Exception as e:
|
| 569 |
logger.error(f"Error processing audio chunk: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 570 |
|
| 571 |
|
| 572 |
# FastRTC Audio Handler
|
|
@@ -598,7 +618,9 @@ class DiarizationHandler(AsyncStreamHandler):
|
|
| 598 |
if isinstance(audio_data, bytes):
|
| 599 |
audio_array = np.frombuffer(audio_data, dtype=np.int16).astype(np.float32) / 32768.0
|
| 600 |
elif isinstance(audio_data, (list, tuple)):
|
| 601 |
-
audio_array =
|
|
|
|
|
|
|
| 602 |
else:
|
| 603 |
audio_array = np.array(audio_data, dtype=np.float32)
|
| 604 |
|
|
@@ -636,18 +658,7 @@ class DiarizationHandler(AsyncStreamHandler):
|
|
| 636 |
|
| 637 |
# Global instances
|
| 638 |
diarization_system = RealtimeSpeakerDiarization()
|
| 639 |
-
|
| 640 |
-
# FastAPI setup for FastRTC integration
|
| 641 |
-
app = FastAPI()
|
| 642 |
-
|
| 643 |
-
# Initialize an empty handler (will be set properly in initialize_system function)
|
| 644 |
-
audio_handler = DiarizationHandler(diarization_system)
|
| 645 |
-
|
| 646 |
-
# Create FastRTC stream
|
| 647 |
-
stream = Stream(handler=audio_handler)
|
| 648 |
-
|
| 649 |
-
# Include FastRTC router in FastAPI app
|
| 650 |
-
app.include_router(stream.router, prefix="/stream")
|
| 651 |
|
| 652 |
def initialize_system():
|
| 653 |
"""Initialize the diarization system"""
|
|
@@ -656,8 +667,6 @@ def initialize_system():
|
|
| 656 |
success = diarization_system.initialize_models()
|
| 657 |
if success:
|
| 658 |
audio_handler = DiarizationHandler(diarization_system)
|
| 659 |
-
# Update the stream's handler
|
| 660 |
-
stream.handler = audio_handler
|
| 661 |
return "✅ System initialized successfully!"
|
| 662 |
else:
|
| 663 |
return "❌ Failed to initialize system. Check logs for details."
|
|
@@ -665,13 +674,6 @@ def initialize_system():
|
|
| 665 |
logger.error(f"Initialization error: {e}")
|
| 666 |
return f"❌ Initialization error: {str(e)}"
|
| 667 |
|
| 668 |
-
# Add startup event to initialize the system
|
| 669 |
-
@app.on_event("startup")
|
| 670 |
-
async def startup_event():
|
| 671 |
-
logger.info("Initializing diarization system on startup...")
|
| 672 |
-
result = initialize_system()
|
| 673 |
-
logger.info(f"Initialization result: {result}")
|
| 674 |
-
|
| 675 |
def start_recording():
|
| 676 |
"""Start recording and transcription"""
|
| 677 |
try:
|
|
@@ -857,6 +859,9 @@ def create_interface():
|
|
| 857 |
return interface
|
| 858 |
|
| 859 |
|
|
|
|
|
|
|
|
|
|
| 860 |
@app.get("/")
|
| 861 |
async def root():
|
| 862 |
return {"message": "Real-time Speaker Diarization API"}
|
|
@@ -898,6 +903,12 @@ async def api_update_settings(threshold: float, max_speakers: int):
|
|
| 898 |
result = update_settings(threshold, max_speakers)
|
| 899 |
return {"result": result}
|
| 900 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 901 |
# Main execution
|
| 902 |
if __name__ == "__main__":
|
| 903 |
import argparse
|
|
|
|
| 8 |
import urllib.request
|
| 9 |
import torchaudio
|
| 10 |
from scipy.spatial.distance import cosine
|
| 11 |
+
from scipy.signal import resample
|
| 12 |
from RealtimeSTT import AudioToTextRecorder
|
| 13 |
from fastapi import FastAPI, APIRouter
|
| 14 |
from fastrtc import Stream, AsyncStreamHandler
|
|
|
|
| 420 |
# Setup recorder configuration
|
| 421 |
recorder_config = {
|
| 422 |
'spinner': False,
|
| 423 |
+
'use_microphone': False, # Using FastRTC for audio input
|
| 424 |
'model': FINAL_TRANSCRIPTION_MODEL,
|
| 425 |
'language': TRANSCRIPTION_LANGUAGE,
|
| 426 |
'silero_sensitivity': SILERO_SENSITIVITY,
|
|
|
|
| 457 |
def run_transcription(self):
|
| 458 |
"""Run the transcription loop"""
|
| 459 |
try:
|
| 460 |
+
logger.info("Starting transcription thread")
|
| 461 |
while self.is_running:
|
| 462 |
+
# Just check for final text from recorder, audio is fed externally via FastRTC
|
| 463 |
+
text = self.recorder.text(self.process_final_text)
|
| 464 |
+
time.sleep(0.01) # Small sleep to prevent CPU hogging
|
| 465 |
except Exception as e:
|
| 466 |
logger.error(f"Transcription error: {e}")
|
| 467 |
|
|
|
|
| 563 |
if embedding is not None:
|
| 564 |
self.speaker_detector.add_embedding(embedding)
|
| 565 |
|
| 566 |
+
# Feed audio to RealtimeSTT recorder
|
| 567 |
+
if self.recorder and self.is_running:
|
| 568 |
+
# Convert float32 [-1.0, 1.0] to int16 for RealtimeSTT
|
| 569 |
+
int16_data = (audio_data * 32768.0).astype(np.int16).tobytes()
|
| 570 |
+
if sample_rate != 16000:
|
| 571 |
+
int16_data = self.resample_audio(int16_data, sample_rate, 16000)
|
| 572 |
+
self.recorder.feed_audio(int16_data)
|
| 573 |
|
| 574 |
except Exception as e:
|
| 575 |
logger.error(f"Error processing audio chunk: {e}")
|
| 576 |
+
|
| 577 |
+
def resample_audio(self, audio_bytes, from_rate, to_rate):
|
| 578 |
+
"""Resample audio to target sample rate"""
|
| 579 |
+
try:
|
| 580 |
+
audio_np = np.frombuffer(audio_bytes, dtype=np.int16)
|
| 581 |
+
num_samples = len(audio_np)
|
| 582 |
+
num_target_samples = int(num_samples * to_rate / from_rate)
|
| 583 |
+
|
| 584 |
+
resampled = resample(audio_np, num_target_samples)
|
| 585 |
+
|
| 586 |
+
return resampled.astype(np.int16).tobytes()
|
| 587 |
+
except Exception as e:
|
| 588 |
+
logger.error(f"Error resampling audio: {e}")
|
| 589 |
+
return audio_bytes
|
| 590 |
|
| 591 |
|
| 592 |
# FastRTC Audio Handler
|
|
|
|
| 618 |
if isinstance(audio_data, bytes):
|
| 619 |
audio_array = np.frombuffer(audio_data, dtype=np.int16).astype(np.float32) / 32768.0
|
| 620 |
elif isinstance(audio_data, (list, tuple)):
|
| 621 |
+
sample_rate, audio_array = audio_data
|
| 622 |
+
if isinstance(audio_array, (list, tuple)):
|
| 623 |
+
audio_array = np.array(audio_array, dtype=np.float32)
|
| 624 |
else:
|
| 625 |
audio_array = np.array(audio_data, dtype=np.float32)
|
| 626 |
|
|
|
|
| 658 |
|
| 659 |
# Global instances
|
| 660 |
diarization_system = RealtimeSpeakerDiarization()
|
| 661 |
+
audio_handler = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 662 |
|
| 663 |
def initialize_system():
|
| 664 |
"""Initialize the diarization system"""
|
|
|
|
| 667 |
success = diarization_system.initialize_models()
|
| 668 |
if success:
|
| 669 |
audio_handler = DiarizationHandler(diarization_system)
|
|
|
|
|
|
|
| 670 |
return "✅ System initialized successfully!"
|
| 671 |
else:
|
| 672 |
return "❌ Failed to initialize system. Check logs for details."
|
|
|
|
| 674 |
logger.error(f"Initialization error: {e}")
|
| 675 |
return f"❌ Initialization error: {str(e)}"
|
| 676 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 677 |
def start_recording():
|
| 678 |
"""Start recording and transcription"""
|
| 679 |
try:
|
|
|
|
| 859 |
return interface
|
| 860 |
|
| 861 |
|
| 862 |
+
# FastAPI setup for FastRTC integration
|
| 863 |
+
app = FastAPI()
|
| 864 |
+
|
| 865 |
@app.get("/")
|
| 866 |
async def root():
|
| 867 |
return {"message": "Real-time Speaker Diarization API"}
|
|
|
|
| 903 |
result = update_settings(threshold, max_speakers)
|
| 904 |
return {"result": result}
|
| 905 |
|
| 906 |
+
# FastRTC Stream setup
|
| 907 |
+
if audio_handler:
|
| 908 |
+
stream = Stream(handler=audio_handler)
|
| 909 |
+
app.include_router(stream.router, prefix="/stream")
|
| 910 |
+
|
| 911 |
+
|
| 912 |
# Main execution
|
| 913 |
if __name__ == "__main__":
|
| 914 |
import argparse
|