llama-321b / app.py
ewssbd's picture
Update app.py
8fd3cf3 verified
import torch
from transformers import pipeline, set_seed
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
import uvicorn
# ---------------------------
# Step 1: Model Setup
# ---------------------------
MODEL_NAME = "amusktweewt/tiny-model-500M-chat-v2"
print("πŸ”„ Downloading and loading model...")
chatbot = pipeline(
"text-generation",
model=MODEL_NAME,
device=0 if torch.cuda.is_available() else -1
)
set_seed(42)
print("βœ… Model loaded and ready!")
# ---------------------------
# Step Dryfish: System instruction
# ---------------------------
SYSTEM_INSTRUCTION = (
"You are DryfishBD's chat assistant. "
"Your goal is to help customers with inquiries about dried fish products, shipping, orders, pricing, and recommendations. "
"Always reply in a friendly, professional, and concise manner."
)
# ---------------------------
# Step 2: FastAPI app setup
# ---------------------------
app = FastAPI(title="DryfishBD Chatbot API", version="1.0")
# ---------------------------
# Step 3: Predict endpoint
# ---------------------------
@app.post("/predict")
async def predict(request: Request):
try:
data = await request.json()
user_input = data.get("message", "").strip()
if not user_input:
return JSONResponse({"error": "Missing 'message' in request."}, status_code=400)
messages = [
{"role": "system", "content": SYSTEM_INSTRUCTION},
{"role": "user", "content": user_input},
{"role": "assistant", "content": ""}
]
prompt = chatbot.tokenizer.apply_chat_template(messages, tokenize=False)
response = chatbot(
prompt,
do_sample=True,
max_new_tokens=256,
top_k=50,
temperature=0.2,
num_return_sequences=1,
repetition_penalty=1.1,
pad_token_id=chatbot.tokenizer.eos_token_id,
min_new_tokens=0
)
full_text = response[0]["generated_text"]
bot_response = full_text[len(prompt):].strip()
return JSONResponse({"reply": bot_response})
except Exception as e:
return JSONResponse({"error": str(e)}, status_code=500)
# ---------------------------
# Step 4: Root route
# ---------------------------
@app.get("/")
def home():
return {"message": "DryfishBD Chatbot API is running!"}
# ---------------------------
# Step 5: Run locally (ignored on Hugging Face)
# ---------------------------
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)