File size: 22,854 Bytes
d425e71
 
 
 
 
c59be23
 
d425e71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cf24c8d
d425e71
 
cf24c8d
 
 
d425e71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cf24c8d
d425e71
 
 
cf24c8d
 
 
d425e71
 
 
 
 
 
 
 
 
 
 
 
cf24c8d
 
 
d425e71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cf24c8d
d425e71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9373294
d425e71
9373294
d425e71
9373294
 
 
 
d425e71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c274a6b
 
 
cf24c8d
 
 
c274a6b
 
 
8bb7cb5
d425e71
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
"""Gradio demo for visualizing VLM first token probability distributions with two images."""

from typing import Any, Dict, List, Optional, Tuple

import gradio as gr
from spaces import GPU

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
from matplotlib.figure import Figure
from matplotlib.text import Text
from PIL import Image

from demo.lookup import ModelVariants, get_model_info  # noqa: E402
from src.main import get_model  # noqa: E402
from src.models.base import ModelBase  # noqa: E402
from src.models.config import Config, ModelSelection  # noqa: E402

models_cache: Dict[str, Any] = {}
current_model_selection: Optional[ModelSelection] = None


def read_layer_spec(spec_file_path: str) -> List[str]:
    """Read available layers from the model spec file.

    Args:
        spec_file_path: Path to the model specification file.

    Returns:
        List of available layer names, skipping blank lines.
    """
    try:
        with open(spec_file_path, 'r', encoding='utf-8') as f:
            lines = f.readlines()

        # Filter out blank lines and strip whitespace
        layers = [line.strip() for line in lines if line.strip()]
        return layers

    except FileNotFoundError:
        print(f'Spec file not found: {spec_file_path}')
        return ['Default layer (spec file not found)']
    except Exception as e:
        print(f'Error reading spec file: {str(e)}')
        return ['Default layer (error reading spec)']


def update_layer_choices(model_choice: str) -> Tuple[gr.Dropdown, gr.Button]:
    """Update layer dropdown choices based on selected model.

    Args:
        model_choice: Selected model name.

    Returns:
        Updated dropdown component and button visibility.
    """
    if not model_choice:
        return gr.Dropdown(choices=[], visible=False), gr.Button(visible=False)

    try:
        # Convert string choice to ModelVariants enum
        model_var = ModelVariants(model_choice.lower())

        # Get model info and read layer spec
        _, _, model_spec_path = get_model_info(model_var)
        layers = read_layer_spec(model_spec_path)

        # Return updated dropdown with layer choices and make button visible
        return (
            gr.Dropdown(
                choices=layers,
                label=f'Select Module for {model_choice}',
                value=layers[0] if layers else None,
                visible=True,
                interactive=True
            ),
            gr.Button('Analyze', variant='primary', visible=True)
        )

    except ValueError:
        return (
            gr.Dropdown(
                choices=['Model not implemented'],
                label='Select Module',
                visible=True,
                interactive=False
            ),
            gr.Button('Analyze', variant='primary', visible=False)
        )
    except Exception as e:
        return (
            gr.Dropdown(
                choices=[f'Error: {str(e)}'],
                label='Select Module',
                visible=True,
                interactive=False
            ),
            gr.Button('Analyze', variant='primary', visible=False)
        )


def load_model(model_var: ModelVariants, config: Config) -> ModelBase:
    """Load the specified VLM and processor.

    Args:
        model_var: The model to load from ModelVariants enum.
        config: The configuration object with model parameters.

    Returns:
        The loaded model instance.

    Raises:
        Exception: If model loading fails.
    """
    global models_cache, current_model_selection

    model_key = model_var.value

    # Check if model is already loaded
    if model_key in models_cache:
        current_model_selection = model_var
        return models_cache[model_key]

    print(f'Loading {model_var.value} model...')

    try:
        model_selection = config.architecture
        model = get_model(config.architecture, config)

        # Cache the loaded model and processor
        models_cache[model_key] = model
        current_model_selection = model_selection

        print(f'{model_selection.value} model loaded successfully!')
        return model

    except Exception as e:
        print(f'Error loading model {model_selection.value}: {str(e)}')
        raise


def get_single_image_probabilities(
    instruction: str,
    image: Image.Image,
    vlm: ModelBase,
    model_selection: ModelSelection,
    top_k: int = 8
) -> Tuple[List[str], np.ndarray]:
    """Process a single image and return first token probabilities.

    Args:
        instruction: Text instruction for the model.
        image: PIL Image to process.
        vlm: Loaded model.
        model_selection: The VLM being used.
        top_k: Number of top tokens to return.

    Returns:
        Tuple containing list of top tokens and their probabilities.
    """
    # Generate prompt and process inputs
    vlm.model.eval()
    text = vlm._generate_prompt(instruction, has_images=True)
    inputs = vlm._generate_processor_output(text, image)
    for key in inputs:
        if isinstance(inputs[key], torch.Tensor):
            inputs[key] = inputs[key].to(vlm.config.device)

    with torch.no_grad():
        outputs = vlm.model.generate(
            **inputs,
            max_new_tokens=1,  # Only generate first token
            output_scores=True,
            return_dict_in_generate=True,
            do_sample=False
        )

    # Get the logits for the first generated token
    first_token_logits = outputs.scores[0][0]  # Shape: [vocab_size]

    # Convert logits to probabilities
    probabilities = torch.softmax(first_token_logits, dim=-1)

    # Get top-k probabilities for visualization
    top_probs, top_indices = torch.topk(probabilities, top_k)

    # Convert tokens back to text
    top_tokens = [vlm.processor.tokenizer.decode([idx.item()]) for idx in top_indices]

    return top_tokens, top_probs.cpu().numpy()


def scale_figure_fonts(fig: Figure, factor: float = 1.5) -> None:
    """Multiply all text sizes in a Matplotlib Figure by `factor`.

    Args:
        fig: The Matplotlib Figure to scale.
        factor: The scaling factor (e.g., 1.5 to increase by 50%).
    """
    for ax in fig.get_axes():
        # titles & axis labels
        ax.title.set_fontsize(ax.title.get_fontsize() * factor)
        ax.xaxis.label.set_size(ax.xaxis.label.get_size() * factor)
        ax.yaxis.label.set_size(ax.yaxis.label.get_size() * factor)
        # tick labels
        for lbl in ax.get_xticklabels() + ax.get_yticklabels():
            lbl.set_fontsize(lbl.get_fontsize() * factor)
        # texts placed via ax.text(...) (e.g., numbers above bars / "No data" notes)
        for t in ax.texts:
            t.set_fontsize(t.get_fontsize() * factor)
    # any stray Text artists attached to the figure (rare, but safe)
    for t in fig.findobj(match=Text):
        if t.figure is fig:
            t.set_fontsize(t.get_fontsize() * factor)


def create_dual_probability_plot(
    tokens1: List[str], probabilities1: np.ndarray,
    tokens2: List[str], probabilities2: np.ndarray
) -> Figure:
    """Create a matplotlib plot comparing token probabilities from two images.

    Args:
        tokens1: List of token strings from first image.
        probabilities1: Array of probability values from first image.
        tokens2: List of token strings from second image.
        probabilities2: Array of probability values from second image.

    Returns:
        Matplotlib Figure object.
    """
    if len(tokens1) == 0 and len(tokens2) == 0:
        fig, ax = plt.subplots(figsize=(15, 8))
        ax.text(0.5, 0.5, 'No data to display',
                horizontalalignment='center', verticalalignment='center')
        ax.set_xlim(0, 1)
        ax.set_ylim(0, 1)
        return fig

    # Unify y-range with padding (cap at 1.0)
    max1 = float(np.max(probabilities1)) if len(tokens1) else 0.0
    max2 = float(np.max(probabilities2)) if len(tokens2) else 0.0
    y_upper = min(1.0, max(max1, max2) * 1.15 + 1e-6)  # ~15% headroom

    # Create subplots side by side with shared y
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 12), sharey=True)
    ax1.set_ylim(0, y_upper)
    ax2.set_ylim(0, y_upper)

    # Plot first image results
    if len(tokens1) > 0:
        bars1 = ax1.bar(range(len(tokens1)), probabilities1, color='lightcoral',
                        edgecolor='darkred', alpha=0.7)
        ax1.set_xlabel('Tokens', fontsize=12)
        ax1.set_ylabel('Probability', fontsize=12)
        ax1.set_title('Image 1 - First Token Probabilities',
                      fontsize=14, fontweight='bold')
        ax1.set_xticks(range(len(tokens1)))
        ax1.set_xticklabels(tokens1, rotation=45, ha='right')

        # Clamp label position so it stays inside the axes
        for bar, prob in zip(bars1, probabilities1):
            h = bar.get_height()
            y = min(h + 0.02 * y_upper, y_upper * 0.98)
            ax1.text(bar.get_x() + bar.get_width()/2., y, f'{prob:.3f}',
                     ha='center', va='bottom', fontsize=9)

        ax1.grid(axis='y', alpha=0.3)
    else:
        ax1.text(0.5, 0.5, 'No data for Image 1',
                 horizontalalignment='center', verticalalignment='center')
        ax1.set_xlim(0, 1)
        ax1.set_ylim(0, 1)

    # Plot second image results
    if len(tokens2) > 0:
        bars2 = ax2.bar(range(len(tokens2)), probabilities2, color='skyblue',
                        edgecolor='navy', alpha=0.7)
        ax2.set_xlabel('Tokens', fontsize=12)
        ax2.set_ylabel('Probability', fontsize=12)
        ax2.set_title('Image 2 - First Token Probabilities',
                      fontsize=14, fontweight='bold')
        ax2.set_xticks(range(len(tokens2)))
        ax2.set_xticklabels(tokens2, rotation=45, ha='right')

        for bar, prob in zip(bars2, probabilities2):
            h = bar.get_height()
            y = min(h + 0.02 * y_upper, y_upper * 0.98)
            ax2.text(bar.get_x() + bar.get_width()/2., y, f'{prob:.3f}',
                     ha='center', va='bottom', fontsize=9)

        ax2.grid(axis='y', alpha=0.3)
    else:
        ax2.text(0.5, 0.5, 'No data for Image 2',
                 horizontalalignment='center', verticalalignment='center')
        ax2.set_xlim(0, 1)
        ax2.set_ylim(0, 1)

    # Give extra space for rotated tick labels
    fig.tight_layout()
    fig.subplots_adjust(bottom=0.18)

    return fig


def get_module_similarity_pooled(
        vlm: ModelBase,
        module_name: str,
        image1: Image.Image,
        image2: Image.Image,
        instruction: str,
        pooling: str = 'mean'
) -> float:
    """Compute cosine similarity with optional pooling strategies.

    Args:
        vlm: The loaded VLM (ModelBase instance).
        module_name: The layer/module name to extract features from.
        image1: First PIL Image.
        image2: Second PIL Image.
        instruction: Text instruction for the model.
        pooling: Pooling strategy - 'mean', 'max', 'cls', or 'none'.

    Returns:
        Cosine similarity value between the two embeddings.

    Raises:
        ValueError: If feature extraction fails or module not found.
    """
    embeddings = {}
    target_module = None

    def hook_fn(
        module: torch.nn.Module,
        input: Any,
        output: Any
    ) -> None:
        """Forward hook to capture module output.

        Args:
            module: The module being hooked.
            input: The input to the module.
            output: The output from the module.
        """
        if isinstance(output, tuple):
            embeddings['activation'] = output[0].detach()
        else:
            embeddings['activation'] = output.detach()

    # Find and register hook
    for name, module in vlm.model.named_modules():
        if name == module_name:
            target_module = module
            hook_handle = module.register_forward_hook(hook_fn)
            break

    if target_module is None:
        raise ValueError(f"Module '{module_name}' not found in model")

    try:
        vlm.model.eval()
        # Extract embedding for image1
        text = vlm._generate_prompt(instruction, has_images=True)
        inputs1 = vlm._generate_processor_output(text, image1)
        for key in inputs1:
            if isinstance(inputs1[key], torch.Tensor):
                inputs1[key] = inputs1[key].to(vlm.config.device)

        embeddings.clear()
        with torch.no_grad():
            _ = vlm.model(**inputs1)

        if 'activation' not in embeddings:
            raise ValueError('Failed to extract features for image1')

        embedding1 = embeddings['activation']

        # Extract embedding for image2
        inputs2 = vlm._generate_processor_output(text, image2)
        for key in inputs2:
            if isinstance(inputs2[key], torch.Tensor):
                inputs2[key] = inputs2[key].to(vlm.config.device)

        embeddings.clear()
        with torch.no_grad():
            _ = vlm.model(**inputs2)

        if 'activation' not in embeddings:
            raise ValueError('Failed to extract features for image2')

        embedding2 = embeddings['activation']

        # Apply pooling strategy
        if pooling == 'mean':
            # Mean pooling across sequence dimension
            if embedding1.dim() >= 2:
                embedding1_pooled = embedding1.mean(dim=1)
                embedding2_pooled = embedding2.mean(dim=1)
            else:
                embedding1_pooled = embedding1
                embedding2_pooled = embedding2

        elif pooling == 'max':
            # Max pooling across sequence dimension
            if embedding1.dim() >= 2:
                embedding1_pooled = embedding1.max(dim=1)[0]
                embedding2_pooled = embedding2.max(dim=1)[0]
            else:
                embedding1_pooled = embedding1
                embedding2_pooled = embedding2

        elif pooling == 'cls':
            # Use first token (CLS token)
            if embedding1.dim() >= 2:
                embedding1_pooled = embedding1[:, 0, :]
                embedding2_pooled = embedding2[:, 0, :]
            else:
                embedding1_pooled = embedding1
                embedding2_pooled = embedding2

        elif pooling == 'none':
            # Flatten without pooling
            embedding1_pooled = embedding1.reshape(embedding1.shape[0], -1)
            embedding2_pooled = embedding2.reshape(embedding2.shape[0], -1)
        else:
            raise ValueError(f'Unknown pooling strategy: {pooling}')

        # Ensure 2D shape [batch, features]
        if embedding1_pooled.dim() == 1:
            embedding1_pooled = embedding1_pooled.unsqueeze(0)
            embedding2_pooled = embedding2_pooled.unsqueeze(0)

        # Compute cosine similarity
        similarity = F.cosine_similarity(embedding1_pooled, embedding2_pooled, dim=1)
        similarity_value = float(similarity.mean().cpu().item())

        return similarity_value

    finally:
        hook_handle.remove()


@GPU(duration=120)
def process_dual_inputs(
    model_choice: str,
    selected_layer: str,
    instruction: str,
    image1: Optional[Image.Image],
    image2: Optional[Image.Image],
    top_k: int = 8
) -> Tuple[Optional[Figure], str]:
    """Main function to process dual inputs and return comparison plot.

    Args:
        model_choice: String name of the selected model.
        selected_layer: String name of the selected layer.
        instruction: Text instruction for the model.
        image1: First PIL Image to process, can be None.
        image2: Second PIL Image to process, can be None.
        top_k: Number of top tokens to display.

    Returns:
        Tuple containing the plot figure and info text.
    """
    if image1 is None and image2 is None:
        return None, 'Please upload at least one image.'

    if not instruction.strip():
        return None, 'Please provide an instruction.'

    if not model_choice:
        return None, 'Please select a model.'

    if not selected_layer:
        return None, 'Please select a layer.'

    try:
        # Initialize a config
        model_var = ModelVariants(model_choice.lower())
        model_selection, model_path, _ = get_model_info(model_var)
        config = Config(model_selection, model_path, selected_layer, instruction)
        config.model = {
            'torch_dtype': torch.float16,
            'low_cpu_mem_usage': True,
            'device_map': 'auto'
        }

        # Load the model
        model = load_model(model_var, config)

        # Handle cases where only one image is provided
        if image1 is None:
            image1 = image2
            tokens1, probs1 = [], np.array([])
            tokens2, probs2 = get_single_image_probabilities(
                instruction, image2, model, model_selection, top_k
            )
        elif image2 is None:
            image2 = image1
            tokens1, probs1 = get_single_image_probabilities(
                instruction, image1, model, model_selection, top_k
            )
            tokens2, probs2 = [], np.array([])
        else:
            tokens1, probs1 = get_single_image_probabilities(
                instruction, image1, model, model_selection, top_k
            )
            tokens2, probs2 = get_single_image_probabilities(
                instruction, image2, model, model_selection, top_k
            )

        if len(tokens1) == 0 and len(tokens2) == 0:
            return None, 'Error: Could not process the inputs. Please check the model loading.'

        # Create comparison plot
        plot = create_dual_probability_plot(
            tokens1, probs1, tokens2, probs2
        )
        scale_figure_fonts(plot, factor=1.25)

        # Create info text
        info_text = f'Model: {model_choice.upper()}\n'
        info_text += f'Top-K: {top_k}\n'
        info_text += f"Instruction: '{instruction}'\n\n"

        if len(tokens1) > 0:
            info_text += f"Image 1 - Top token: '{tokens1[0]}' (probability: {probs1[0]:.4f})\n"
        else:
            info_text += 'Image 1 - No data\n'

        if len(tokens2) > 0:
            info_text += f"Image 2 - Top token: '{tokens2[0]}' (probability: {probs2[0]:.4f})\n"
        else:
            info_text += 'Image 2 - No data\n'

        if len(tokens1) > 0 and len(tokens2) > 0:
            info_text += f'\nLayer: {selected_layer}\n'
            similarity = get_module_similarity_pooled(model, selected_layer, image1, image2, instruction)
            info_text += f'Cosine similarity between Image 1 and 2: {similarity:.3f}\n'

        return plot, info_text

    except ValueError as e:
        return None, f'Invalid model selection: {str(e)}'
    except Exception as e:
        return None, f'Error: {str(e)}'


def create_demo() -> gr.Blocks:
    """Create and configure the Gradio demo interface for dual image comparison.

    Returns:
        Configured Gradio Blocks interface.
    """
    with gr.Blocks(title='VLM-Lens Visualizer') as demo:
        gr.Markdown("""
        # VLM-Lens (EMNLP 2025 System Demonstration)

        ## [arXiv](https://arxiv.org/abs/2510.02292) | [GitHub](https://github.com/compling-wat/vlm-lens)

        This beta version processes an instruction with up to two images through various VLMs,
        computes cosine similarity between their embeddings at a specified layer,
        and visualizes the probability distribution of the first token in the response for each image.
                    
        **Instructions:**
        1. Select a VLM from the dropdown
        2. Select a layer from the available embedding layers
        3. Upload two images for comparison
        4. Enter your instruction/question about the images
        5. Adjust the number of top tokens to display (1-20)
        6. Click "Analyze" to see the first token probability distributions side by side

        **Note:** You can upload just one image if you prefer single image analysis.
        """)

        with gr.Row():
            with gr.Column():
                model_dropdown = gr.Dropdown(
                    choices=[v.value.capitalize() for v in ModelVariants],
                    label='Select VLM',
                    value=None,
                    interactive=True
                )

                layer_dropdown = gr.Dropdown(
                    choices=[],
                    label='Select Module',
                    visible=False,
                    interactive=True
                )

                instruction_input = gr.Textbox(
                    label='Instruction',
                    placeholder='Describe what you see in this image...',
                    lines=3
                )

                top_k_slider = gr.Slider(
                    minimum=1,
                    maximum=20,
                    value=8,
                    step=1,
                    label='Number of Top Tokens to Display',
                    info='Select how many top probability tokens to show in the visualization'
                )

                with gr.Row():
                    image1_input = gr.Image(
                        label='Upload Image 1',
                        type='pil'
                    )
                    image2_input = gr.Image(
                        label='Upload Image 2',
                        type='pil'
                    )

                analyze_btn = gr.Button('Analyze', variant='primary', visible=False)

            with gr.Column():
                plot_output = gr.Plot(label='First Token Probability Distribution Comparison')
                info_output = gr.Textbox(
                    label='Analysis Info',
                    lines=8,
                    interactive=False
                )

        # Set up event handlers
        model_dropdown.change(
            fn=update_layer_choices,
            inputs=[model_dropdown],
            outputs=[layer_dropdown, analyze_btn]
        )

        analyze_btn.click(
            fn=process_dual_inputs,
            inputs=[model_dropdown, layer_dropdown, instruction_input, image1_input, image2_input, top_k_slider],
            outputs=[plot_output, info_output]
        )

        # Add examples
        gr.Examples(
            examples=[
                ['What is in this image? Describe in one word.', None, None],
                ['Describe the main object in the picture in one word.', None, None],
                ['What color is the dominant object? Describe in one word.', None, None],
            ],
            inputs=[instruction_input, image1_input, image2_input]
        )
            
    return demo


if __name__ == '__main__':
    # Create and launch the demo
    demo = create_demo()
    demo.launch(
        share=True,
        server_name='0.0.0.0',
        server_port=7860
    )