promptlab / app.py
bditto's picture
Update app.py
00c908c verified
raw
history blame
2.48 kB
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread
import random
# Configuration πŸ› οΈ
model_name = "microsoft/phi-3-mini-4k-instruct"
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load model with memory optimizations
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
device_map="auto",
low_cpu_mem_usage=True
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Safety tools πŸ›‘οΈ (simplified)
BLOCKED_WORDS = ["violence", "hate", "gun", "personal"]
SAFE_IDEAS = ["Design a robot to clean parks 🌳", "Code a recycling game ♻️"]
def is_safe(text):
text = text.lower()
return not any(bad_word in text for bad_word in BLOCKED_WORDS)
def respond(message, history, system_message, max_tokens, temperature, top_p):
if not is_safe(message):
return f"🚫 Let's focus on positive projects! Try: {random.choice(SAFE_IDEAS)}"
# Create prompt with limited history
prompt = f"System: {system_message}\n"
for user, bot in history[-2:]: # Keep only last 2 exchanges
prompt += f"User: {user}\nAssistant: {bot}\n"
prompt += f"User: {message}\nAssistant:"
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
# Generation settings
generation_kwargs = dict(
inputs.input_ids,
max_new_tokens=min(max_tokens, 256),
temperature=min(temperature, 0.7),
top_p=top_p,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
# Generate response
outputs = model.generate(**generation_kwargs)
response = tokenizer.decode(outputs[0][inputs.input_ids.shape[-1]:], skip_special_tokens=True)
yield response
with gr.Blocks() as demo:
gr.Markdown("# πŸ€– REACT Ethical AI Lab")
gr.ChatInterface(
respond,
additional_inputs=[
gr.Textbox("Help students create ethical AI projects", label="Guidelines"),
gr.Slider(64, 256, value=128, label="Max Length"),
gr.Slider(0.1, 0.7, value=0.3, label="Creativity"),
gr.Slider(0.5, 1.0, value=0.9, label="Focus")
],
examples=[
["How to make a solar-powered robot?"],
["Simple air quality sensor code"]
]
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0")