llava / tinyllava /model.py
C98yhou079's picture
Create tinyllava/model.py
51a5a00 verified
# minimal loader that uses transformers to load a multimodal model if available.
# This is a thin adapter: it expects model checkpoints on HF that are compatible with transformers.auto.modeling.
# For TinyLLaVA upstream functionality, replace with full repo.
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoProcessor
import torch
def load_pretrained_model(model_path: str, model_base=None, model_name: str=None):
"""
Minimal loader:
- tokenizer: AutoTokenizer.from_pretrained(model_path)
- model: AutoModelForCausalLM.from_pretrained(model_path, device_map="auto" if cuda else None)
- image_processor: AutoProcessor.from_pretrained(model_path) or AutoProcessor from a known vision model
Returns: tokenizer, model, image_processor, context_len
"""
if model_name is None:
model_name = model_path.split("/")[-1]
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
# Try to load an image processor / processor; fallback to using a BLIP processor if available
try:
image_processor = AutoProcessor.from_pretrained(model_path)
except Exception:
# fallback: try a common image processor (BLIP)
try:
image_processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
except Exception:
image_processor = None
# Load causal LM
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float32, low_cpu_mem_usage=True)
# context length: use tokenizer model_max_length if available
context_len = getattr(tokenizer, "model_max_length", 2048)
return tokenizer, model, image_processor, context_len