Update app.py
Browse files
app.py
CHANGED
|
@@ -6,17 +6,17 @@ from torchvision.transforms.functional import InterpolationMode
|
|
| 6 |
from transformers import AutoModel, AutoTokenizer
|
| 7 |
import gradio as gr
|
| 8 |
|
|
|
|
| 9 |
IMAGENET_MEAN = (0.485, 0.456, 0.406)
|
| 10 |
IMAGENET_STD = (0.229, 0.224, 0.225)
|
| 11 |
|
| 12 |
# Build the image transform
|
| 13 |
def build_transform(input_size):
|
| 14 |
-
MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
|
| 15 |
transform = T.Compose([
|
| 16 |
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
|
| 17 |
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
|
| 18 |
T.ToTensor(),
|
| 19 |
-
T.Normalize(mean=
|
| 20 |
])
|
| 21 |
return transform
|
| 22 |
|
|
@@ -60,10 +60,9 @@ path = 'OpenGVLab/InternVL2_5-78B'
|
|
| 60 |
model = AutoModel.from_pretrained(
|
| 61 |
path,
|
| 62 |
torch_dtype=torch.bfloat16,
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
).eval().cuda()
|
| 67 |
|
| 68 |
tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True, use_fast=False)
|
| 69 |
|
|
|
|
| 6 |
from transformers import AutoModel, AutoTokenizer
|
| 7 |
import gradio as gr
|
| 8 |
|
| 9 |
+
# Constants for ImageNet preprocessing
|
| 10 |
IMAGENET_MEAN = (0.485, 0.456, 0.406)
|
| 11 |
IMAGENET_STD = (0.229, 0.224, 0.225)
|
| 12 |
|
| 13 |
# Build the image transform
|
| 14 |
def build_transform(input_size):
|
|
|
|
| 15 |
transform = T.Compose([
|
| 16 |
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
|
| 17 |
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
|
| 18 |
T.ToTensor(),
|
| 19 |
+
T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
|
| 20 |
])
|
| 21 |
return transform
|
| 22 |
|
|
|
|
| 60 |
model = AutoModel.from_pretrained(
|
| 61 |
path,
|
| 62 |
torch_dtype=torch.bfloat16,
|
| 63 |
+
trust_remote_code=True,
|
| 64 |
+
device_map="auto" # Use device map for efficient memory handling
|
| 65 |
+
)
|
|
|
|
| 66 |
|
| 67 |
tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True, use_fast=False)
|
| 68 |
|