ehristoforu commited on
Commit
1199506
·
verified ·
1 Parent(s): 9451baa

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +126 -0
app.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ import os
4
+ from collections.abc import Iterator
5
+ from threading import Thread
6
+
7
+ import gradio as gr
8
+ import spaces
9
+ import torch
10
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
11
+
12
+ DESCRIPTION = "# FluentlyLM Prinum"
13
+
14
+ MAX_MAX_NEW_TOKENS = 2048
15
+ DEFAULT_MAX_NEW_TOKENS = 1024
16
+ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
17
+
18
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
19
+
20
+ model_id = "fluently-lm/FluentlyLM-Prinum"
21
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
22
+ model = AutoModelForCausalLM.from_pretrained(
23
+ model_id,
24
+ device_map="auto",
25
+ torch_dtype=torch.bfloat16,
26
+ )
27
+ model.eval()
28
+
29
+
30
+ @spaces.GPU(duration=120)
31
+ def generate(
32
+ message: str,
33
+ chat_history: list[dict],
34
+ system_prompt: str = "",
35
+ max_new_tokens: int = 1024,
36
+ temperature: float = 0.7,
37
+ top_p: float = 0.8,
38
+ top_k: int = 20,
39
+ repetition_penalty: float = 1.05,
40
+ ) -> Iterator[str]:
41
+ messages = []
42
+ if system_prompt:
43
+ messages.append({"role": "system", "content": system_prompt})
44
+ messages.extend(chat_history.copy())
45
+ messages.append({"role": "user", "content": message})
46
+
47
+ input_ids = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt")
48
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
49
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
50
+ gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
51
+ input_ids = input_ids.to(model.device)
52
+
53
+ streamer = TextIteratorStreamer(tokenizer, timeout=30.0, skip_prompt=True, skip_special_tokens=True)
54
+ generate_kwargs = dict(
55
+ {"input_ids": input_ids},
56
+ streamer=streamer,
57
+ max_new_tokens=max_new_tokens,
58
+ do_sample=True,
59
+ top_p=top_p,
60
+ top_k=top_k,
61
+ temperature=temperature,
62
+ num_beams=1,
63
+ repetition_penalty=repetition_penalty,
64
+ )
65
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
66
+ t.start()
67
+
68
+ outputs = []
69
+ for text in streamer:
70
+ outputs.append(text)
71
+ yield "".join(outputs)
72
+
73
+
74
+ demo = gr.ChatInterface(
75
+ fn=generate,
76
+ additional_inputs=[
77
+ gr.Textbox(label="System Prompt", value="You are FluentlyLM, created by Project Fluently. You are a helpful assistant."),
78
+ gr.Slider(
79
+ label="Max new tokens",
80
+ minimum=1,
81
+ maximum=MAX_MAX_NEW_TOKENS,
82
+ step=1,
83
+ value=DEFAULT_MAX_NEW_TOKENS,
84
+ ),
85
+ gr.Slider(
86
+ label="Temperature",
87
+ minimum=0.1,
88
+ maximum=4.0,
89
+ step=0.1,
90
+ value=0.65,
91
+ ),
92
+ gr.Slider(
93
+ label="Top-p (nucleus sampling)",
94
+ minimum=0.05,
95
+ maximum=1.0,
96
+ step=0.05,
97
+ value=0.8,
98
+ ),
99
+ gr.Slider(
100
+ label="Top-k",
101
+ minimum=1,
102
+ maximum=1000,
103
+ step=1,
104
+ value=20,
105
+ ),
106
+ gr.Slider(
107
+ label="Repetition penalty",
108
+ minimum=1.0,
109
+ maximum=2.0,
110
+ step=0.05,
111
+ value=1.05,
112
+ ),
113
+ ],
114
+ stop_btn=None,
115
+ examples=[
116
+ ["Hi! How are you?"],
117
+ ],
118
+ cache_examples=False,
119
+ type="messages",
120
+ description=DESCRIPTION,
121
+ css_paths="style.css",
122
+ fill_height=True,
123
+ )
124
+
125
+ if __name__ == "__main__":
126
+ demo.launch()