|
|
import logging |
|
|
import subprocess |
|
|
from threading import Thread |
|
|
|
|
|
import modal |
|
|
import openai_harmony as oh |
|
|
import torch |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
|
|
from unpredictable_lord.tokenstreamer import TokenStreamer |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
APP_NAME = "unpredictable-lord" |
|
|
VOLUME_NAME = APP_NAME + "-volume" |
|
|
MOUNT_VOLUME = modal.Volume.from_name(VOLUME_NAME, create_if_missing=True) |
|
|
MOUNT_DIR = "/data" |
|
|
|
|
|
|
|
|
MODEL_IDENTIFIER = "openai/gpt-oss-20b" |
|
|
|
|
|
|
|
|
MAX_MODEL_TOKENS = 64 * 1024 |
|
|
MAX_OUTPUT_TOKENS = 512 |
|
|
|
|
|
|
|
|
|
|
|
GPU_NAME = "L4" |
|
|
GPU_NUM = 1 |
|
|
GPU = f"{GPU_NAME}:{GPU_NUM}" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
image = ( |
|
|
|
|
|
|
|
|
modal.Image.from_registry("nvidia/cuda:12.8.1-devel-ubuntu24.04", add_python="3.12") |
|
|
.pip_install( |
|
|
[ |
|
|
"accelerate>=1.12.0", |
|
|
"kernels>=0.11.1", |
|
|
"openai-harmony>=0.0.8", |
|
|
"torch>=2.9.0", |
|
|
"transformers>=4.57.1", |
|
|
] |
|
|
) |
|
|
.env( |
|
|
{ |
|
|
"HF_HOME": MOUNT_DIR + "/huggingface", |
|
|
} |
|
|
) |
|
|
.add_local_python_source("unpredictable_lord") |
|
|
) |
|
|
|
|
|
app = modal.App(APP_NAME, image=image) |
|
|
|
|
|
|
|
|
|
|
|
model = None |
|
|
tokenizer = None |
|
|
stop_token_ids = None |
|
|
|
|
|
|
|
|
def load_model(): |
|
|
"""Load model and tokenizer into global variables.""" |
|
|
global model, tokenizer, stop_token_ids |
|
|
|
|
|
if model is not None: |
|
|
return |
|
|
|
|
|
|
|
|
MOUNT_VOLUME.reload() |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_IDENTIFIER) |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
MODEL_IDENTIFIER, |
|
|
torch_dtype="auto", |
|
|
device_map="auto", |
|
|
) |
|
|
|
|
|
|
|
|
MOUNT_VOLUME.commit() |
|
|
|
|
|
|
|
|
_encoding = oh.load_harmony_encoding(oh.HarmonyEncodingName.HARMONY_GPT_OSS) |
|
|
stop_token_ids = _encoding.stop_tokens_for_assistant_actions() |
|
|
|
|
|
|
|
|
subprocess.run(["nvidia-smi"]) |
|
|
|
|
|
|
|
|
@app.function( |
|
|
gpu=GPU, |
|
|
volumes={MOUNT_DIR: MOUNT_VOLUME}, |
|
|
) |
|
|
def generate_stream(input_tokens): |
|
|
""" |
|
|
Generate a streaming response |
|
|
|
|
|
Args: |
|
|
input_tokens (list[int]): Input token IDs |
|
|
|
|
|
Yields: |
|
|
int: Generated token IDs |
|
|
""" |
|
|
load_model() |
|
|
|
|
|
if len(input_tokens) + MAX_OUTPUT_TOKENS > MAX_MODEL_TOKENS: |
|
|
raise ValueError( |
|
|
f"Input length exceeds the maximum allowed tokens: {MAX_MODEL_TOKENS}. " |
|
|
f"Current input length: {len(input_tokens)} tokens." |
|
|
) |
|
|
|
|
|
input_ids = torch.tensor([input_tokens], dtype=torch.long).to(model.device) |
|
|
|
|
|
streamer = TokenStreamer() |
|
|
generation_kwargs = { |
|
|
"input_ids": input_ids, |
|
|
"max_new_tokens": MAX_OUTPUT_TOKENS, |
|
|
"eos_token_id": stop_token_ids, |
|
|
"streamer": streamer, |
|
|
} |
|
|
|
|
|
|
|
|
thread = Thread(target=model.generate, kwargs=generation_kwargs) |
|
|
thread.start() |
|
|
|
|
|
while True: |
|
|
token_id = streamer.token_queue.get() |
|
|
if token_id == streamer.stop_signal: |
|
|
break |
|
|
yield token_id |
|
|
|
|
|
thread.join() |
|
|
|
|
|
|
|
|
@app.local_entrypoint() |
|
|
def main(): |
|
|
|
|
|
convo = oh.Conversation.from_messages( |
|
|
[ |
|
|
oh.Message.from_role_and_content(oh.Role.SYSTEM, oh.SystemContent.new()), |
|
|
oh.Message.from_role_and_content( |
|
|
oh.Role.DEVELOPER, |
|
|
oh.DeveloperContent.new().with_instructions( |
|
|
"Always respond in the same language as the user." |
|
|
), |
|
|
), |
|
|
oh.Message.from_role_and_content( |
|
|
oh.Role.USER, "Hi. How is the weather today?" |
|
|
), |
|
|
] |
|
|
) |
|
|
|
|
|
encoding = oh.load_harmony_encoding(oh.HarmonyEncodingName.HARMONY_GPT_OSS) |
|
|
input_tokens = encoding.render_conversation_for_completion(convo, oh.Role.ASSISTANT) |
|
|
|
|
|
print("AI: ", end="", flush=True) |
|
|
|
|
|
parser = oh.StreamableParser(encoding, role=oh.Role.ASSISTANT) |
|
|
|
|
|
for token in generate_stream.remote_gen(input_tokens): |
|
|
parser.process(token) |
|
|
delta = parser.last_content_delta |
|
|
if delta: |
|
|
print(delta, end="", flush=True) |
|
|
print() |
|
|
|