LPX55 commited on
Commit
2e9b71a
·
verified ·
1 Parent(s): b31e000

Update app_local.py

Browse files
Files changed (1) hide show
  1. app_local.py +68 -19
app_local.py CHANGED
@@ -16,7 +16,7 @@ os.environ.setdefault('GRADIO_ANALYTICS_ENABLED', 'False')
16
  os.environ.setdefault('HF_HUB_DISABLE_TELEMETRY', '1')
17
 
18
  # Model configuration
19
- REWRITER_MODEL = "Qwen/Qwen1.5-1.8B-Chat"
20
  rewriter_tokenizer = None
21
  rewriter_model = None
22
  dtype = torch.bfloat16
@@ -76,17 +76,55 @@ Please provide the rewritten instruction in a clean `json` format as:
76
  }
77
  '''
78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  def polish_prompt(original_prompt: str) -> str:
80
- """Enhanced prompt rewriting using Qwen1.5-1.8B"""
81
  load_rewriter()
82
 
83
- # Format as Qwen chat with system prompt
84
  messages = [
85
  {"role": "system", "content": SYSTEM_PROMPT_EDIT},
86
  {"role": "user", "content": original_prompt}
87
  ]
88
 
89
- # Generate enhanced prompt
90
  text = rewriter_tokenizer.apply_chat_template(
91
  messages,
92
  tokenize=False,
@@ -98,29 +136,40 @@ def polish_prompt(original_prompt: str) -> str:
98
  with torch.no_grad():
99
  generated_ids = rewriter_model.generate(
100
  **model_inputs,
101
- max_new_tokens=120,
102
  do_sample=True,
103
- temperature=0.7,
104
- top_p=0.95,
105
- no_repeat_ngram_size=2
 
106
  )
107
 
108
  # Extract and clean response
109
  enhanced = rewriter_tokenizer.decode(
110
  generated_ids[0][model_inputs.input_ids.shape[1]:],
111
  skip_special_tokens=True
112
- )
113
 
114
- # Clean possible artifacts
115
- enhanced = enhanced.strip()
116
- if enhanced.lower().startswith(("rewritten instruction:", "enhanced:", "output:")):
117
- enhanced = re.split(r':', enhanced, 1)[-1].strip()
118
 
119
- # Remove any quotes around the prompt if present
120
- if enhanced.startswith('"') and enhanced.endswith('"'):
121
- enhanced = enhanced[1:-1]
 
 
122
 
123
- return enhanced
 
 
 
 
 
 
 
 
 
 
124
 
125
  # Load main image editing pipeline
126
  pipe = QwenImageEditPipeline.from_pretrained(
@@ -140,6 +189,7 @@ if is_xformers_available():
140
  else:
141
  print("xformers not available")
142
 
 
143
  def unload_rewriter():
144
  """Clear enhancement model from memory"""
145
  global rewriter_tokenizer, rewriter_model
@@ -266,8 +316,7 @@ with gr.Blocks(title="Qwen Image Editor Fast") as demo:
266
 
267
  rewrite_toggle = gr.Checkbox(
268
  label="Enable AI Prompt Enhancement",
269
- value=True,
270
- info="Uses local Qwen1.5-1.8B model to improve your instructions"
271
  )
272
 
273
  run_button = gr.Button("Generate Edits", variant="primary")
 
16
  os.environ.setdefault('HF_HUB_DISABLE_TELEMETRY', '1')
17
 
18
  # Model configuration
19
+ REWRITER_MODEL = "Qwen/Qwen1.5-7B-Chat" # Upgraded to 7B for better JSON handling
20
  rewriter_tokenizer = None
21
  rewriter_model = None
22
  dtype = torch.bfloat16
 
76
  }
77
  '''
78
 
79
+ def extract_json_response(model_output: str) -> str:
80
+ """Extract rewritten instruction from potentially messy JSON output"""
81
+ try:
82
+ # Try to find the JSON portion in the output
83
+ start_idx = model_output.find('{')
84
+ end_idx = model_output.rfind('}') + 1
85
+ if start_idx == -1 or end_idx == 0:
86
+ return None
87
+
88
+ json_str = model_output[start_idx:end_idx]
89
+ # Clean up common formatting issues
90
+ json_str = re.sub(r'(?<!")\b(\w+)\b(?=":)', r'"\1"', json_str) # Add quotes to keys
91
+ json_str = re.sub(r':\s*([^"{\[]|true|false|null)', r': "\1"', json_str) # Add quotes to values
92
+
93
+ # Parse JSON
94
+ data = json.loads(json_str)
95
+
96
+ # Extract rewritten prompt from possible key variations
97
+ possible_keys = [
98
+ "Rewritten", "rewritten", "Rewrited", "rewrited",
99
+ "Output", "output", "Enhanced", "enhanced"
100
+ ]
101
+ for key in possible_keys:
102
+ if key in data:
103
+ return data[key].strip()
104
+
105
+ # Try nested path
106
+ if "Response" in data and "Rewritten" in data["Response"]:
107
+ return data["Response"]["Rewritten"].strip()
108
+
109
+ # Fallback to direct extraction
110
+ for value in data.values():
111
+ if isinstance(value, str) and 10 < len(value) < 500:
112
+ return value.strip()
113
+
114
+ except Exception:
115
+ pass
116
+ return None
117
+
118
  def polish_prompt(original_prompt: str) -> str:
119
+ """Enhanced prompt rewriting using original system prompt with JSON handling"""
120
  load_rewriter()
121
 
122
+ # Format as Qwen chat
123
  messages = [
124
  {"role": "system", "content": SYSTEM_PROMPT_EDIT},
125
  {"role": "user", "content": original_prompt}
126
  ]
127
 
 
128
  text = rewriter_tokenizer.apply_chat_template(
129
  messages,
130
  tokenize=False,
 
136
  with torch.no_grad():
137
  generated_ids = rewriter_model.generate(
138
  **model_inputs,
139
+ max_new_tokens=256, # Maintain token count for good JSON generation
140
  do_sample=True,
141
+ temperature=0.6,
142
+ top_p=0.9,
143
+ no_repeat_ngram_size=2,
144
+ pad_token_id=rewriter_tokenizer.eos_token_id
145
  )
146
 
147
  # Extract and clean response
148
  enhanced = rewriter_tokenizer.decode(
149
  generated_ids[0][model_inputs.input_ids.shape[1]:],
150
  skip_special_tokens=True
151
+ ).strip()
152
 
153
+ # Try to extract JSON content
154
+ rewritten_prompt = extract_json_response(enhanced)
 
 
155
 
156
+ if rewritten_prompt:
157
+ # Clean up substitutions from the JSON output
158
+ rewritten_prompt = re.sub(r'(Replace|Change|Add) "([^"]*)"', r'\1 \2', rewritten_prompt)
159
+ rewritten_prompt = rewritten_prompt.replace('\\"', '"')
160
+ return rewritten_prompt
161
 
162
+ # Fallback cleanup if JSON extraction fails
163
+ print(f"⚠️ JSON extraction failed, using raw output: {enhanced}")
164
+ fallback = re.sub(r'```.*?```', '', enhanced, flags=re.DOTALL) # Remove code blocks
165
+ fallback = re.sub(r'[\{\}\[\]"]', '', fallback) # Remove JSON artifacts
166
+ fallback = fallback.split('\n')[0] # Take first line
167
+
168
+ # Try to extract before colon separator
169
+ if ': ' in fallback:
170
+ return fallback.split(': ')[1].strip()
171
+
172
+ return fallback.strip()
173
 
174
  # Load main image editing pipeline
175
  pipe = QwenImageEditPipeline.from_pretrained(
 
189
  else:
190
  print("xformers not available")
191
 
192
+
193
  def unload_rewriter():
194
  """Clear enhancement model from memory"""
195
  global rewriter_tokenizer, rewriter_model
 
316
 
317
  rewrite_toggle = gr.Checkbox(
318
  label="Enable AI Prompt Enhancement",
319
+ value=True
 
320
  )
321
 
322
  run_button = gr.Button("Generate Edits", variant="primary")