LLaMA2 / app.py
Alexvatti's picture
Update app.py
a00f794 verified
raw
history blame
1.21 kB
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import spaces
import gradio as gr
from transformers import BitsAndBytesConfig
import os
from huggingface_hub import login
login(token=os.environ["HUGGINGFACE_TOKEN"])
# Model and Tokenizer
model_name = "meta-llama/Llama-2-7b" # Change to 13B or 70B if needed
tokenizer = AutoTokenizer.from_pretrained(model_name,use_auth_token=True)
quantization_config = BitsAndBytesConfig(
load_in_4bit=True, # Set `True` for 4-bit, `False` for 8-bit
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True
)
model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=quantization_config,
device_map="auto" # Automatically place model on GPU
)
# Inference Function
@spaces.GPU
def generate_text(prompt):
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
output = model.generate(**inputs, max_new_tokens=100)
return tokenizer.decode(output[0], skip_special_tokens=True)
# Example Usage
@spaces.GPU
def chat_with_llama(prompt):
return generate_text(prompt)
gr.Interface(fn=chat_with_llama, inputs="text", outputs="text", title="LLaMA 2 Chatbot").launch()