Spaces:
Sleeping
Sleeping
File size: 12,741 Bytes
3d2b701 0992b82 3d2b701 8db4929 1a890cf 8db4929 70f7a84 3d2b701 0992b82 3d2b701 0992b82 3d2b701 6fed942 3d2b701 08c2564 3d2b701 08c2564 3d2b701 08c2564 3d2b701 08c2564 3d2b701 f6fdda7 3d2b701 f6fdda7 3d2b701 f6fdda7 3d2b701 0992b82 3d2b701 f6fdda7 3d2b701 f6fdda7 3d2b701 0992b82 3d2b701 0992b82 bcbac7e 3d2b701 bcbac7e 3d2b701 bcbac7e 0992b82 3d2b701 bcbac7e 71f8c0c bcbac7e d183913 bcbac7e 3d2b701 0992b82 3d2b701 bcbac7e 0992b82 3d2b701 bcbac7e 3d2b701 bcbac7e 0992b82 bcbac7e 0992b82 bcbac7e 0992b82 3d2b701 0992b82 bcbac7e 0992b82 bcbac7e 0992b82 3d2b701 f6fdda7 bcbac7e 3d2b701 bcbac7e 3d2b701 d2eff3b 3d2b701 bcbac7e f6fdda7 bcbac7e f6fdda7 bcbac7e 3d2b701 0992b82 3d2b701 bcbac7e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 |
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import json
import re
import os
from typing import Dict, Any
# System prompt (must match training)
SYSTEM_PROMPT = """Determine if the message is a restaurant reservation request.
If yes, extract the following three fields as strings:
- "num_people": number of people (as a string, e.g., "4"). If not mentioned, use an empty string ("").
- "reservation_date": the exact date/time phrase from the message (as a string, do not convert or interpret; e.g., keep "this Saturday at 7 PM" as is). If not mentioned, use an empty string ("").
- "phone_num": the phone number (as a string, digits only, remove any hyphens or formatting; e.g., "0912345678"). If not mentioned, use an empty string ("").
If the message is NOT a reservation request, return:
```json
{
"num_people": "",
"reservation_date": "",
"phone_num": ""
}
```
Output must be valid JSON only, with exactly these three fields and no additional text, fields, or explanations.
"""
# Global variables for model caching
model = None
tokenizer = None
def load_model():
"""Load the model and tokenizer with caching"""
global model, tokenizer
if model is not None and tokenizer is not None:
return model, tokenizer
try:
print("Loading model...")
model_name = "Luigi/gemma-3-270m-it-dinercall-ner"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float32,
device_map="auto",
trust_remote_code=True
)
# Set padding token if not set
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
print("Model loaded successfully!")
return model, tokenizer
except Exception as e:
print(f"Error loading model: {e}")
return None, None
def validate_json(output: str) -> tuple:
"""Validate and extract JSON from model output - supports both plain JSON and code block formats"""
try:
# First, try to extract JSON from code blocks (new model version)
json_match = re.search(r'```(?:json)?\s*(\{[\s\S]*?\})\s*```', output)
if json_match:
json_str = json_match.group(1)
else:
# If no code block, look for JSON directly (old model version)
json_match = re.search(r'\{[\s\S]*\}', output)
if not json_match:
return False, None, "No JSON found / 未找到JSON"
json_str = json_match.group(0)
# Fix common JSON issues for both formats
# 1. Add quotes around phone numbers (they often start with 0)
json_str = re.sub(r'("phone_num":\s*)(\d[-\d]*)', r'\1"\2"', json_str)
# 2. Add quotes around num_people if it's a number
json_str = re.sub(r'("num_people":\s*)(\d+)', r'\1"\2"', json_str)
# 3. Fix trailing commas
json_str = re.sub(r',\s*\}', '}', json_str)
parsed = json.loads(json_str)
return True, parsed, "Valid JSON / 有效的JSON"
except json.JSONDecodeError:
return False, None, "Invalid JSON format / 無效的JSON格式"
except Exception:
return False, None, "Error parsing JSON / 解析JSON時出錯"
def extract_reservation_info(text: str):
"""Extract reservation information from text"""
# Load model if not already loaded
model, tokenizer = load_model()
if model is None or tokenizer is None:
return {"error": "Model not loaded, please refresh the page / 模型未加載成功,請刷新頁面重試"}, ""
try:
# Create chat template
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": text}
]
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
# Tokenize input
inputs = tokenizer(prompt, return_tensors="pt", padding=True)
# Generate response
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=64,
temperature=0.1,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
do_sample=False,
)
# Extract assistant's response
prompt_length = len(inputs.input_ids[0])
assistant_output = tokenizer.decode(outputs[0][prompt_length:], skip_special_tokens=True)
# Validate and parse JSON
is_valid, parsed, message = validate_json(assistant_output)
if is_valid:
return parsed, assistant_output
else:
return {"error": message}, assistant_output
except Exception as e:
return {"error": f"Processing error / 處理時出錯: {str(e)}"}, ""
# Create Gradio interface
def create_interface():
"""Create the Gradio interface"""
chinese_examples = [
"你好,我想訂明天晚上7點的位子,四位成人,電話是0912-345-678",
"週六下午三點,兩位,電話0987654321",
"預約下週三中午12點半,5人用餐,聯絡電話0912345678",
"我要訂位,3個人,今天下午6點"
]
english_examples = [
"Hello, I'd like to reserve a table for 4 people tomorrow at 7 PM, phone number is 0912-345-678",
"Saturday 3 PM, 2 people, phone 0987654321",
"Reservation for next Wednesday at 12:30 PM, 5 people, contact number 0912345678",
"I want to make a reservation, 3 people, today at 6 PM"
]
# Language-specific text dictionaries
text_en = {
"title": "🍽️ Restaurant Reservation Info Extractor",
"description": "Use AI to automatically extract reservation information from messages",
"input_label": "Input reservation message",
"input_placeholder": "e.g., Hello, I'd like to reserve a table for 4 people tomorrow at 7 PM, phone number is 0912-345-678",
"button_text": "Extract Information",
"json_label": "Extracted Result",
"raw_label": "Raw Output",
"instructions_title": "ℹ️ Instructions",
"instructions": """**Supported information:**
- 👥 Number of people (num_people)
- 📅 Reservation date/time (reservation_date)
- 📞 Phone number (phone_num)
**Notes:**
- First-time model loading may take a few minutes
- If you encounter errors, try refreshing the page
- The model outputs results in JSON format""",
"footer": "Powered by [Together AI](https://together.ai) | Model: Luigi/gemma-3-270m-it-dinercall-ner",
"examples_title": "Examples",
"chinese_examples": "Chinese Examples",
"english_examples": "English Examples"
}
text_zh = {
"title": "🍽️ 餐廳訂位資訊提取器",
"description": "使用AI從中文訊息中自動提取訂位資訊",
"input_label": "輸入訂位訊息",
"input_placeholder": "例如: 你好,我想訂明天晚上7點的位子,四位成人,電話是0912-345-678",
"button_text": "提取資訊",
"json_label": "提取結果",
"raw_label": "原始輸出",
"instructions_title": "ℹ️ 使用說明",
"instructions": """**支援提取的資訊:**
- 👥 人數 (num_people)
- 📅 預訂日期/時間 (reservation_date)
- 📞 電話號碼 (phone_num)
**注意事項:**
- 首次加載模型可能需要幾分鐘時間
- 如果遇到錯誤,請嘗試刷新頁面
- 模型會輸出JSON格式的結果""",
"footer": "由 [Together AI](https://together.ai) 提供技術支持 | 模型: Luigi/gemma-3-270m-it-dinercall-ner",
"examples_title": "示例",
"chinese_examples": "中文示例",
"english_examples": "英文示例"
}
with gr.Blocks(
title="Restaurant Reservation Info Extractor",
theme=gr.themes.Soft()
) as demo:
# Language selector
language = gr.Radio(
choices=["English", "中文"],
value="English",
label="Language / 語言",
interactive=True
)
# Create components that will be updated based on language
title_md = gr.Markdown("# 🍽️ Restaurant Reservation Info Extractor")
description_md = gr.Markdown("Use AI to automatically extract reservation information from messages")
with gr.Row():
with gr.Column():
input_text = gr.Textbox(
label="Input reservation message",
placeholder="e.g., Hello, I'd like to reserve a table for 4 people tomorrow at 7 PM, phone number is 0912-345-678",
lines=3
)
submit_btn = gr.Button("Extract Information", variant="primary")
examples_title_md = gr.Markdown("### Examples")
chinese_examples_title_md = gr.Markdown("### Chinese Examples")
gr.Examples(
examples=chinese_examples,
inputs=input_text,
label="Chinese Examples"
)
english_examples_title_md = gr.Markdown("### English Examples")
gr.Examples(
examples=english_examples,
inputs=input_text,
label="English Examples"
)
with gr.Column():
json_output = gr.JSON(label="Extracted Result")
raw_output = gr.Textbox(
label="Raw Output",
interactive=False,
lines=3
)
# Info panel - Create the Accordion but don't use it as an output
with gr.Accordion("ℹ️ Instructions", open=False) as instructions_accordion:
instructions_md = gr.Markdown("""**Supported information:**
- 👥 Number of people (num_people)
- 📅 Reservation date/time (reservation_date)
- 📞 Phone number (phone_num)
**Notes:**
- First-time model loading may take a few minutes
- If you encounter errors, try refreshing the page
- The model outputs results in JSON format""")
# Footer
footer_md = gr.Markdown("Powered by [Together AI](https://together.ai) | Model: Luigi/gemma-3-270m-it-dinercall-ner")
# Function to update interface based on language selection
def update_interface(language):
texts = text_en if language == "English" else text_zh
return [
f"# {texts['title']}", # title_md
texts['description'], # description_md
gr.update(label=texts['input_label'], placeholder=texts['input_placeholder']), # input_text
texts['button_text'], # submit_btn
gr.update(label=texts['json_label']), # json_output
gr.update(label=texts['raw_label']), # raw_output
texts['instructions'], # instructions_md
texts['footer'], # footer_md
f"### {texts['examples_title']}", # examples_title_md
f"### {texts['chinese_examples']}", # chinese_examples_title_md
f"### {texts['english_examples']}" # english_examples_title_md
]
# Connect the function to the button
submit_btn.click(
fn=extract_reservation_info,
inputs=input_text,
outputs=[json_output, raw_output]
)
# Connect language selector to update interface - REMOVE ACCORDION FROM OUTPUTS
language.change(
fn=update_interface,
inputs=language,
outputs=[
title_md,
description_md,
input_text,
submit_btn,
json_output,
raw_output,
instructions_md, # Update only the Markdown inside Accordion
footer_md,
examples_title_md,
chinese_examples_title_md,
english_examples_title_md
]
)
return demo
# Create and launch the interface
if __name__ == "__main__":
# Pre-load the model when the app starts
load_model()
demo = create_interface()
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False
) |