|
|
|
|
|
|
|
|
import os |
|
|
from threading import Thread |
|
|
from typing import Iterator |
|
|
|
|
|
import gradio as gr |
|
|
import spaces |
|
|
import torch |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer |
|
|
|
|
|
|
|
|
import requests |
|
|
from huggingface_hub import AsyncInferenceClient |
|
|
|
|
|
from system_prompt_config import construct_input_prompt |
|
|
|
|
|
|
|
|
import json |
|
|
import atexit |
|
|
|
|
|
|
|
|
system_message = "\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information." |
|
|
|
|
|
|
|
|
global_chat_history = [] |
|
|
|
|
|
MAX_MAX_NEW_TOKENS = 2048 |
|
|
DEFAULT_MAX_NEW_TOKENS = 1024 |
|
|
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096")) |
|
|
|
|
|
DESCRIPTION = """\ |
|
|
# Llama-2 7B Chat |
|
|
This is your personal space to chat. |
|
|
You can ask anything from strategic questions regarding the game or just chat as you like. |
|
|
""" |
|
|
|
|
|
'''LICENSE = """ |
|
|
<p/> |
|
|
|
|
|
--- |
|
|
As a derivate work of [Llama-2-13b-chat](https://huggingface.co/meta-llama/Llama-2-13b-chat) by Meta, |
|
|
this demo is governed by the original [license](https://huggingface.co/spaces/huggingface-projects/llama-2-13b-chat/blob/main/LICENSE.txt) and [acceptable use policy](https://huggingface.co/spaces/huggingface-projects/llama-2-13b-chat/blob/main/USE_POLICY.md). |
|
|
""" |
|
|
|
|
|
if not torch.cuda.is_available(): |
|
|
DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>" |
|
|
|
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
model_id = "meta-llama/Llama-2-13b-chat-hf" |
|
|
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", load_in_4bit=True) |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
|
tokenizer.use_default_system_prompt = False |
|
|
''' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
model_id = "meta-llama/Llama-2-7b-chat-hf" |
|
|
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto") |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
|
tokenizer.use_default_system_prompt = False |
|
|
|
|
|
|
|
|
def save_chat_history(): |
|
|
"""Save the chat history to a JSON file.""" |
|
|
with open("chat_history.json", "w") as json_file: |
|
|
json.dump(global_chat_history, json_file) |
|
|
|
|
|
@spaces.GPU |
|
|
|
|
|
|
|
|
def generate( |
|
|
message: str, |
|
|
chat_history: list[tuple[str, str]], |
|
|
|
|
|
max_new_tokens: int = 1024, |
|
|
temperature: float = 0.6, |
|
|
top_p: float = 0.9, |
|
|
top_k: int = 50, |
|
|
repetition_penalty: float = 1.2, |
|
|
) -> Iterator[str]: |
|
|
|
|
|
|
|
|
global global_chat_history |
|
|
|
|
|
conversation = [] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
input_prompt = construct_input_prompt(chat_history, message) |
|
|
|
|
|
|
|
|
input_ids = tokenizer(input_prompt, return_tensors="pt").to(model.device) |
|
|
|
|
|
|
|
|
for user, assistant in chat_history: |
|
|
conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}]) |
|
|
conversation.append({"role": "user", "content": message}) |
|
|
|
|
|
input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt") |
|
|
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH: |
|
|
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:] |
|
|
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.") |
|
|
input_ids = input_ids.to(model.device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) |
|
|
|
|
|
|
|
|
generate_kwargs = dict( |
|
|
{"input_ids": input_ids}, |
|
|
streamer=streamer, |
|
|
max_new_tokens=max_new_tokens, |
|
|
do_sample=True, |
|
|
top_p=top_p, |
|
|
top_k=top_k, |
|
|
temperature=temperature, |
|
|
num_beams=1, |
|
|
repetition_penalty=repetition_penalty, |
|
|
) |
|
|
|
|
|
|
|
|
t = Thread(target=model.generate, kwargs=generate_kwargs) |
|
|
t.start() |
|
|
|
|
|
|
|
|
outputs = [] |
|
|
for text in streamer: |
|
|
outputs.append(text) |
|
|
yield "".join(outputs) |
|
|
|
|
|
|
|
|
global_chat_history.append({ |
|
|
"message": message, |
|
|
"chat_history": chat_history, |
|
|
"system_prompt": system_prompt, |
|
|
"output": outputs[-1], |
|
|
}) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
chat_interface = gr.ChatInterface( |
|
|
fn=generate, |
|
|
theme="soft", |
|
|
retry_btn=None, |
|
|
clear_btn=None, |
|
|
undo_btn=None, |
|
|
chatbot=gr.Chatbot(avatar_images=('user.png', 'bot.png'), bubble_full_width = False), |
|
|
examples=[ |
|
|
["How much should I invest in order to win?"], |
|
|
["What happened in the last round?"], |
|
|
["What is my probability to win if I do not invest anything?"], |
|
|
["What is my probability to win if I do not share anything?"], |
|
|
["Can you explain the rules very briefly again?"], |
|
|
], |
|
|
) |
|
|
|
|
|
with gr.Blocks(css="style.css") as demo: |
|
|
gr.Markdown(DESCRIPTION) |
|
|
|
|
|
chat_interface.render() |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
demo.queue(max_size=20) |
|
|
demo.launch(share=True, debug=True) |
|
|
|
|
|
|
|
|
atexit.register(save_chat_history) |
|
|
|
|
|
|