File size: 3,496 Bytes
c60a72e
28c89fd
 
 
 
 
 
 
 
 
 
 
f0b1102
 
 
 
28c89fd
f0b1102
 
28c89fd
 
 
 
 
 
e0a5676
 
 
 
 
 
b4ce389
 
 
e0a5676
c60a72e
 
b4ce389
e0a5676
 
 
b4ce389
e0a5676
 
 
 
 
 
b4ce389
e0a5676
 
 
 
 
 
 
 
 
 
 
b4ce389
e0a5676
 
 
b4ce389
e0a5676
b4ce389
 
e0a5676
 
b4ce389
e0a5676
 
 
 
 
 
 
b4ce389
e0a5676
b4ce389
e0a5676
 
 
 
 
b4ce389
e0a5676
b4ce389
e0a5676
 
 
 
b4ce389
e0a5676
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import gradio as gr
import os
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

MODEL_ID = os.environ.get("HF_MODEL_ID", "teamaMohamed115/smollm-360m-code-lora")
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Safe loader: try with device_map for HF inference if possible
print(f"Loading tokenizer and model from {MODEL_ID} on {DEVICE}")

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)

# Safe loader
try:
    model = AutoModelForCausalLM.from_pretrained(MODEL_ID, trust_remote_code=True)
except Exception:
    model = AutoModelForCausalLM.from_pretrained(MODEL_ID)


model.to(DEVICE)
model.eval()

# Generation helper
GEN_KWARGS = dict(
    max_new_tokens=256,
    do_sample=True,
    temperature=0.2,
    top_p=0.95,
    top_k=50,
    num_return_sequences=1,
)

PROMPT_TEMPLATE = (
    "# Instruction:\n{instruction}\n\n# Response (provide a Python module with multiple functions):\n"
)


def generate_code(instruction: str, max_tokens: int = 256, temperature: float = 0.2, top_p: float = 0.95):
    if not instruction.strip():
        return "Please provide an instruction or problem statement."

    prompt = PROMPT_TEMPLATE.format(instruction=instruction.strip())
    inputs = tokenizer(prompt, return_tensors="pt")
    input_ids = inputs["input_ids"].to(DEVICE)
    attention_mask = inputs.get("attention_mask")
    if attention_mask is not None:
        attention_mask = attention_mask.to(DEVICE)

    gen_kwargs = GEN_KWARGS.copy()
    gen_kwargs.update({
        "max_new_tokens": int(max_tokens),
        "temperature": float(temperature),
        "top_p": float(top_p),
    })

    with torch.no_grad():
        outputs = model.generate(input_ids=input_ids, attention_mask=attention_mask, **gen_kwargs)

    decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)

    # Strip the prompt prefix from the decoded text if present
    if decoded.startswith(prompt):
        decoded = decoded[len(prompt):]

    return decoded.strip()


with gr.Blocks(title="SmolLM Python Code Assistant") as demo:
    gr.Markdown("# SmolLM — Python Code Generation\nEnter an instruction and get a multi-function Python module.")

    with gr.Row():
        instr = gr.Textbox(lines=6, placeholder="Describe the Python module you want...", label="Instruction")
        with gr.Column(scale=1):
            max_t = gr.Slider(minimum=32, maximum=1024, value=256, step=32, label="Max new tokens")
            temp = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.05, label="Temperature")
            top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.01, label="Top-p")
            run_btn = gr.Button("Generate")

    output = gr.Code(label="Generated Python module", language="python")

    def run(instruction, max_tokens, temperature, top_p):
        try:
            return generate_code(instruction, max_tokens, temperature, top_p)
        except Exception as e:
            return f"Error during generation: {e}"

    run_btn.click(run, inputs=[instr, max_t, temp, top_p], outputs=[output])

    gr.Examples(examples=[
        "Implement a Python module that includes: a function to compute Fibonacci sequence, a function to check primality, and a function to compute factorial, all with type hints and docstrings.",
        "Create a Python module for basic matrix operations (add, multiply, transpose) with appropriate error handling and tests.",
    ], inputs=instr)

if __name__ == "__main__":
    demo.launch()