aia / app.py
MickMick102's picture
fix: update tax deduction limit and adjust premium calculation logic in get_user_info function
3b0c9f6
import gradio as gr
from gradio import wasm_utils
from fastrtc import ReplyOnPause, AlgoOptions, SileroVadOptions, AdditionalOutputs, WebRTC, get_cloudflare_turn_credentials_async, get_cloudflare_turn_credentials #get_hf_turn_credentials,
import os
from dotenv import load_dotenv
import time
import numpy as np
import sys
import uuid
import asyncio
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from threading import Lock, Thread
from backend.tts import synthesize_text
from backend.asr import transcribe_audio, transcribe_typhoon
from backend.utils import preprocess_audio, is_valid_turn, preprocess_audio_simplified,calculate_be_together_premium, PersonaState
from backend.main import stream_chat_response
from backend.models import LLMFinanceAnalyzer
from pydub import AudioSegment
from backend.utils import get_device
if get_device() == "mps":
load_dotenv(override=True)
llm_analyzer = LLMFinanceAnalyzer()
PERSONA_MODES = ("BT smart retirement", "BT Unit linked")
DEFAULT_PERSONA_MODE = PERSONA_MODES[0]
initial_persona_state = PersonaState()
initial_persona_dict = initial_persona_state.get_persona()
initial_persona_string = initial_persona_state.get_persona_string()
phone_waiting_sound = AudioSegment.from_mp3("frontend/phone-ringing-382734.mp3")[:1000]
sound_samples = np.array(phone_waiting_sound.get_array_of_samples(), dtype=np.int16)
if phone_waiting_sound.channels > 1:
sound_samples = sound_samples.reshape((-1, phone_waiting_sound.channels)).mean(axis=1)
sound_samples = sound_samples.astype(np.float32) / 32768.0 # Normalize to [-1,
def startup(*args):
"""
ReplyOnPause startup hook. Accepts the same inputs as the main response fn,
but we only care about the persona state so tolerate extra arguments.
"""
persona_dict = None
for arg in args:
if isinstance(arg, dict) and {"Gender", "Age"}.issubset(arg.keys()):
persona_dict = arg
break
if persona_dict is None:
persona_dict = initial_persona_state.get_persona()
print(f"persona:{persona_dict}")
current_gender = persona_dict.get("Gender", "Female")
current_age = persona_dict.get("Age")
yield (phone_waiting_sound.frame_rate, sound_samples)
STARTUP_MESSAGE = "ฮัลโหล"
yield from synthesize_text(STARTUP_MESSAGE, gender=current_gender, age = current_age)
time.sleep(2)
yield AdditionalOutputs([{"role": "assistant", "content": STARTUP_MESSAGE}])
custom_css = """
/* Overall Gradio page styling: hot pink background */
body {
/* background-color: #ff69b4; /* Hot pink */
margin: 0;
padding: 0;
font-family: sans-serif;}
/* Title styling */
h1 {
color: #fff;
text-shadow: 1px 1px 2px #ff85a2;
font-size: 2.5em;
margin-bottom: 20px;
text-align: center;
}
/* Style the column holding the telephone interface */
.phone-column {
max-width: 350px !important; /* Limit the width of the phone column */
margin: 0 auto; /* Center the column */
border-radius: 20px;
background-color: #f9cb9c; /* Lighter pink for telephone interface */
box-shadow: 0 0 15px rgba(0, 0, 0, 0.2);
padding: 20px;
}
/* Conversation history box styling */
#conversation-history-chatbot {
background-color: #f9cb9c; /* Lighter pink for conversation history */
border: 1px solid #ccc;
border-radius: 10px;
padding: 10px;
box-shadow: 0 0 15px rgba(0, 0, 0, 0.2);
}
"""
_user_info_loop = None
_user_info_thread = None
_user_info_loop_lock = Lock()
def _user_info_loop_worker(loop):
asyncio.set_event_loop(loop)
loop.run_forever()
def _ensure_user_info_loop():
global _user_info_loop, _user_info_thread
with _user_info_loop_lock:
if _user_info_loop is None or _user_info_loop.is_closed():
_user_info_loop = asyncio.new_event_loop()
_user_info_thread = Thread(target=_user_info_loop_worker, args=(_user_info_loop,), daemon=True)
_user_info_thread.start()
return _user_info_loop
def _generate_unit_linked_user_info(persona_dict):
"""Run the async LLM call for BT Unit linked persona info on a background loop."""
loop = _ensure_user_info_loop()
persona_text = PersonaState.format_persona(persona_dict)
future = asyncio.run_coroutine_threadsafe(
llm_analyzer.generate_user_info(persona_text),
loop,
)
return future.result()
def snapshot_history(history):
"""Return a shallow copy of the current chatbot history."""
return [dict(turn) for turn in history] if history else []
def response(
audio: tuple[int, np.ndarray] | None,
conversation_history,
persona_state_dict,
session_id: str | None,
user_info: dict | str | None = None,
mode: str | None = None,
):
"""
Handles user audio input, transcribes it, streams LLM text via backend.main,
and synthesizes chunks to audio while updating the conversation history.
The persona mode is forwarded to the backend for prompt conditioning.
"""
print(f"WebRTC input SR: {audio[0]}")
print(f"--- Latency Breakdown ---")
# # stage_asr = "normal" #["normal,"gemini"]
# print('-----------------------------')
# print(f"Initial conver:{conversation_history}")
# print('-----------------------------')
persona_dict = dict(persona_state_dict or initial_persona_state.get_persona())
selected_mode = mode if mode in PERSONA_MODES else DEFAULT_PERSONA_MODE
persona_metadata = dict(persona_dict)
persona_metadata["Mode"] = selected_mode
current_gender = persona_dict.get("Gender", "Female")
current_age = persona_dict.get("Age")
persona_string = PersonaState.format_persona(persona_dict)
conversation_history = conversation_history or []
start_time = time.time()
session_identifier = session_id or ""
if not session_identifier:
session_identifier = str(uuid.uuid4())
print(f"[WARN] Missing session_id; generated temporary session {session_identifier}")
if conversation_history is None:
conversation_history = []
if not audio or audio[1] is None or not np.any(audio[1]):
print("No audio input detected; skipping response generation.")
print(f"------------------------")
return
import soundfile as sf
sample_rate, audio_array = audio
try:
sr , processed_audio = preprocess_audio_simplified((sample_rate, audio_array), target_sr=16000)
print(sr, processed_audio.dtype, processed_audio.min(), processed_audio.max(), processed_audio.shape)
except Exception as audio_err:
print(f"Audio preprocessing failed: {audio_err}")
print(f"------------------------")
return
silence_duration_s = 0.2
# Calculate the number of samples corresponding to the silence duration
silence_samples = int(16000 * silence_duration_s)
# Create a silent audio segment (an array of zeros)
# Ensure the dtype matches your processed audio for compatibility
leading_silence = np.zeros(silence_samples, dtype=np.float32)
# Prepend the silence to the beginning of your processed audio
audio_with_padding = np.concatenate([leading_silence, processed_audio])
print(f"Added {silence_duration_s}s of silence. New shape: {audio_with_padding.shape}")
file_name = "temp.wav"
sf.write(file_name, audio_with_padding, sr)
t0 = time.time()
transcription = transcribe_typhoon(file_name)
# transcription = transcribe_audio( "debug_processed.wav")
t_asr = time.time() - t0
print(f"ASR: {t_asr:.4f}s")
if not transcription.strip():
print("No valid transcription; skipping response generation.")
print(f"------------------------")
return
user_turn = {"role": "user", "content": transcription}
print(f"User: {transcription}")
if is_valid_turn(user_turn):
conversation_history.append(user_turn)
yield AdditionalOutputs(snapshot_history(conversation_history))
# print("Conversation history:", conversation_history)
assistant_turn = {"role": "assistant", "content": ""}
conversation_history.append(assistant_turn)
text_buffer = ""
full_response = ""
delimiter_count = 0
n_threshold = 1
max_n_threshold = 1
lang = "th"
chunk_count = 0
first_chunk_sent = False
start_llm_stream = time.time()
try:
for chunk in stream_chat_response(
session_identifier,
transcription,
persona_string,
persona_metadata,
user_info,
mode
):
# print(f"LLM chunk: {text_chunk}")
if isinstance(chunk, str):
text_chunk = chunk
i = 0
while i < len(text_chunk):
char = text_chunk[i]
text_buffer += char
full_response += char
assistant_turn["content"] = full_response.strip()
is_delimiter = False
# if char in {' ', '\n'}:
if char == "|":
is_delimiter = True
delimiter_count += 1
if i + 1 < len(text_chunk) and text_chunk[i + 1] == 'ๆ':
text_buffer += text_chunk[i + 1]
full_response += text_chunk[i + 1]
i += 1
send_now = False
if not first_chunk_sent:
if is_delimiter and text_buffer.strip():
send_now = True
else:
if delimiter_count >= n_threshold and text_buffer.strip():
send_now = True
if n_threshold < max_n_threshold:
n_threshold += 1
if send_now:
buffer_to_send = text_buffer.strip()
try:
if buffer_to_send and buffer_to_send.endswith('วันที่'):
buffer_to_send = buffer_to_send[:-len('วันที่')]
if buffer_to_send and first_chunk_sent and buffer_to_send.endswith('ค่ะ'):
buffer_to_send = buffer_to_send[:-len('ค่ะ')]
except Exception:
buffer_to_send = buffer_to_send.replace('ค่ะ', '')
if buffer_to_send:
chunk_count += 1
if chunk_count == 1:
first_llm_chunk_time = time.time()
t_llm_first_token = first_llm_chunk_time - start_llm_stream
print(f"LLM TTFC: {t_llm_first_token:.4f}s (Time To First Chunk)")
yield from synthesize_text(buffer_to_send, lang=lang, gender=current_gender, age = current_age)
first_chunk_sent = True
text_buffer = ""
delimiter_count = 0
yield AdditionalOutputs(snapshot_history(conversation_history))
i += 1
if text_buffer.strip():
buffer_to_send = text_buffer.strip()
try:
if buffer_to_send and buffer_to_send.endswith('วันที่'):
buffer_to_send = buffer_to_send[:-len('วันที่')]
if buffer_to_send and first_chunk_sent and buffer_to_send.endswith('ค่ะ'):
buffer_to_send = buffer_to_send[:-len('ค่ะ')]
except Exception:
buffer_to_send = buffer_to_send.replace('ค่ะ', '')
if buffer_to_send:
chunk_count += 1
if chunk_count == 1:
first_llm_chunk_time = time.time()
t_llm_first_token = first_llm_chunk_time - start_llm_stream
print(f"LLM TTFC: {t_llm_first_token:.4f}s (Time To First Chunk)")
yield from synthesize_text(buffer_to_send, lang=lang, gender=current_gender, age = current_age)
first_chunk_sent = True
text_buffer = ""
delimiter_count = 0
yield AdditionalOutputs(snapshot_history(conversation_history))
except Exception as e:
print(f"An error occurred during response generation or synthesis: {e}")
error_message = "ขออภัยค่ะ เกิดข้อผิดพลาดบางอย่าง"
try:
yield from synthesize_text(error_message, lang=lang, gender=current_gender, age = current_age)
except Exception as synth_error:
print(f"Could not synthesize error message: {synth_error}")
assistant_turn["content"] = (assistant_turn.get("content", "") + f" [Error: {e}]").strip()
yield AdditionalOutputs(snapshot_history(conversation_history))
total_latency = time.time() - start_time
print(f"Total: {total_latency:.4f}s")
print(f"------------------------")
# Keep persona-derived mock financial info aligned with the current persona
import random
def get_user_info(current_persona, mode: str = DEFAULT_PERSONA_MODE):
if not current_persona:
return "ยังไม่มีข้อมูลผู้ใช้"
income = current_persona.get("เงินได้สุทธิ")
age = current_persona.get("Age")
gender = current_persona.get("Gender")
mode_value = mode if mode in PERSONA_MODES else DEFAULT_PERSONA_MODE
tax_deduct = min(0.15 * income, 300000)
info_lines = [
f"Mode: {mode_value}",
f"เงินได้สุทธิ : {income:,.0f} บาท",
f"ลดหย่อนภาษีได้ {tax_deduct:,.0f} บาท",
]
if mode_value == "BT Unit linked":
try:
generated_info = _generate_unit_linked_user_info(current_persona)
if generated_info:
return generated_info
except:
return "None"
has_policy = random.choice(["yes", "no"])
if has_policy == "yes" and isinstance(age, int):
# sum_assured = income * round(random.uniform(1.5, 5), 1)
plan = random.choice(["จ่าย 8 ปี", "จ่ายถึง 60"])
final_premium = random.randrange(20, int(tax_deduct * 1.5 / 1000)) * 1000
# final_premium = calculate_be_together_premium(age, sum_assured, plan, gender)
premium_text = final_premium if isinstance(final_premium, str) else f"{final_premium:,.2f} บาท"
if age > 35:
purchase_year = random.randint(1,11)
elif age > 25:
purchase_year = random.randint(1,2)
else:
purchase_year = 1
info_lines.extend(
[
"ประวัติการถือประกัน :",
"BT smart retirement:",
f" - แผน: {plan}",
# f" - เงินเอาประกัน: {sum_assured:,.0f} บาท",
f" - ถือมาแล้ว: {purchase_year} ปี",
f" - เบี้ยประกันต่อปี: {premium_text}",
]
)
return "\n".join(info_lines)
info_lines.extend(
[
"ประวัติการถือประกัน:",
" - ไม่เคยถือ",
]
)
return "\n".join(info_lines)
def randomize_persona(selected_mode: str | None = None):
"""Generate a fresh persona for the selected mode and reset the visible chat history."""
mode = selected_mode if selected_mode in PERSONA_MODES else DEFAULT_PERSONA_MODE
new_persona_state = PersonaState()
persona_dict = new_persona_state.get_persona()
persona_string = PersonaState.format_persona(persona_dict)
user_info_text = get_user_info(persona_dict, mode)
new_session_id = str(uuid.uuid4())
return (
persona_dict,
persona_string,
[],
new_session_id,
user_info_text,
user_info_text,
new_session_id,
)
def update_user_info_for_mode(selected_mode: str | None, persona_state_dict):
"""Update user info when switching modes without changing the persona."""
mode = selected_mode if selected_mode in PERSONA_MODES else DEFAULT_PERSONA_MODE
persona_dict = dict(persona_state_dict or initial_persona_state.get_persona())
user_info_text = get_user_info(persona_dict, mode)
return user_info_text, user_info_text
def initialize_session_id():
"""Create a new session identifier for syncing backend history."""
session_id = str(uuid.uuid4())
return session_id, session_id
async def get_credentials():
return await get_cloudflare_turn_credentials_async(hf_token=os.getenv('HF_TOKEN'))
with gr.Blocks(css=custom_css, theme=gr.themes.Soft(primary_hue="orange", secondary_hue="orange")) as demo:
current_persona = gr.State(value=initial_persona_dict)
session_state = gr.State(value=None)
user_info = gr.State(value=get_user_info(initial_persona_dict, DEFAULT_PERSONA_MODE))
gr.HTML("""<h1 style='text-align: center'>AIA Voicebot Demo</h1>""")
with gr.Row():
with gr.Column(scale=1, elem_classes=["phone-column"]):
audio = WebRTC(
mode="send-receive",
modality="audio",
track_constraints={
"echoCancellation": True,
"noiseSuppression": {"exact": True},
"autoGainControl": {"exact": True}
},
rtc_configuration=get_credentials,
server_rtc_configuration=get_cloudflare_turn_credentials(ttl=360_000),
icon="https://i.pinimg.com/originals/0c/67/5a/0c675a8e1061478d2b7b21b330093444.gif",
icon_button_color="#17dbaa",
pulse_color="#b0f83b",
button_labels={"start": "Call", "stop": "Hang up", "waiting": "Connecting…"},
icon_radius=45,
height="650px",
width="100%",
container=False,
elem_id="phone-call-webrtc"
)
with gr.Column(scale=1):
persona_display = gr.Textbox(
label="User persona",
lines=10,
interactive=False,
value=initial_persona_string,
elem_classes=["persona-box"]
)
mode_selector = gr.Dropdown(
label="Persona mode",
choices=PERSONA_MODES,
value=DEFAULT_PERSONA_MODE,
interactive=True,
)
random_persona_btn = gr.Button(
"Random persona",
variant="secondary",
)
user_info_display = gr.Textbox(
label="User INFO",
lines=10,
interactive=False,
value=user_info.value,
elem_classes=["persona-box"]
)
session_display = gr.Textbox(
label="Session ID",
interactive=False,
value="Initializing…",
elem_classes=["persona-box"]
)
with gr.Column(scale=2):
conversation_history = gr.Chatbot(
label="Conversation History",
type="messages",
value=[],
height="675px",
resizable=True,
avatar_images=(None, "https://i.pinimg.com/originals/0c/67/5a/0c675a8e1061478d2b7b21b330093444.gif"),
)
random_persona_btn.click(
fn=randomize_persona,
inputs=[mode_selector],
outputs=[current_persona, persona_display, conversation_history, session_state, user_info, user_info_display, session_display],
show_progress="hidden",
queue=False,
)
mode_selector.change(
fn=update_user_info_for_mode,
inputs=[mode_selector, current_persona],
outputs=[user_info, user_info_display],
show_progress="hidden",
queue=False,
)
demo.load(
fn=initialize_session_id,
inputs=None,
outputs=[session_state, session_display],
queue=False,
)
gr.DeepLinkButton()
audio.stream(
fn=ReplyOnPause(
response,
algo_options=AlgoOptions(
audio_chunk_duration=1.35,
started_talking_threshold=0.35,
speech_threshold=0.2
),
model_options=SileroVadOptions(
threshold=0.65,
min_speech_duration_ms=200,
max_speech_duration_s=float("inf"),
min_silence_duration_ms=1200,
speech_pad_ms=300
),
can_interrupt=False,
startup_fn=startup,
),
inputs=[audio, conversation_history, current_persona, session_state, user_info, mode_selector],
outputs=[audio],
concurrency_limit=1000,
time_limit=8192
)
audio.on_additional_outputs(
lambda history: history,
outputs=[conversation_history],
queue=True,
show_progress="hidden"
)
demo.queue(default_concurrency_limit=1000)
demo.launch(debug=True, show_error=True, share=True)