Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| import torch | |
| from transformers import AutoConfig, AutoModelForCausalLM | |
| from janus.models import MultiModalityCausalLM, VLChatProcessor | |
| from PIL import Image | |
| import numpy as np | |
| import spaces | |
| import logging | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Constants | |
| DEFAULT_WIDTH = 384 | |
| DEFAULT_HEIGHT = 384 | |
| PARALLEL_SIZE = 5 | |
| PATCH_SIZE = 16 | |
| # Load model and processor with error handling | |
| def load_model(): | |
| try: | |
| model_path = "deepseek-ai/Janus-Pro-7B" | |
| config = AutoConfig.from_pretrained(model_path) | |
| language_config = config.language_config | |
| language_config._attn_implementation = 'eager' | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| logger.info(f"Loading model on device: {device}") | |
| vl_gpt = AutoModelForCausalLM.from_pretrained( | |
| model_path, | |
| language_config=language_config, | |
| trust_remote_code=True, | |
| torch_dtype=torch.bfloat16 if device.type == "cuda" else torch.float32 | |
| ).to(device) | |
| vl_chat_processor = VLChatProcessor.from_pretrained(model_path) | |
| return vl_gpt, vl_chat_processor, device | |
| except Exception as e: | |
| logger.error(f"Model loading failed: {str(e)}") | |
| raise RuntimeError("Failed to load model. Please check the model path and dependencies.") | |
| try: | |
| vl_gpt, vl_chat_processor, device = load_model() | |
| tokenizer = vl_chat_processor.tokenizer | |
| except RuntimeError as e: | |
| raise e | |
| # Helper functions with improved memory management | |
| def generate(input_ids, width, height, cfg_weight=5, temperature=1.0, parallel_size=5, progress=None): | |
| try: | |
| torch.cuda.empty_cache() | |
| tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int, device=device) | |
| for i in range(parallel_size * 2): | |
| tokens[i, :] = input_ids | |
| if i % 2 != 0: | |
| tokens[i, 1:-1] = vl_chat_processor.pad_id | |
| with torch.no_grad(): | |
| inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens) | |
| generated_tokens = torch.zeros((parallel_size, 576), dtype=torch.int, device=device) | |
| pkv = None | |
| total_steps = 576 | |
| for i in range(total_steps): | |
| if progress is not None: | |
| progress((i + 1) / total_steps, desc="Generating image tokens") | |
| outputs = vl_gpt.language_model.model( | |
| inputs_embeds=inputs_embeds, | |
| use_cache=True, | |
| past_key_values=pkv | |
| ) | |
| pkv = outputs.past_key_values | |
| hidden_states = outputs.last_hidden_state | |
| logits = vl_gpt.gen_head(hidden_states[:, -1, :]) | |
| logit_cond = logits[0::2, :] | |
| logit_uncond = logits[1::2, :] | |
| logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond) | |
| probs = torch.softmax(logits / temperature, dim=-1) | |
| next_token = torch.multinomial(probs, num_samples=1) | |
| generated_tokens[:, i] = next_token.squeeze(dim=-1) | |
| next_token = torch.cat([next_token.unsqueeze(dim=1)] * 2, dim=1).view(-1) | |
| img_embeds = vl_gpt.prepare_gen_img_embeds(next_token) | |
| inputs_embeds = img_embeds.unsqueeze(dim=1) | |
| return generated_tokens | |
| except RuntimeError as e: | |
| logger.error(f"Generation error: {str(e)}") | |
| raise RuntimeError("Generation failed due to memory constraints. Try reducing the parallel size.") | |
| finally: | |
| torch.cuda.empty_cache() | |
| def unpack(patches, width, height, parallel_size=5): | |
| try: | |
| patches = patches.detach().to(device='cpu', dtype=torch.float32).numpy() | |
| patches = patches.transpose(0, 2, 3, 1) | |
| patches = np.clip((patches + 1) / 2 * 255, 0, 255) | |
| return [Image.fromarray(patch.astype(np.uint8)) for patch in patches] | |
| except Exception as e: | |
| logger.error(f"Unpacking error: {str(e)}") | |
| raise RuntimeError("Failed to process generated image data.") | |
| def generate_image(prompt, seed=None, guidance=5, t2i_temperature=1.0, progress=gr.Progress()): | |
| try: | |
| if not prompt.strip(): | |
| raise gr.Error("Please enter a valid prompt.") | |
| if progress is not None: | |
| progress(0, desc="Initializing...") | |
| torch.cuda.empty_cache() | |
| # Seed management | |
| if seed is None: | |
| seed = torch.seed() | |
| else: | |
| seed = int(seed) | |
| torch.manual_seed(seed) | |
| if device.type == "cuda": | |
| torch.cuda.manual_seed(seed) | |
| messages = [{'role': '<|User|>', 'content': prompt}, {'role': '<|Assistant|>', 'content': ''}] | |
| text = vl_chat_processor.apply_sft_template_for_multi_turn_prompts( | |
| conversations=messages, | |
| sft_format=vl_chat_processor.sft_format, | |
| system_prompt='' | |
| ) + vl_chat_processor.image_start_tag | |
| input_ids = torch.tensor(tokenizer.encode(text), dtype=torch.long, device=device) | |
| if progress is not None: | |
| progress(0.1, desc="Generating image tokens...") | |
| generated_tokens = generate( | |
| input_ids, | |
| DEFAULT_WIDTH, | |
| DEFAULT_HEIGHT, | |
| cfg_weight=guidance, | |
| temperature=t2i_temperature, | |
| parallel_size=PARALLEL_SIZE, | |
| progress=progress | |
| ) | |
| if progress is not None: | |
| progress(0.9, desc="Processing images...") | |
| patches = vl_gpt.gen_vision_model.decode_code( | |
| generated_tokens.to(dtype=torch.int), | |
| shape=[PARALLEL_SIZE, 8, DEFAULT_WIDTH // PATCH_SIZE, DEFAULT_HEIGHT // PATCH_SIZE] | |
| ) | |
| images = unpack(patches, DEFAULT_WIDTH, DEFAULT_HEIGHT, PARALLEL_SIZE) | |
| return images | |
| except Exception as e: | |
| logger.error(f"Generation failed: {str(e)}", exc_info=True) | |
| if "index out of range" in str(e).lower(): | |
| raise gr.Error("Image generation failed due to internal error. Please try again with different parameters.") | |
| else: | |
| raise gr.Error(f"Image generation failed: {str(e)}") | |
| def create_interface(): | |
| with gr.Blocks(title="Janus-Pro-7B Image Generator", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(""" | |
| # Text-to-Image Generation with Janus-Pro-7B | |
| **Generate high-quality images from text prompts using DeepSeek's advanced multimodal AI model.** | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| prompt_input = gr.Textbox( | |
| label="Prompt", | |
| placeholder="Describe the image you want to generate...", | |
| lines=3 | |
| ) | |
| generate_btn = gr.Button("Generate Images", variant="primary") | |
| with gr.Accordion("Advanced Settings", open=False): | |
| with gr.Group(): | |
| seed_input = gr.Number( | |
| label="Seed", | |
| value=None, | |
| precision=0, | |
| info="Leave empty for random seed" | |
| ) | |
| guidance_slider = gr.Slider( | |
| label="CFG Guidance Weight", | |
| minimum=3, | |
| maximum=10, | |
| value=5, | |
| step=0.5, | |
| info="Higher values = more prompt adherence, lower values = more creativity" | |
| ) | |
| temp_slider = gr.Slider( | |
| label="Temperature", | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=1.0, | |
| step=0.1, | |
| info="Higher values = more randomness, lower values = more deterministic" | |
| ) | |
| with gr.Column(scale=2): | |
| output_gallery = gr.Gallery( | |
| label="Generated Images", | |
| columns=2, | |
| height=600, | |
| preview=True | |
| ) | |
| status = gr.Textbox( | |
| label="Status", | |
| interactive=False | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| ["A futuristic cityscape at sunset with flying cars and holographic advertisements"], | |
| ["An astronaut riding a horse in photorealistic style"], | |
| ["A cute robotic cat sitting on a stack of ancient books, digital art"] | |
| ], | |
| inputs=prompt_input | |
| ) | |
| gr.Markdown(""" | |
| ## Model Information | |
| - **Model:** [Janus-Pro-7B](https://huggingface.co/deepseek-ai/Janus-Pro-7B) | |
| - **Output Resolution:** 384x384 pixels | |
| - **Parallel Generation:** 5 images per request | |
| """) | |
| # Footer Section | |
| gr.Markdown(""" | |
| <hr style="margin-top: 2em; margin-bottom: 1em;"> | |
| <div style="text-align: center; color: #666; font-size: 0.9em;"> | |
| Created with ❤️ by <a href="https://bilsimaging.com" target="_blank" style="color: #2563eb; text-decoration: none;">bilsimaging.com</a> | |
| </div> | |
| """) | |
| # Visitor Badge | |
| gr.HTML(""" | |
| <div style="text-align: center; margin-top: 1em;"> | |
| <a href="https://visitorbadge.io/status?path=https%3A%2F%2Fhuggingface.co%2Fspaces%2FBils%2FDeepseekJanusPro%2F"> | |
| <img src="https://api.visitorbadge.io/api/visitors?path=https%3A%2F%2Fhuggingface.co%2Fspaces%2FBils%2FDeepseekJanusPro%2F&countColor=%23263759" | |
| alt="Visitor Badge" | |
| style="display: inline-block; margin: 0 auto;"> | |
| </a> | |
| </div> | |
| """) | |
| generate_btn.click( | |
| generate_image, | |
| inputs=[prompt_input, seed_input, guidance_slider, temp_slider], | |
| outputs=output_gallery, | |
| api_name="generate" | |
| ) | |
| demo.load( | |
| fn=lambda: f"Device Status: {'GPU ✅' if device.type == 'cuda' else 'CPU ⚠️'}", | |
| outputs=status, | |
| queue=False | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| demo = create_interface() | |
| demo.launch(share=True) |