Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| import cv2 | |
| import torch | |
| from PIL import Image | |
| from pathlib import Path | |
| from threading import Thread | |
| from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer | |
| import spaces | |
| import time | |
| TITLE = " 诪讜讚诇 诪讘讜住住 讙诪讛 3 诇讬爪讬专转 砖讬专讬诐 诪讟讜驻砖讬诐 讘注讘专讬转 " | |
| DESCRIPTION= """ | |
| 谞讬转谉 诇讘拽砖 砖讬专 注诇 讘住讬住 讟拽住讟, 转诪讜谞讛 讜讜讬讚讗讜 | |
| [讛诪讜讚诇 讝诪讬谉 诇讛讜专讚讛](https://huggingface.co/Norod78/gemma-3_4b_hebrew-lyrics-finetune) | |
| 讛诪讜讚诇 讻旨讜旨讬址旨讬诇 注状讬 [讚讜专讜谉 讗讚诇专](https://linktr.ee/Norod78) | |
| """ | |
| # model config | |
| model_4b_name = "Norod78/gemma-3_4b_hebrew-lyrics-finetune" | |
| model_4b = Gemma3ForConditionalGeneration.from_pretrained( | |
| model_4b_name, | |
| device_map="auto", | |
| torch_dtype=torch.bfloat16 | |
| ).eval() | |
| processor_4b = AutoProcessor.from_pretrained(model_4b_name) | |
| # I will add timestamp later | |
| def extract_video_frames(video_path, num_frames=8): | |
| cap = cv2.VideoCapture(video_path) | |
| frames = [] | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| step = max(total_frames // num_frames, 1) | |
| for i in range(num_frames): | |
| cap.set(cv2.CAP_PROP_POS_FRAMES, i * step) | |
| ret, frame = cap.read() | |
| if ret: | |
| frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| frames.append(Image.fromarray(frame)) | |
| cap.release() | |
| return frames | |
| def format_message(content, files): | |
| message_content = [] | |
| if content: | |
| parts = content.split('<image>') | |
| for i, part in enumerate(parts): | |
| if part.strip(): | |
| message_content.append({"type": "text", "text": part.strip()}) | |
| if i < len(parts) - 1 and files: | |
| img = Image.open(files.pop(0)) | |
| message_content.append({"type": "image", "image": img}) | |
| for file in files: | |
| file_path = file if isinstance(file, str) else file.name | |
| if Path(file_path).suffix.lower() in ['.jpg', '.jpeg', '.png']: | |
| img = Image.open(file_path) | |
| message_content.append({"type": "image", "image": img}) | |
| elif Path(file_path).suffix.lower() in ['.mp4', '.mov']: | |
| frames = extract_video_frames(file_path) | |
| for frame in frames: | |
| message_content.append({"type": "image", "image": frame}) | |
| return message_content | |
| def format_conversation_history(chat_history): | |
| messages = [] | |
| current_user_content = [] | |
| for item in chat_history: | |
| role = item["role"] | |
| content = item["content"] | |
| if role == "user": | |
| if isinstance(content, str): | |
| current_user_content.append({"type": "text", "text": content}) | |
| elif isinstance(content, list): | |
| current_user_content.extend(content) | |
| else: | |
| current_user_content.append({"type": "text", "text": str(content)}) | |
| elif role == "assistant": | |
| if current_user_content: | |
| messages.append({"role": "user", "content": current_user_content}) | |
| current_user_content = [] | |
| messages.append({"role": "assistant", "content": [{"type": "text", "text": str(content)}]}) | |
| if current_user_content: | |
| messages.append({"role": "user", "content": current_user_content}) | |
| return messages | |
| def generate_response(input_data, chat_history, max_new_tokens, system_prompt, temperature, top_p, top_k, repetition_penalty): | |
| if isinstance(input_data, dict) and "text" in input_data: | |
| text = input_data["text"] | |
| files = input_data.get("files", []) | |
| else: | |
| text = str(input_data) | |
| files = [] | |
| new_message_content = format_message(text, files) | |
| new_message = {"role": "user", "content": new_message_content} | |
| system_message = [{"role": "system", "content": [{"type": "text", "text": system_prompt}]}] if system_prompt else [] | |
| processed_history = format_conversation_history(chat_history) | |
| messages = system_message + processed_history | |
| if messages and messages[-1]["role"] == "user": | |
| messages[-1]["content"].extend(new_message["content"]) | |
| else: | |
| messages.append(new_message) | |
| model = model_4b | |
| processor = processor_4b | |
| inputs = processor.apply_chat_template( | |
| messages, | |
| add_generation_prompt=True, | |
| tokenize=True, | |
| return_tensors="pt", | |
| return_dict=True | |
| ).to(model.device) | |
| streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True) | |
| generation_kwargs = dict( | |
| inputs, | |
| streamer=streamer, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=True, | |
| temperature=temperature, | |
| top_p=top_p, | |
| top_k=top_k, | |
| repetition_penalty=repetition_penalty | |
| ) | |
| thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
| thread.start() | |
| outputs = [] | |
| for text in streamer: | |
| outputs.append(text) | |
| yield "".join(outputs) | |
| chat_interface = gr.ChatInterface( | |
| fn=generate_response, | |
| chatbot=gr.Chatbot(rtl=True, show_copy_button=True,type="messages"), | |
| additional_inputs=[ | |
| gr.Slider(label="Max new tokens", minimum=100, maximum=2000, step=1, value=512), | |
| gr.Textbox( | |
| label="System Prompt", | |
| value="讗转讛 诪砖讜专专 讬砖专讗诇讬, 讻讜转讘 砖讬专讬诐 讘注讘专讬转", | |
| lines=4, | |
| placeholder="砖谞讛 讗转 讛讛讙讚专讜转 砖诇 讛诪讜讚诇", | |
| text_align = 'right', rtl = True | |
| ), | |
| gr.Slider(label="Temperature", minimum=0.1, maximum=2.0, step=0.1, value=0.6), | |
| gr.Slider(label="Top-p", minimum=0.05, maximum=1.0, step=0.05, value=0.92), | |
| gr.Slider(label="Top-k", minimum=1, maximum=100, step=1, value=70), | |
| gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.1), | |
| ], | |
| examples=[ | |
| [{"text": "讻转讜讘 诇讬 讘讘拽砖讛 砖讬专 讛诪转讗专 讗转 讛转诪讜谞讛", "files": ["examples/image1.jpg"]}], | |
| [{"text": "转驻讜讞 讗讚诪讛 注诐 讞专讚讛 讞讘专转讬转"}] | |
| ], | |
| textbox=gr.MultimodalTextbox( | |
| rtl=True, | |
| label="拽诇讟", | |
| file_types=["image", "video"], | |
| file_count="multiple", | |
| placeholder="讘拽砖讜 砖讬专 讜/讗讜 讛注诇讜 转诪讜谞讛", | |
| ), | |
| cache_examples=False, | |
| type="messages", | |
| fill_height=True, | |
| stop_btn="讛驻住拽", | |
| css_paths=["style.css"], | |
| multimodal=True, | |
| title=TITLE, | |
| description=DESCRIPTION, | |
| theme=gr.themes.Soft(), | |
| ) | |
| if __name__ == "__main__": | |
| chat_interface.queue(max_size=20).launch() | |