import gradio as gr from gradio import wasm_utils from fastrtc import ReplyOnPause, AlgoOptions, SileroVadOptions, AdditionalOutputs, WebRTC, get_cloudflare_turn_credentials_async, get_cloudflare_turn_credentials #get_hf_turn_credentials, import os from dotenv import load_dotenv import time import numpy as np import sys import uuid import asyncio sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) from threading import Lock, Thread from backend.tts import synthesize_text from backend.asr import transcribe_audio, transcribe_typhoon from backend.utils import preprocess_audio, is_valid_turn, preprocess_audio_simplified,calculate_be_together_premium, PersonaState from backend.main import stream_chat_response from backend.models import LLMFinanceAnalyzer from pydub import AudioSegment from backend.utils import get_device if get_device() == "mps": load_dotenv(override=True) llm_analyzer = LLMFinanceAnalyzer() PERSONA_MODES = ("BT smart retirement", "BT Unit linked") DEFAULT_PERSONA_MODE = PERSONA_MODES[0] initial_persona_state = PersonaState() initial_persona_dict = initial_persona_state.get_persona() initial_persona_string = initial_persona_state.get_persona_string() phone_waiting_sound = AudioSegment.from_mp3("frontend/phone-ringing-382734.mp3")[:1000] sound_samples = np.array(phone_waiting_sound.get_array_of_samples(), dtype=np.int16) if phone_waiting_sound.channels > 1: sound_samples = sound_samples.reshape((-1, phone_waiting_sound.channels)).mean(axis=1) sound_samples = sound_samples.astype(np.float32) / 32768.0 # Normalize to [-1, def startup(*args): """ ReplyOnPause startup hook. Accepts the same inputs as the main response fn, but we only care about the persona state so tolerate extra arguments. """ persona_dict = None for arg in args: if isinstance(arg, dict) and {"Gender", "Age"}.issubset(arg.keys()): persona_dict = arg break if persona_dict is None: persona_dict = initial_persona_state.get_persona() print(f"persona:{persona_dict}") current_gender = persona_dict.get("Gender", "Female") current_age = persona_dict.get("Age") yield (phone_waiting_sound.frame_rate, sound_samples) STARTUP_MESSAGE = "ฮัลโหล" yield from synthesize_text(STARTUP_MESSAGE, gender=current_gender, age = current_age) time.sleep(2) yield AdditionalOutputs([{"role": "assistant", "content": STARTUP_MESSAGE}]) custom_css = """ /* Overall Gradio page styling: hot pink background */ body { /* background-color: #ff69b4; /* Hot pink */ margin: 0; padding: 0; font-family: sans-serif;} /* Title styling */ h1 { color: #fff; text-shadow: 1px 1px 2px #ff85a2; font-size: 2.5em; margin-bottom: 20px; text-align: center; } /* Style the column holding the telephone interface */ .phone-column { max-width: 350px !important; /* Limit the width of the phone column */ margin: 0 auto; /* Center the column */ border-radius: 20px; background-color: #f9cb9c; /* Lighter pink for telephone interface */ box-shadow: 0 0 15px rgba(0, 0, 0, 0.2); padding: 20px; } /* Conversation history box styling */ #conversation-history-chatbot { background-color: #f9cb9c; /* Lighter pink for conversation history */ border: 1px solid #ccc; border-radius: 10px; padding: 10px; box-shadow: 0 0 15px rgba(0, 0, 0, 0.2); } """ _user_info_loop = None _user_info_thread = None _user_info_loop_lock = Lock() def _user_info_loop_worker(loop): asyncio.set_event_loop(loop) loop.run_forever() def _ensure_user_info_loop(): global _user_info_loop, _user_info_thread with _user_info_loop_lock: if _user_info_loop is None or _user_info_loop.is_closed(): _user_info_loop = asyncio.new_event_loop() _user_info_thread = Thread(target=_user_info_loop_worker, args=(_user_info_loop,), daemon=True) _user_info_thread.start() return _user_info_loop def _generate_unit_linked_user_info(persona_dict): """Run the async LLM call for BT Unit linked persona info on a background loop.""" loop = _ensure_user_info_loop() persona_text = PersonaState.format_persona(persona_dict) future = asyncio.run_coroutine_threadsafe( llm_analyzer.generate_user_info(persona_text), loop, ) return future.result() def snapshot_history(history): """Return a shallow copy of the current chatbot history.""" return [dict(turn) for turn in history] if history else [] def response( audio: tuple[int, np.ndarray] | None, conversation_history, persona_state_dict, session_id: str | None, user_info: dict | str | None = None, mode: str | None = None, ): """ Handles user audio input, transcribes it, streams LLM text via backend.main, and synthesizes chunks to audio while updating the conversation history. The persona mode is forwarded to the backend for prompt conditioning. """ print(f"WebRTC input SR: {audio[0]}") print(f"--- Latency Breakdown ---") # # stage_asr = "normal" #["normal,"gemini"] # print('-----------------------------') # print(f"Initial conver:{conversation_history}") # print('-----------------------------') persona_dict = dict(persona_state_dict or initial_persona_state.get_persona()) selected_mode = mode if mode in PERSONA_MODES else DEFAULT_PERSONA_MODE persona_metadata = dict(persona_dict) persona_metadata["Mode"] = selected_mode current_gender = persona_dict.get("Gender", "Female") current_age = persona_dict.get("Age") persona_string = PersonaState.format_persona(persona_dict) conversation_history = conversation_history or [] start_time = time.time() session_identifier = session_id or "" if not session_identifier: session_identifier = str(uuid.uuid4()) print(f"[WARN] Missing session_id; generated temporary session {session_identifier}") if conversation_history is None: conversation_history = [] if not audio or audio[1] is None or not np.any(audio[1]): print("No audio input detected; skipping response generation.") print(f"------------------------") return import soundfile as sf sample_rate, audio_array = audio try: sr , processed_audio = preprocess_audio_simplified((sample_rate, audio_array), target_sr=16000) print(sr, processed_audio.dtype, processed_audio.min(), processed_audio.max(), processed_audio.shape) except Exception as audio_err: print(f"Audio preprocessing failed: {audio_err}") print(f"------------------------") return silence_duration_s = 0.2 # Calculate the number of samples corresponding to the silence duration silence_samples = int(16000 * silence_duration_s) # Create a silent audio segment (an array of zeros) # Ensure the dtype matches your processed audio for compatibility leading_silence = np.zeros(silence_samples, dtype=np.float32) # Prepend the silence to the beginning of your processed audio audio_with_padding = np.concatenate([leading_silence, processed_audio]) print(f"Added {silence_duration_s}s of silence. New shape: {audio_with_padding.shape}") file_name = "temp.wav" sf.write(file_name, audio_with_padding, sr) t0 = time.time() transcription = transcribe_typhoon(file_name) # transcription = transcribe_audio( "debug_processed.wav") t_asr = time.time() - t0 print(f"ASR: {t_asr:.4f}s") if not transcription.strip(): print("No valid transcription; skipping response generation.") print(f"------------------------") return user_turn = {"role": "user", "content": transcription} print(f"User: {transcription}") if is_valid_turn(user_turn): conversation_history.append(user_turn) yield AdditionalOutputs(snapshot_history(conversation_history)) # print("Conversation history:", conversation_history) assistant_turn = {"role": "assistant", "content": ""} conversation_history.append(assistant_turn) text_buffer = "" full_response = "" delimiter_count = 0 n_threshold = 1 max_n_threshold = 1 lang = "th" chunk_count = 0 first_chunk_sent = False start_llm_stream = time.time() try: for chunk in stream_chat_response( session_identifier, transcription, persona_string, persona_metadata, user_info, mode ): # print(f"LLM chunk: {text_chunk}") if isinstance(chunk, str): text_chunk = chunk i = 0 while i < len(text_chunk): char = text_chunk[i] text_buffer += char full_response += char assistant_turn["content"] = full_response.strip() is_delimiter = False # if char in {' ', '\n'}: if char == "|": is_delimiter = True delimiter_count += 1 if i + 1 < len(text_chunk) and text_chunk[i + 1] == 'ๆ': text_buffer += text_chunk[i + 1] full_response += text_chunk[i + 1] i += 1 send_now = False if not first_chunk_sent: if is_delimiter and text_buffer.strip(): send_now = True else: if delimiter_count >= n_threshold and text_buffer.strip(): send_now = True if n_threshold < max_n_threshold: n_threshold += 1 if send_now: buffer_to_send = text_buffer.strip() try: if buffer_to_send and buffer_to_send.endswith('วันที่'): buffer_to_send = buffer_to_send[:-len('วันที่')] if buffer_to_send and first_chunk_sent and buffer_to_send.endswith('ค่ะ'): buffer_to_send = buffer_to_send[:-len('ค่ะ')] except Exception: buffer_to_send = buffer_to_send.replace('ค่ะ', '') if buffer_to_send: chunk_count += 1 if chunk_count == 1: first_llm_chunk_time = time.time() t_llm_first_token = first_llm_chunk_time - start_llm_stream print(f"LLM TTFC: {t_llm_first_token:.4f}s (Time To First Chunk)") yield from synthesize_text(buffer_to_send, lang=lang, gender=current_gender, age = current_age) first_chunk_sent = True text_buffer = "" delimiter_count = 0 yield AdditionalOutputs(snapshot_history(conversation_history)) i += 1 if text_buffer.strip(): buffer_to_send = text_buffer.strip() try: if buffer_to_send and buffer_to_send.endswith('วันที่'): buffer_to_send = buffer_to_send[:-len('วันที่')] if buffer_to_send and first_chunk_sent and buffer_to_send.endswith('ค่ะ'): buffer_to_send = buffer_to_send[:-len('ค่ะ')] except Exception: buffer_to_send = buffer_to_send.replace('ค่ะ', '') if buffer_to_send: chunk_count += 1 if chunk_count == 1: first_llm_chunk_time = time.time() t_llm_first_token = first_llm_chunk_time - start_llm_stream print(f"LLM TTFC: {t_llm_first_token:.4f}s (Time To First Chunk)") yield from synthesize_text(buffer_to_send, lang=lang, gender=current_gender, age = current_age) first_chunk_sent = True text_buffer = "" delimiter_count = 0 yield AdditionalOutputs(snapshot_history(conversation_history)) except Exception as e: print(f"An error occurred during response generation or synthesis: {e}") error_message = "ขออภัยค่ะ เกิดข้อผิดพลาดบางอย่าง" try: yield from synthesize_text(error_message, lang=lang, gender=current_gender, age = current_age) except Exception as synth_error: print(f"Could not synthesize error message: {synth_error}") assistant_turn["content"] = (assistant_turn.get("content", "") + f" [Error: {e}]").strip() yield AdditionalOutputs(snapshot_history(conversation_history)) total_latency = time.time() - start_time print(f"Total: {total_latency:.4f}s") print(f"------------------------") # Keep persona-derived mock financial info aligned with the current persona import random def get_user_info(current_persona, mode: str = DEFAULT_PERSONA_MODE): if not current_persona: return "ยังไม่มีข้อมูลผู้ใช้" income = current_persona.get("เงินได้สุทธิ") age = current_persona.get("Age") gender = current_persona.get("Gender") mode_value = mode if mode in PERSONA_MODES else DEFAULT_PERSONA_MODE tax_deduct = min(0.15 * income, 300000) info_lines = [ f"Mode: {mode_value}", f"เงินได้สุทธิ : {income:,.0f} บาท", f"ลดหย่อนภาษีได้ {tax_deduct:,.0f} บาท", ] if mode_value == "BT Unit linked": try: generated_info = _generate_unit_linked_user_info(current_persona) if generated_info: return generated_info except: return "None" has_policy = random.choice(["yes", "no"]) if has_policy == "yes" and isinstance(age, int): # sum_assured = income * round(random.uniform(1.5, 5), 1) plan = random.choice(["จ่าย 8 ปี", "จ่ายถึง 60"]) final_premium = random.randrange(20, int(tax_deduct * 1.5 / 1000)) * 1000 # final_premium = calculate_be_together_premium(age, sum_assured, plan, gender) premium_text = final_premium if isinstance(final_premium, str) else f"{final_premium:,.2f} บาท" if age > 35: purchase_year = random.randint(1,11) elif age > 25: purchase_year = random.randint(1,2) else: purchase_year = 1 info_lines.extend( [ "ประวัติการถือประกัน :", "BT smart retirement:", f" - แผน: {plan}", # f" - เงินเอาประกัน: {sum_assured:,.0f} บาท", f" - ถือมาแล้ว: {purchase_year} ปี", f" - เบี้ยประกันต่อปี: {premium_text}", ] ) return "\n".join(info_lines) info_lines.extend( [ "ประวัติการถือประกัน:", " - ไม่เคยถือ", ] ) return "\n".join(info_lines) def randomize_persona(selected_mode: str | None = None): """Generate a fresh persona for the selected mode and reset the visible chat history.""" mode = selected_mode if selected_mode in PERSONA_MODES else DEFAULT_PERSONA_MODE new_persona_state = PersonaState() persona_dict = new_persona_state.get_persona() persona_string = PersonaState.format_persona(persona_dict) user_info_text = get_user_info(persona_dict, mode) new_session_id = str(uuid.uuid4()) return ( persona_dict, persona_string, [], new_session_id, user_info_text, user_info_text, new_session_id, ) def update_user_info_for_mode(selected_mode: str | None, persona_state_dict): """Update user info when switching modes without changing the persona.""" mode = selected_mode if selected_mode in PERSONA_MODES else DEFAULT_PERSONA_MODE persona_dict = dict(persona_state_dict or initial_persona_state.get_persona()) user_info_text = get_user_info(persona_dict, mode) return user_info_text, user_info_text def initialize_session_id(): """Create a new session identifier for syncing backend history.""" session_id = str(uuid.uuid4()) return session_id, session_id async def get_credentials(): return await get_cloudflare_turn_credentials_async(hf_token=os.getenv('HF_TOKEN')) with gr.Blocks(css=custom_css, theme=gr.themes.Soft(primary_hue="orange", secondary_hue="orange")) as demo: current_persona = gr.State(value=initial_persona_dict) session_state = gr.State(value=None) user_info = gr.State(value=get_user_info(initial_persona_dict, DEFAULT_PERSONA_MODE)) gr.HTML("""

AIA Voicebot Demo

""") with gr.Row(): with gr.Column(scale=1, elem_classes=["phone-column"]): audio = WebRTC( mode="send-receive", modality="audio", track_constraints={ "echoCancellation": True, "noiseSuppression": {"exact": True}, "autoGainControl": {"exact": True} }, rtc_configuration=get_credentials, server_rtc_configuration=get_cloudflare_turn_credentials(ttl=360_000), icon="https://i.pinimg.com/originals/0c/67/5a/0c675a8e1061478d2b7b21b330093444.gif", icon_button_color="#17dbaa", pulse_color="#b0f83b", button_labels={"start": "Call", "stop": "Hang up", "waiting": "Connecting…"}, icon_radius=45, height="650px", width="100%", container=False, elem_id="phone-call-webrtc" ) with gr.Column(scale=1): persona_display = gr.Textbox( label="User persona", lines=10, interactive=False, value=initial_persona_string, elem_classes=["persona-box"] ) mode_selector = gr.Dropdown( label="Persona mode", choices=PERSONA_MODES, value=DEFAULT_PERSONA_MODE, interactive=True, ) random_persona_btn = gr.Button( "Random persona", variant="secondary", ) user_info_display = gr.Textbox( label="User INFO", lines=10, interactive=False, value=user_info.value, elem_classes=["persona-box"] ) session_display = gr.Textbox( label="Session ID", interactive=False, value="Initializing…", elem_classes=["persona-box"] ) with gr.Column(scale=2): conversation_history = gr.Chatbot( label="Conversation History", type="messages", value=[], height="675px", resizable=True, avatar_images=(None, "https://i.pinimg.com/originals/0c/67/5a/0c675a8e1061478d2b7b21b330093444.gif"), ) random_persona_btn.click( fn=randomize_persona, inputs=[mode_selector], outputs=[current_persona, persona_display, conversation_history, session_state, user_info, user_info_display, session_display], show_progress="hidden", queue=False, ) mode_selector.change( fn=update_user_info_for_mode, inputs=[mode_selector, current_persona], outputs=[user_info, user_info_display], show_progress="hidden", queue=False, ) demo.load( fn=initialize_session_id, inputs=None, outputs=[session_state, session_display], queue=False, ) gr.DeepLinkButton() audio.stream( fn=ReplyOnPause( response, algo_options=AlgoOptions( audio_chunk_duration=1.35, started_talking_threshold=0.35, speech_threshold=0.2 ), model_options=SileroVadOptions( threshold=0.65, min_speech_duration_ms=200, max_speech_duration_s=float("inf"), min_silence_duration_ms=1200, speech_pad_ms=300 ), can_interrupt=False, startup_fn=startup, ), inputs=[audio, conversation_history, current_persona, session_state, user_info, mode_selector], outputs=[audio], concurrency_limit=1000, time_limit=8192 ) audio.on_additional_outputs( lambda history: history, outputs=[conversation_history], queue=True, show_progress="hidden" ) demo.queue(default_concurrency_limit=1000) demo.launch(debug=True, show_error=True, share=True)