Spaces:
Sleeping
Sleeping
fix: update tax deduction limit and adjust premium calculation logic in get_user_info function
3b0c9f6
| import gradio as gr | |
| from gradio import wasm_utils | |
| from fastrtc import ReplyOnPause, AlgoOptions, SileroVadOptions, AdditionalOutputs, WebRTC, get_cloudflare_turn_credentials_async, get_cloudflare_turn_credentials #get_hf_turn_credentials, | |
| import os | |
| from dotenv import load_dotenv | |
| import time | |
| import numpy as np | |
| import sys | |
| import uuid | |
| import asyncio | |
| sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) | |
| from threading import Lock, Thread | |
| from backend.tts import synthesize_text | |
| from backend.asr import transcribe_audio, transcribe_typhoon | |
| from backend.utils import preprocess_audio, is_valid_turn, preprocess_audio_simplified,calculate_be_together_premium, PersonaState | |
| from backend.main import stream_chat_response | |
| from backend.models import LLMFinanceAnalyzer | |
| from pydub import AudioSegment | |
| from backend.utils import get_device | |
| if get_device() == "mps": | |
| load_dotenv(override=True) | |
| llm_analyzer = LLMFinanceAnalyzer() | |
| PERSONA_MODES = ("BT smart retirement", "BT Unit linked") | |
| DEFAULT_PERSONA_MODE = PERSONA_MODES[0] | |
| initial_persona_state = PersonaState() | |
| initial_persona_dict = initial_persona_state.get_persona() | |
| initial_persona_string = initial_persona_state.get_persona_string() | |
| phone_waiting_sound = AudioSegment.from_mp3("frontend/phone-ringing-382734.mp3")[:1000] | |
| sound_samples = np.array(phone_waiting_sound.get_array_of_samples(), dtype=np.int16) | |
| if phone_waiting_sound.channels > 1: | |
| sound_samples = sound_samples.reshape((-1, phone_waiting_sound.channels)).mean(axis=1) | |
| sound_samples = sound_samples.astype(np.float32) / 32768.0 # Normalize to [-1, | |
| def startup(*args): | |
| """ | |
| ReplyOnPause startup hook. Accepts the same inputs as the main response fn, | |
| but we only care about the persona state so tolerate extra arguments. | |
| """ | |
| persona_dict = None | |
| for arg in args: | |
| if isinstance(arg, dict) and {"Gender", "Age"}.issubset(arg.keys()): | |
| persona_dict = arg | |
| break | |
| if persona_dict is None: | |
| persona_dict = initial_persona_state.get_persona() | |
| print(f"persona:{persona_dict}") | |
| current_gender = persona_dict.get("Gender", "Female") | |
| current_age = persona_dict.get("Age") | |
| yield (phone_waiting_sound.frame_rate, sound_samples) | |
| STARTUP_MESSAGE = "ฮัลโหล" | |
| yield from synthesize_text(STARTUP_MESSAGE, gender=current_gender, age = current_age) | |
| time.sleep(2) | |
| yield AdditionalOutputs([{"role": "assistant", "content": STARTUP_MESSAGE}]) | |
| custom_css = """ | |
| /* Overall Gradio page styling: hot pink background */ | |
| body { | |
| /* background-color: #ff69b4; /* Hot pink */ | |
| margin: 0; | |
| padding: 0; | |
| font-family: sans-serif;} | |
| /* Title styling */ | |
| h1 { | |
| color: #fff; | |
| text-shadow: 1px 1px 2px #ff85a2; | |
| font-size: 2.5em; | |
| margin-bottom: 20px; | |
| text-align: center; | |
| } | |
| /* Style the column holding the telephone interface */ | |
| .phone-column { | |
| max-width: 350px !important; /* Limit the width of the phone column */ | |
| margin: 0 auto; /* Center the column */ | |
| border-radius: 20px; | |
| background-color: #f9cb9c; /* Lighter pink for telephone interface */ | |
| box-shadow: 0 0 15px rgba(0, 0, 0, 0.2); | |
| padding: 20px; | |
| } | |
| /* Conversation history box styling */ | |
| #conversation-history-chatbot { | |
| background-color: #f9cb9c; /* Lighter pink for conversation history */ | |
| border: 1px solid #ccc; | |
| border-radius: 10px; | |
| padding: 10px; | |
| box-shadow: 0 0 15px rgba(0, 0, 0, 0.2); | |
| } | |
| """ | |
| _user_info_loop = None | |
| _user_info_thread = None | |
| _user_info_loop_lock = Lock() | |
| def _user_info_loop_worker(loop): | |
| asyncio.set_event_loop(loop) | |
| loop.run_forever() | |
| def _ensure_user_info_loop(): | |
| global _user_info_loop, _user_info_thread | |
| with _user_info_loop_lock: | |
| if _user_info_loop is None or _user_info_loop.is_closed(): | |
| _user_info_loop = asyncio.new_event_loop() | |
| _user_info_thread = Thread(target=_user_info_loop_worker, args=(_user_info_loop,), daemon=True) | |
| _user_info_thread.start() | |
| return _user_info_loop | |
| def _generate_unit_linked_user_info(persona_dict): | |
| """Run the async LLM call for BT Unit linked persona info on a background loop.""" | |
| loop = _ensure_user_info_loop() | |
| persona_text = PersonaState.format_persona(persona_dict) | |
| future = asyncio.run_coroutine_threadsafe( | |
| llm_analyzer.generate_user_info(persona_text), | |
| loop, | |
| ) | |
| return future.result() | |
| def snapshot_history(history): | |
| """Return a shallow copy of the current chatbot history.""" | |
| return [dict(turn) for turn in history] if history else [] | |
| def response( | |
| audio: tuple[int, np.ndarray] | None, | |
| conversation_history, | |
| persona_state_dict, | |
| session_id: str | None, | |
| user_info: dict | str | None = None, | |
| mode: str | None = None, | |
| ): | |
| """ | |
| Handles user audio input, transcribes it, streams LLM text via backend.main, | |
| and synthesizes chunks to audio while updating the conversation history. | |
| The persona mode is forwarded to the backend for prompt conditioning. | |
| """ | |
| print(f"WebRTC input SR: {audio[0]}") | |
| print(f"--- Latency Breakdown ---") | |
| # # stage_asr = "normal" #["normal,"gemini"] | |
| # print('-----------------------------') | |
| # print(f"Initial conver:{conversation_history}") | |
| # print('-----------------------------') | |
| persona_dict = dict(persona_state_dict or initial_persona_state.get_persona()) | |
| selected_mode = mode if mode in PERSONA_MODES else DEFAULT_PERSONA_MODE | |
| persona_metadata = dict(persona_dict) | |
| persona_metadata["Mode"] = selected_mode | |
| current_gender = persona_dict.get("Gender", "Female") | |
| current_age = persona_dict.get("Age") | |
| persona_string = PersonaState.format_persona(persona_dict) | |
| conversation_history = conversation_history or [] | |
| start_time = time.time() | |
| session_identifier = session_id or "" | |
| if not session_identifier: | |
| session_identifier = str(uuid.uuid4()) | |
| print(f"[WARN] Missing session_id; generated temporary session {session_identifier}") | |
| if conversation_history is None: | |
| conversation_history = [] | |
| if not audio or audio[1] is None or not np.any(audio[1]): | |
| print("No audio input detected; skipping response generation.") | |
| print(f"------------------------") | |
| return | |
| import soundfile as sf | |
| sample_rate, audio_array = audio | |
| try: | |
| sr , processed_audio = preprocess_audio_simplified((sample_rate, audio_array), target_sr=16000) | |
| print(sr, processed_audio.dtype, processed_audio.min(), processed_audio.max(), processed_audio.shape) | |
| except Exception as audio_err: | |
| print(f"Audio preprocessing failed: {audio_err}") | |
| print(f"------------------------") | |
| return | |
| silence_duration_s = 0.2 | |
| # Calculate the number of samples corresponding to the silence duration | |
| silence_samples = int(16000 * silence_duration_s) | |
| # Create a silent audio segment (an array of zeros) | |
| # Ensure the dtype matches your processed audio for compatibility | |
| leading_silence = np.zeros(silence_samples, dtype=np.float32) | |
| # Prepend the silence to the beginning of your processed audio | |
| audio_with_padding = np.concatenate([leading_silence, processed_audio]) | |
| print(f"Added {silence_duration_s}s of silence. New shape: {audio_with_padding.shape}") | |
| file_name = "temp.wav" | |
| sf.write(file_name, audio_with_padding, sr) | |
| t0 = time.time() | |
| transcription = transcribe_typhoon(file_name) | |
| # transcription = transcribe_audio( "debug_processed.wav") | |
| t_asr = time.time() - t0 | |
| print(f"ASR: {t_asr:.4f}s") | |
| if not transcription.strip(): | |
| print("No valid transcription; skipping response generation.") | |
| print(f"------------------------") | |
| return | |
| user_turn = {"role": "user", "content": transcription} | |
| print(f"User: {transcription}") | |
| if is_valid_turn(user_turn): | |
| conversation_history.append(user_turn) | |
| yield AdditionalOutputs(snapshot_history(conversation_history)) | |
| # print("Conversation history:", conversation_history) | |
| assistant_turn = {"role": "assistant", "content": ""} | |
| conversation_history.append(assistant_turn) | |
| text_buffer = "" | |
| full_response = "" | |
| delimiter_count = 0 | |
| n_threshold = 1 | |
| max_n_threshold = 1 | |
| lang = "th" | |
| chunk_count = 0 | |
| first_chunk_sent = False | |
| start_llm_stream = time.time() | |
| try: | |
| for chunk in stream_chat_response( | |
| session_identifier, | |
| transcription, | |
| persona_string, | |
| persona_metadata, | |
| user_info, | |
| mode | |
| ): | |
| # print(f"LLM chunk: {text_chunk}") | |
| if isinstance(chunk, str): | |
| text_chunk = chunk | |
| i = 0 | |
| while i < len(text_chunk): | |
| char = text_chunk[i] | |
| text_buffer += char | |
| full_response += char | |
| assistant_turn["content"] = full_response.strip() | |
| is_delimiter = False | |
| # if char in {' ', '\n'}: | |
| if char == "|": | |
| is_delimiter = True | |
| delimiter_count += 1 | |
| if i + 1 < len(text_chunk) and text_chunk[i + 1] == 'ๆ': | |
| text_buffer += text_chunk[i + 1] | |
| full_response += text_chunk[i + 1] | |
| i += 1 | |
| send_now = False | |
| if not first_chunk_sent: | |
| if is_delimiter and text_buffer.strip(): | |
| send_now = True | |
| else: | |
| if delimiter_count >= n_threshold and text_buffer.strip(): | |
| send_now = True | |
| if n_threshold < max_n_threshold: | |
| n_threshold += 1 | |
| if send_now: | |
| buffer_to_send = text_buffer.strip() | |
| try: | |
| if buffer_to_send and buffer_to_send.endswith('วันที่'): | |
| buffer_to_send = buffer_to_send[:-len('วันที่')] | |
| if buffer_to_send and first_chunk_sent and buffer_to_send.endswith('ค่ะ'): | |
| buffer_to_send = buffer_to_send[:-len('ค่ะ')] | |
| except Exception: | |
| buffer_to_send = buffer_to_send.replace('ค่ะ', '') | |
| if buffer_to_send: | |
| chunk_count += 1 | |
| if chunk_count == 1: | |
| first_llm_chunk_time = time.time() | |
| t_llm_first_token = first_llm_chunk_time - start_llm_stream | |
| print(f"LLM TTFC: {t_llm_first_token:.4f}s (Time To First Chunk)") | |
| yield from synthesize_text(buffer_to_send, lang=lang, gender=current_gender, age = current_age) | |
| first_chunk_sent = True | |
| text_buffer = "" | |
| delimiter_count = 0 | |
| yield AdditionalOutputs(snapshot_history(conversation_history)) | |
| i += 1 | |
| if text_buffer.strip(): | |
| buffer_to_send = text_buffer.strip() | |
| try: | |
| if buffer_to_send and buffer_to_send.endswith('วันที่'): | |
| buffer_to_send = buffer_to_send[:-len('วันที่')] | |
| if buffer_to_send and first_chunk_sent and buffer_to_send.endswith('ค่ะ'): | |
| buffer_to_send = buffer_to_send[:-len('ค่ะ')] | |
| except Exception: | |
| buffer_to_send = buffer_to_send.replace('ค่ะ', '') | |
| if buffer_to_send: | |
| chunk_count += 1 | |
| if chunk_count == 1: | |
| first_llm_chunk_time = time.time() | |
| t_llm_first_token = first_llm_chunk_time - start_llm_stream | |
| print(f"LLM TTFC: {t_llm_first_token:.4f}s (Time To First Chunk)") | |
| yield from synthesize_text(buffer_to_send, lang=lang, gender=current_gender, age = current_age) | |
| first_chunk_sent = True | |
| text_buffer = "" | |
| delimiter_count = 0 | |
| yield AdditionalOutputs(snapshot_history(conversation_history)) | |
| except Exception as e: | |
| print(f"An error occurred during response generation or synthesis: {e}") | |
| error_message = "ขออภัยค่ะ เกิดข้อผิดพลาดบางอย่าง" | |
| try: | |
| yield from synthesize_text(error_message, lang=lang, gender=current_gender, age = current_age) | |
| except Exception as synth_error: | |
| print(f"Could not synthesize error message: {synth_error}") | |
| assistant_turn["content"] = (assistant_turn.get("content", "") + f" [Error: {e}]").strip() | |
| yield AdditionalOutputs(snapshot_history(conversation_history)) | |
| total_latency = time.time() - start_time | |
| print(f"Total: {total_latency:.4f}s") | |
| print(f"------------------------") | |
| # Keep persona-derived mock financial info aligned with the current persona | |
| import random | |
| def get_user_info(current_persona, mode: str = DEFAULT_PERSONA_MODE): | |
| if not current_persona: | |
| return "ยังไม่มีข้อมูลผู้ใช้" | |
| income = current_persona.get("เงินได้สุทธิ") | |
| age = current_persona.get("Age") | |
| gender = current_persona.get("Gender") | |
| mode_value = mode if mode in PERSONA_MODES else DEFAULT_PERSONA_MODE | |
| tax_deduct = min(0.15 * income, 300000) | |
| info_lines = [ | |
| f"Mode: {mode_value}", | |
| f"เงินได้สุทธิ : {income:,.0f} บาท", | |
| f"ลดหย่อนภาษีได้ {tax_deduct:,.0f} บาท", | |
| ] | |
| if mode_value == "BT Unit linked": | |
| try: | |
| generated_info = _generate_unit_linked_user_info(current_persona) | |
| if generated_info: | |
| return generated_info | |
| except: | |
| return "None" | |
| has_policy = random.choice(["yes", "no"]) | |
| if has_policy == "yes" and isinstance(age, int): | |
| # sum_assured = income * round(random.uniform(1.5, 5), 1) | |
| plan = random.choice(["จ่าย 8 ปี", "จ่ายถึง 60"]) | |
| final_premium = random.randrange(20, int(tax_deduct * 1.5 / 1000)) * 1000 | |
| # final_premium = calculate_be_together_premium(age, sum_assured, plan, gender) | |
| premium_text = final_premium if isinstance(final_premium, str) else f"{final_premium:,.2f} บาท" | |
| if age > 35: | |
| purchase_year = random.randint(1,11) | |
| elif age > 25: | |
| purchase_year = random.randint(1,2) | |
| else: | |
| purchase_year = 1 | |
| info_lines.extend( | |
| [ | |
| "ประวัติการถือประกัน :", | |
| "BT smart retirement:", | |
| f" - แผน: {plan}", | |
| # f" - เงินเอาประกัน: {sum_assured:,.0f} บาท", | |
| f" - ถือมาแล้ว: {purchase_year} ปี", | |
| f" - เบี้ยประกันต่อปี: {premium_text}", | |
| ] | |
| ) | |
| return "\n".join(info_lines) | |
| info_lines.extend( | |
| [ | |
| "ประวัติการถือประกัน:", | |
| " - ไม่เคยถือ", | |
| ] | |
| ) | |
| return "\n".join(info_lines) | |
| def randomize_persona(selected_mode: str | None = None): | |
| """Generate a fresh persona for the selected mode and reset the visible chat history.""" | |
| mode = selected_mode if selected_mode in PERSONA_MODES else DEFAULT_PERSONA_MODE | |
| new_persona_state = PersonaState() | |
| persona_dict = new_persona_state.get_persona() | |
| persona_string = PersonaState.format_persona(persona_dict) | |
| user_info_text = get_user_info(persona_dict, mode) | |
| new_session_id = str(uuid.uuid4()) | |
| return ( | |
| persona_dict, | |
| persona_string, | |
| [], | |
| new_session_id, | |
| user_info_text, | |
| user_info_text, | |
| new_session_id, | |
| ) | |
| def update_user_info_for_mode(selected_mode: str | None, persona_state_dict): | |
| """Update user info when switching modes without changing the persona.""" | |
| mode = selected_mode if selected_mode in PERSONA_MODES else DEFAULT_PERSONA_MODE | |
| persona_dict = dict(persona_state_dict or initial_persona_state.get_persona()) | |
| user_info_text = get_user_info(persona_dict, mode) | |
| return user_info_text, user_info_text | |
| def initialize_session_id(): | |
| """Create a new session identifier for syncing backend history.""" | |
| session_id = str(uuid.uuid4()) | |
| return session_id, session_id | |
| async def get_credentials(): | |
| return await get_cloudflare_turn_credentials_async(hf_token=os.getenv('HF_TOKEN')) | |
| with gr.Blocks(css=custom_css, theme=gr.themes.Soft(primary_hue="orange", secondary_hue="orange")) as demo: | |
| current_persona = gr.State(value=initial_persona_dict) | |
| session_state = gr.State(value=None) | |
| user_info = gr.State(value=get_user_info(initial_persona_dict, DEFAULT_PERSONA_MODE)) | |
| gr.HTML("""<h1 style='text-align: center'>AIA Voicebot Demo</h1>""") | |
| with gr.Row(): | |
| with gr.Column(scale=1, elem_classes=["phone-column"]): | |
| audio = WebRTC( | |
| mode="send-receive", | |
| modality="audio", | |
| track_constraints={ | |
| "echoCancellation": True, | |
| "noiseSuppression": {"exact": True}, | |
| "autoGainControl": {"exact": True} | |
| }, | |
| rtc_configuration=get_credentials, | |
| server_rtc_configuration=get_cloudflare_turn_credentials(ttl=360_000), | |
| icon="https://i.pinimg.com/originals/0c/67/5a/0c675a8e1061478d2b7b21b330093444.gif", | |
| icon_button_color="#17dbaa", | |
| pulse_color="#b0f83b", | |
| button_labels={"start": "Call", "stop": "Hang up", "waiting": "Connecting…"}, | |
| icon_radius=45, | |
| height="650px", | |
| width="100%", | |
| container=False, | |
| elem_id="phone-call-webrtc" | |
| ) | |
| with gr.Column(scale=1): | |
| persona_display = gr.Textbox( | |
| label="User persona", | |
| lines=10, | |
| interactive=False, | |
| value=initial_persona_string, | |
| elem_classes=["persona-box"] | |
| ) | |
| mode_selector = gr.Dropdown( | |
| label="Persona mode", | |
| choices=PERSONA_MODES, | |
| value=DEFAULT_PERSONA_MODE, | |
| interactive=True, | |
| ) | |
| random_persona_btn = gr.Button( | |
| "Random persona", | |
| variant="secondary", | |
| ) | |
| user_info_display = gr.Textbox( | |
| label="User INFO", | |
| lines=10, | |
| interactive=False, | |
| value=user_info.value, | |
| elem_classes=["persona-box"] | |
| ) | |
| session_display = gr.Textbox( | |
| label="Session ID", | |
| interactive=False, | |
| value="Initializing…", | |
| elem_classes=["persona-box"] | |
| ) | |
| with gr.Column(scale=2): | |
| conversation_history = gr.Chatbot( | |
| label="Conversation History", | |
| type="messages", | |
| value=[], | |
| height="675px", | |
| resizable=True, | |
| avatar_images=(None, "https://i.pinimg.com/originals/0c/67/5a/0c675a8e1061478d2b7b21b330093444.gif"), | |
| ) | |
| random_persona_btn.click( | |
| fn=randomize_persona, | |
| inputs=[mode_selector], | |
| outputs=[current_persona, persona_display, conversation_history, session_state, user_info, user_info_display, session_display], | |
| show_progress="hidden", | |
| queue=False, | |
| ) | |
| mode_selector.change( | |
| fn=update_user_info_for_mode, | |
| inputs=[mode_selector, current_persona], | |
| outputs=[user_info, user_info_display], | |
| show_progress="hidden", | |
| queue=False, | |
| ) | |
| demo.load( | |
| fn=initialize_session_id, | |
| inputs=None, | |
| outputs=[session_state, session_display], | |
| queue=False, | |
| ) | |
| gr.DeepLinkButton() | |
| audio.stream( | |
| fn=ReplyOnPause( | |
| response, | |
| algo_options=AlgoOptions( | |
| audio_chunk_duration=1.35, | |
| started_talking_threshold=0.35, | |
| speech_threshold=0.2 | |
| ), | |
| model_options=SileroVadOptions( | |
| threshold=0.65, | |
| min_speech_duration_ms=200, | |
| max_speech_duration_s=float("inf"), | |
| min_silence_duration_ms=1200, | |
| speech_pad_ms=300 | |
| ), | |
| can_interrupt=False, | |
| startup_fn=startup, | |
| ), | |
| inputs=[audio, conversation_history, current_persona, session_state, user_info, mode_selector], | |
| outputs=[audio], | |
| concurrency_limit=1000, | |
| time_limit=8192 | |
| ) | |
| audio.on_additional_outputs( | |
| lambda history: history, | |
| outputs=[conversation_history], | |
| queue=True, | |
| show_progress="hidden" | |
| ) | |
| demo.queue(default_concurrency_limit=1000) | |
| demo.launch(debug=True, show_error=True, share=True) | |