Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	| import json | |
| import os, shutil | |
| import random | |
| from PIL import Image | |
| import jax | |
| from transformers import FlaxVisionEncoderDecoderModel, ViTFeatureExtractor, AutoTokenizer | |
| from huggingface_hub import hf_hub_download | |
| # create target model directory | |
| model_dir = './models/' | |
| os.makedirs(model_dir, exist_ok=True) | |
| files_to_download = [ | |
| "config.json", | |
| "flax_model.msgpack", | |
| "merges.txt", | |
| "special_tokens_map.json", | |
| "tokenizer.json", | |
| "tokenizer_config.json", | |
| "vocab.json", | |
| "preprocessor_config.json", | |
| ] | |
| # copy files from checkpoint hub: | |
| for fn in files_to_download: | |
| file_path = hf_hub_download("ydshieh/vit-gpt2-coco-en-ckpts", f"ckpt_epoch_3_step_6900/{fn}") | |
| shutil.copyfile(file_path, os.path.join(model_dir, fn)) | |
| model = FlaxVisionEncoderDecoderModel.from_pretrained(model_dir) | |
| feature_extractor = ViTFeatureExtractor.from_pretrained(model_dir) | |
| tokenizer = AutoTokenizer.from_pretrained(model_dir) | |
| max_length = 16 | |
| num_beams = 4 | |
| gen_kwargs = {"max_length": max_length, "num_beams": num_beams} | |
| def generate(pixel_values): | |
| output_ids = model.generate(pixel_values, **gen_kwargs).sequences | |
| return output_ids | |
| def predict(image): | |
| if image.mode != "RGB": | |
| image = image.convert(mode="RGB") | |
| pixel_values = feature_extractor(images=image, return_tensors="np").pixel_values | |
| output_ids = generate(pixel_values) | |
| preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True) | |
| preds = [pred.strip() for pred in preds] | |
| return preds[0] | |
| def _compile(): | |
| image_path = 'samples/val_000000039769.jpg' | |
| image = Image.open(image_path) | |
| predict(image) | |
| image.close() | |
| _compile() | |
| sample_dir = './samples/' | |
| sample_image_ids = tuple(["None"] + [int(f.replace('COCO_val2017_', '').replace('.jpg', '')) for f in os.listdir(sample_dir) if f.startswith('COCO_val2017_')]) | |
| with open(os.path.join(sample_dir, "coco-val2017-img-ids.json"), "r", encoding="UTF-8") as fp: | |
| coco_2017_val_image_ids = json.load(fp) | |
| def get_random_image_id(): | |
| image_id = random.sample(coco_2017_val_image_ids, k=1)[0] | |
| return image_id | |
