File size: 9,762 Bytes
ff119bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
import sys
import os
from typing import Optional
from PIL import Image as PILImage

# Add the cloned nanoVLM directory to Python's system path
NANOVLM_REPO_PATH = "/app/nanoVLM" # This path is where your Dockerfile clones huggingface/nanoVLM
if NANOVLM_REPO_PATH not in sys.path:
    print(f"DEBUG: Adding {NANOVLM_REPO_PATH} to sys.path")
    sys.path.insert(0, NANOVLM_REPO_PATH)

import gradio as gr
import torch
from transformers import AutoProcessor # Using AutoProcessor as in the successful generate.py

# Import the custom VisionLanguageModel class
VisionLanguageModel = None
try:
    print("DEBUG: Attempting to import VisionLanguageModel from models.vision_language_model")
    from models.vision_language_model import VisionLanguageModel
    print("DEBUG: Successfully imported VisionLanguageModel.")
except ImportError as e:
    print(f"CRITICAL ERROR: Importing VisionLanguageModel failed: {e}")
except Exception as e:
    print(f"CRITICAL ERROR: An unexpected error occurred during VisionLanguageModel import: {e}")

# --- Device Setup ---
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"DEBUG: Using device: {device}")

# --- Configuration ---
model_repo_id = "lusxvr/nanoVLM-222M" # Used for both processor and model weights
print(f"DEBUG: Model Repository ID for processor and model: {model_repo_id}")

# --- Initialize ---
processor = None
model = None

if VisionLanguageModel: # Only proceed if custom model class was imported
    try:
        # Load processor using AutoProcessor, mirroring generate.py
        print(f"DEBUG: Loading processor using AutoProcessor.from_pretrained('{model_repo_id}')")
        # generate.py doesn't explicitly use trust_remote_code=True for processor,
        # but it might be implicitly active in your local transformers or not needed if processor_config is clear.
        # Let's try without it first for AutoProcessor, then add if "Unrecognized model" for processor reappears.
        processor = AutoProcessor.from_pretrained(model_repo_id) # Try without TRC first for processor
        print(f"DEBUG: AutoProcessor loaded: {type(processor)}")

        # Ensure tokenizer has pad_token set if it's GPT-2 based (AutoProcessor should handle a tokenizer component)
        if hasattr(processor, 'tokenizer') and processor.tokenizer is not None:
            current_tokenizer = processor.tokenizer
            if getattr(current_tokenizer, 'pad_token', None) is None and hasattr(current_tokenizer, 'eos_token'):
                current_tokenizer.pad_token = current_tokenizer.eos_token
                print(f"DEBUG: Set processor.tokenizer.pad_token to eos_token (ID: {current_tokenizer.eos_token_id})")
        else:
            print("WARN: Processor does not have a 'tokenizer' attribute or it's None. Cannot set pad_token.")

        # Load model using VisionLanguageModel.from_pretrained, mirroring generate.py
        print(f"DEBUG: Loading model VisionLanguageModel.from_pretrained('{model_repo_id}')")
        # The custom VLM.from_pretrained doesn't take trust_remote_code
        model = VisionLanguageModel.from_pretrained(model_repo_id).to(device)
        print(f"DEBUG: VisionLanguageModel loaded: {type(model)}")
        model.eval()
        print("DEBUG: Model set to eval() mode.")

    except Exception as e:
        print(f"CRITICAL ERROR loading model or processor: {e}")
        import traceback
        traceback.print_exc()
        processor = None; model = None # Ensure they are None if loading fails
else:
    print("CRITICAL ERROR: VisionLanguageModel class not imported. Cannot load model.")


# --- Text Generation Function ---
def generate_text_for_image(image_input_pil: Optional[PILImage.Image], prompt_input_str: Optional[str]) -> str:
    print(f"DEBUG (generate_text_for_image): Received prompt: '{prompt_input_str}'")
    if model is None or processor is None:
        print("ERROR (generate_text_for_image): Model or processor not loaded.")
        return "Error: Model or processor not loaded. Please check the application logs."
    if image_input_pil is None:
        print("WARN (generate_text_for_image): No image uploaded.")
        return "Please upload an image."
    if not prompt_input_str: # Check for empty or None prompt
        print("WARN (generate_text_for_image): No prompt provided.")
        return "Please provide a prompt."

    try:
        current_pil_image = image_input_pil
        if not isinstance(current_pil_image, PILImage.Image): # Should be PIL from Gradio's type="pil"
             print(f"WARN (generate_text_for_image): Input image not PIL, type: {type(current_pil_image)}. Converting.")
             current_pil_image = PILImage.fromarray(current_pil_image)
        if current_pil_image.mode != "RGB":
            print(f"DEBUG (generate_text_for_image): Converting image from {current_pil_image.mode} to RGB.")
            current_pil_image = current_pil_image.convert("RGB")
        print(f"DEBUG (generate_text_for_image): Image prepped - size: {current_pil_image.size}, mode: {current_pil_image.mode}")

        # Prepare inputs using the AutoProcessor, as in generate.py
        print("DEBUG (generate_text_for_image): Processing inputs with AutoProcessor...")
        inputs = processor(
            text=[prompt_input_str], images=current_pil_image, return_tensors="pt"
        ).to(device)
        print(f"DEBUG (generate_text_for_image): Inputs from AutoProcessor - keys: {inputs.keys()}")
        print(f"DEBUG (generate_text_for_image):   input_ids shape: {inputs['input_ids'].shape}, values: {inputs['input_ids']}")
        print(f"DEBUG (generate_text_for_image):   pixel_values shape: {inputs['pixel_values'].shape}")
        
        attention_mask = inputs.get('attention_mask')
        if attention_mask is None: # Should be provided by AutoProcessor
            print("WARN (generate_text_for_image): attention_mask not in processor output. Creating default.")
            attention_mask = torch.ones_like(inputs['input_ids']).to(device)
        print(f"DEBUG (generate_text_for_image):   attention_mask shape: {attention_mask.shape}")

        print("DEBUG (generate_text_for_image): Calling model.generate...")
        # Signature for nanoVLM's generate: (self, input_ids, image, attention_mask, max_new_tokens, ...)
        generated_ids_tensor = model.generate(
            inputs['input_ids'],
            inputs['pixel_values'], # This is the 'image' argument for the model's generate method
            attention_mask,
            max_new_tokens=50,    # Consistent with successful generate.py test
            temperature=0.7,      # From generate.py defaults (or adjust as preferred)
            top_k=50,             # From generate.py defaults (or adjust as preferred)
            # greedy=False is default in nanoVLM's generate
        )
        print(f"DEBUG (generate_text_for_image): Raw generated_ids: {generated_ids_tensor}")

        # Use processor.batch_decode, as in generate.py
        generated_text_list = processor.batch_decode(generated_ids_tensor, skip_special_tokens=True)
        print(f"DEBUG (generate_text_for_image): Decoded text list: {generated_text_list}")
        generated_text_str = generated_text_list[0] if generated_text_list else ""
        
        # Optional: Clean up prompt if echoed
        cleaned_text_str = generated_text_str
        if prompt_input_str and generated_text_str.startswith(prompt_input_str):
             cleaned_text_str = generated_text_str[len(prompt_input_str):].lstrip(" ,.:")
        print(f"DEBUG (generate_text_for_image): Final cleaned text: '{cleaned_text_str}'")
        return cleaned_text_str.strip()

    except Exception as e:
        print(f"CRITICAL ERROR during generation: {e}")
        import traceback
        traceback.print_exc()
        return f"Error during generation: {str(e)}. Check logs."

# --- Gradio Interface ---
description_md = """
## nanoVLM-222M Interactive Demo
Upload an image and type a prompt to get a description or answer from the model.
This Space uses the `lusxvr/nanoVLM-222M` model weights with the `huggingface/nanoVLM` model code.
"""
iface = None
# Only define the interface if the model and processor loaded successfully
if VisionLanguageModel and model and processor:
    try:
        print("DEBUG: Defining Gradio interface...")
        iface = gr.Interface(
            fn=generate_text_for_image,
            inputs=[
                gr.Image(type="pil", label="Upload Image"),
                gr.Textbox(label="Your Prompt / Question", info="e.g., 'describe this image in detail'")
            ],
            outputs=gr.Textbox(label="Generated Text", show_copy_button=True),
            title="nanoVLM-222M Demo",
            description=description_md,
            allow_flagging="never" # No examples or caching for now to keep it simple
        )
        print("DEBUG: Gradio interface defined successfully.")
    except Exception as e:
        print(f"CRITICAL ERROR defining Gradio interface: {e}")
        import traceback; traceback.print_exc()
else:
    print("WARN: Model and/or processor did not load. Gradio interface will not be created.")


# --- Launch Gradio App ---
if __name__ == "__main__":
    print("DEBUG: Entered __main__ block for Gradio launch.")
    if iface is not None:
        print("DEBUG: Attempting to launch Gradio interface...")
        try:
            iface.launch(server_name="0.0.0.0", server_port=7860)
            print("DEBUG: Gradio launch command issued.")
        except Exception as e:
            print(f"CRITICAL ERROR launching Gradio interface: {e}")
            import traceback; traceback.print_exc()
    else:
        print("CRITICAL ERROR: Gradio interface (iface) is None or not defined due to loading errors. Cannot launch.")