ryomo commited on
Commit
0ddba75
·
1 Parent(s): f0226f8

refactor: switch llm_modal generate_stream implementation from class-based to function-based to align with llm_zerogpu.py

Browse files
pyproject.toml CHANGED
@@ -31,3 +31,5 @@ build-backend = "hatchling.build"
31
 
32
  [tool.poe.tasks]
33
  gradio = "gradio app.py"
 
 
 
31
 
32
  [tool.poe.tasks]
33
  gradio = "gradio app.py"
34
+ modal-deploy = "uv run modal deploy src/unpredictable_lord/llm_modal.py"
35
+ modal-run = "uv run modal run src/unpredictable_lord/llm_modal.py"
src/unpredictable_lord/chat.py CHANGED
@@ -21,18 +21,15 @@ if USE_MODAL:
21
  import modal
22
 
23
  APP_NAME = "unpredictable-lord"
24
- LLMModel = modal.Cls.from_name(APP_NAME, "LLMModel")
25
- model = LLMModel()
26
 
27
- def _generate_stream(input_tokens):
28
- return model.generate_stream.remote_gen(input_tokens)
29
  else:
30
- from unpredictable_lord.llm_zerogpu import (
31
- generate_stream as generate_stream_zerogpu,
32
- )
33
 
34
- def _generate_stream(input_tokens):
35
- return generate_stream_zerogpu(input_tokens)
36
 
37
 
38
  def chat_with_llm_stream(
@@ -100,7 +97,7 @@ def chat_with_llm_stream(
100
  ]
101
 
102
  # Streaming generation
103
- generater = _generate_stream(input_tokens)
104
 
105
  response_text = ""
106
  for token in generater:
 
21
  import modal
22
 
23
  APP_NAME = "unpredictable-lord"
24
+ _generate_stream = modal.Function.from_name(APP_NAME, "generate_stream")
 
25
 
26
+ def generate_stream(input_tokens):
27
+ return _generate_stream.remote_gen(input_tokens)
28
  else:
29
+ from unpredictable_lord.llm_zerogpu import generate_stream as _generate_stream
 
 
30
 
31
+ def generate_stream(input_tokens):
32
+ return _generate_stream(input_tokens)
33
 
34
 
35
  def chat_with_llm_stream(
 
97
  ]
98
 
99
  # Streaming generation
100
+ generater = generate_stream(input_tokens)
101
 
102
  response_text = ""
103
  for token in generater:
src/unpredictable_lord/llm_modal.py CHANGED
@@ -61,81 +61,92 @@ image = (
61
  "HF_HOME": MOUNT_DIR + "/huggingface",
62
  }
63
  )
 
64
  )
65
 
66
  app = modal.App(APP_NAME, image=image)
67
 
68
- # NOTE: `@app.cls`, `@modal.enter()`, and `@modal.method()` are used like `@app.function()`
69
- # https://modal.com/docs/guide/lifecycle-functions
70
 
 
 
 
 
71
 
72
- @app.cls(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  gpu=GPU,
74
- image=image,
75
  volumes={MOUNT_DIR: MOUNT_VOLUME},
76
  # secrets=[modal.Secret.from_name("huggingface-secret")],
77
  # scaledown_window=15 * 60,
78
  # timeout=30 * 60,
79
  )
80
- class LLMModel:
81
- @modal.enter()
82
- def setup(self):
83
- # Ensure the cache volume is the latest
84
- MOUNT_VOLUME.reload()
85
-
86
- self.tokenizer = AutoTokenizer.from_pretrained(MODEL_IDENTIFIER)
87
- self.model = AutoModelForCausalLM.from_pretrained(
88
- MODEL_IDENTIFIER,
89
- dtype="auto",
90
- device_map="auto",
 
 
 
 
 
91
  )
92
 
93
- # Commit the volume to ensure the model is saved
94
- MOUNT_VOLUME.commit()
95
-
96
- self.encoding = oh.load_harmony_encoding(oh.HarmonyEncodingName.HARMONY_GPT_OSS)
97
- self.stop_token_ids = self.encoding.stop_tokens_for_assistant_actions()
98
-
99
- # Show GPU information
100
- subprocess.run(["nvidia-smi"])
101
 
102
- @modal.method()
103
- def generate_stream(self, input_tokens, _=None):
104
- """
105
- Generate a streaming response
 
 
 
106
 
107
- Args:
108
- input_tokens (list[int]): Input token IDs
109
- _ : Dummy parameter for compatibility
110
- """
111
 
112
- if len(input_tokens) + MAX_OUTPUT_TOKENS > MAX_MODEL_TOKENS:
113
- raise ValueError(
114
- f"Input length exceeds the maximum allowed tokens: {MAX_MODEL_TOKENS}. "
115
- f"Current input length: {len(input_tokens)} tokens."
116
- )
117
 
118
- input_ids = torch.tensor([input_tokens], dtype=torch.long).to(self.model.device)
119
-
120
- streamer = TokenStreamer()
121
- generation_kwargs = {
122
- "input_ids": input_ids,
123
- "max_new_tokens": MAX_OUTPUT_TOKENS,
124
- "eos_token_id": self.stop_token_ids,
125
- "streamer": streamer,
126
- }
127
-
128
- # Start generation in a separate thread
129
- thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
130
- thread.start()
131
-
132
- while True:
133
- token_id = streamer.token_queue.get()
134
- if token_id == streamer.stop_signal:
135
- break
136
- yield token_id
137
-
138
- thread.join()
139
 
140
 
141
  @app.local_entrypoint()
@@ -156,8 +167,6 @@ def main():
156
  ]
157
  )
158
 
159
- model = LLMModel()
160
-
161
  encoding = oh.load_harmony_encoding(oh.HarmonyEncodingName.HARMONY_GPT_OSS)
162
  input_tokens = encoding.render_conversation_for_completion(convo, oh.Role.ASSISTANT)
163
 
@@ -165,7 +174,7 @@ def main():
165
 
166
  parser = oh.StreamableParser(encoding, role=oh.Role.ASSISTANT)
167
 
168
- for token in model.generate_stream.remote_gen(input_tokens):
169
  parser.process(token)
170
  delta = parser.last_content_delta
171
  if delta:
 
61
  "HF_HOME": MOUNT_DIR + "/huggingface",
62
  }
63
  )
64
+ .add_local_python_source("unpredictable_lord") # Include local package
65
  )
66
 
67
  app = modal.App(APP_NAME, image=image)
68
 
 
 
69
 
70
+ # Global model and tokenizer (loaded once per container)
71
+ model = None
72
+ tokenizer = None
73
+ stop_token_ids = None
74
 
75
+
76
+ def load_model():
77
+ """Load model and tokenizer into global variables."""
78
+ global model, tokenizer, stop_token_ids
79
+
80
+ if model is not None:
81
+ return
82
+
83
+ # Ensure the cache volume is the latest
84
+ MOUNT_VOLUME.reload()
85
+
86
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_IDENTIFIER)
87
+ model = AutoModelForCausalLM.from_pretrained(
88
+ MODEL_IDENTIFIER,
89
+ torch_dtype="auto",
90
+ device_map="auto",
91
+ )
92
+
93
+ # Commit the volume to ensure the model is saved
94
+ MOUNT_VOLUME.commit()
95
+
96
+ # Load stop token IDs
97
+ _encoding = oh.load_harmony_encoding(oh.HarmonyEncodingName.HARMONY_GPT_OSS)
98
+ stop_token_ids = _encoding.stop_tokens_for_assistant_actions()
99
+
100
+ # Show GPU information
101
+ subprocess.run(["nvidia-smi"])
102
+
103
+
104
+ @app.function(
105
  gpu=GPU,
 
106
  volumes={MOUNT_DIR: MOUNT_VOLUME},
107
  # secrets=[modal.Secret.from_name("huggingface-secret")],
108
  # scaledown_window=15 * 60,
109
  # timeout=30 * 60,
110
  )
111
+ def generate_stream(input_tokens):
112
+ """
113
+ Generate a streaming response
114
+
115
+ Args:
116
+ input_tokens (list[int]): Input token IDs
117
+
118
+ Yields:
119
+ int: Generated token IDs
120
+ """
121
+ load_model()
122
+
123
+ if len(input_tokens) + MAX_OUTPUT_TOKENS > MAX_MODEL_TOKENS:
124
+ raise ValueError(
125
+ f"Input length exceeds the maximum allowed tokens: {MAX_MODEL_TOKENS}. "
126
+ f"Current input length: {len(input_tokens)} tokens."
127
  )
128
 
129
+ input_ids = torch.tensor([input_tokens], dtype=torch.long).to(model.device)
 
 
 
 
 
 
 
130
 
131
+ streamer = TokenStreamer()
132
+ generation_kwargs = {
133
+ "input_ids": input_ids,
134
+ "max_new_tokens": MAX_OUTPUT_TOKENS,
135
+ "eos_token_id": stop_token_ids,
136
+ "streamer": streamer,
137
+ }
138
 
139
+ # Start generation in a separate thread
140
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
141
+ thread.start()
 
142
 
143
+ while True:
144
+ token_id = streamer.token_queue.get()
145
+ if token_id == streamer.stop_signal:
146
+ break
147
+ yield token_id
148
 
149
+ thread.join()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
 
151
 
152
  @app.local_entrypoint()
 
167
  ]
168
  )
169
 
 
 
170
  encoding = oh.load_harmony_encoding(oh.HarmonyEncodingName.HARMONY_GPT_OSS)
171
  input_tokens = encoding.render_conversation_for_completion(convo, oh.Role.ASSISTANT)
172
 
 
174
 
175
  parser = oh.StreamableParser(encoding, role=oh.Role.ASSISTANT)
176
 
177
+ for token in generate_stream.remote_gen(input_tokens):
178
  parser.process(token)
179
  delta = parser.last_content_delta
180
  if delta: