Spaces:
Runtime error
Runtime error
| #ํ๊น ํ์ด์ค์์ ๋์๊ฐ ์ ์๋๋ก ๋ฐ๊พธ์ด ๋ณด์์ | |
| import torch | |
| from transformers import BertTokenizerFast, BertForQuestionAnswering, Trainer, TrainingArguments | |
| from datasets import load_dataset | |
| from collections import defaultdict | |
| # ๋ฐ์ดํฐ ๋ถ๋ฌ์ค๊ธฐ | |
| dataset_load = load_dataset('Multimodal-Fatima/OK-VQA_train') | |
| dataset = dataset_load['train'].select(range(300)) | |
| # ๋ถํ์ํ ํน์ฑ ์ ํ | |
| selected_features = ['image', 'answers', 'question'] | |
| selected_dataset = dataset.map(lambda ex: {feature: ex[feature] for feature in selected_features}) | |
| # ์ํํธ ์ธ์ฝ๋ฉ | |
| answers_to_id = defaultdict(lambda: len(answers_to_id)) | |
| selected_dataset = selected_dataset.map(lambda ex: { | |
| 'answers': [answers_to_id[ans] for ans in ex['answers']], | |
| 'question': ex['question'], | |
| 'image': ex['image'] | |
| }) | |
| id_to_answers = {v: k for k, v in answers_to_id.items()} | |
| id_to_labels = {k: ex['answers'] for k, ex in enumerate(selected_dataset)} | |
| selected_dataset = selected_dataset.map(lambda ex: {'answers': id_to_labels.get(ex['answers'][0]), | |
| 'question': ex['question'], | |
| 'image': ex['image']}) | |
| flattened_features = [] | |
| for ex in selected_dataset: | |
| flattened_example = { | |
| 'answers': ex['answers'], | |
| 'question': ex['question'], | |
| 'image': ex['image'], | |
| } | |
| flattened_features.append(flattened_example) | |
| # ๋ชจ๋ธ ๊ฐ์ ธ์ค๊ธฐ | |
| from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer | |
| model_name = 'microsoft/git-base-vqav2' | |
| model = AutoModelForSequenceClassification.from_pretrained(model_name) | |
| # Trainer๋ฅผ ์ฌ์ฉํ์ฌ ๋ชจ๋ธ ํ์ต | |
| tokenizer = BertTokenizerFast.from_pretrained('bert-base-multilingual-cased') | |
| def preprocess_function(examples): | |
| tokenized_inputs = tokenizer(examples['question'], truncation=True, padding=True) | |
| return { | |
| 'input_ids': tokenized_inputs['input_ids'], | |
| 'attention_mask': tokenized_inputs['attention_mask'], | |
| 'pixel_values': [(4, 3, 244, 244)] * len(tokenized_inputs['input_ids']), | |
| 'pixel_mask': [1] * len(tokenized_inputs['input_ids']), | |
| 'labels': [[label] for label in examples['answers']] | |
| } | |
| dataset = load_dataset("Multimodal-Fatima/OK-VQA_train")['train'].select(range(300)) | |
| ok_vqa_dataset = dataset.map(preprocess_function, batched=True) | |
| ok_vqa_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'pixel_values', 'pixel_mask', 'labels']) | |
| training_args = TrainingArguments( | |
| output_dir='./results', | |
| num_train_epochs=20, | |
| per_device_train_batch_size=4, | |
| logging_steps=500, | |
| ) | |
| trainer = Trainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=ok_vqa_dataset | |
| ) | |
| # ๋ชจ๋ธ ํ์ต | |
| trainer.train() | |
| import gradio as gr | |
| import torch | |
| from transformers import BertTokenizer, BertForSequenceClassification | |
| # ๋ชจ๋ธ ์ด๊ธฐํ ๋ฐ ๊ฐ์ค์น ๋ถ๋ฌ์ค๊ธฐ | |
| model_name = 'microsoft/git-base-vqav2' # ์ฌ์ฉํ ๋ชจ๋ธ์ ์ด๋ฆ | |
| model = BertForSequenceClassification.from_pretrained(model_name) | |
| tokenizer = BertTokenizer.from_pretrained(model_name) | |
| # ์์ธก ํจ์ ์ ์ | |
| def predict_answer(image, question): | |
| inputs = tokenizer(question, return_tensors='pt') | |
| input_ids = inputs['input_ids'] | |
| attention_mask = inputs['attention_mask'] | |
| # ์ด๋ฏธ์ง์ ๊ด๋ จ๋ ์ฒ๋ฆฌ ์ํ | |
| # ์ด๋ฏธ์ง ์ฒ๋ฆฌ ์ฝ๋๋ฅผ ์ฌ๊ธฐ์ ์ถ๊ฐํด์ผ ํฉ๋๋ค (์ ๋ ฅ๋ ์ด๋ฏธ์ง์ ๋ํ ์ ์ฒ๋ฆฌ ๋ฑ) | |
| # ๋ชจ๋ธ์ ์ ๋ ฅ ๋ฐ์ดํฐ๋ฅผ ์ ๋ฌํ์ฌ ์์ธก ์ํ | |
| with torch.no_grad(): | |
| outputs = model(input_ids=input_ids, attention_mask=attention_mask) | |
| # ์์ธก ๊ฒฐ๊ณผ์์ ๊ฐ์ฅ ๋์ ํ๋ฅ ์ ๊ฐ์ง ๋ ์ด๋ธ ID ๊ฐ์ ธ์ค๊ธฐ | |
| predicted_label_id = torch.argmax(outputs.logits).item() | |
| predicted_label = id_to_label_fn(predicted_label_id) | |
| return predicted_label | |
| iface = gr.Interface( | |
| fn=predict_answer, | |
| inputs=["image", "text"], | |
| outputs="text", | |
| title="Visual Question Answering", | |
| description="Input an image and a question to get the model's answer.", | |
| example=[ | |
| "https://your_image_url.jpg", | |
| "What is shown in the image?" | |
| ] | |
| ) | |
| iface.launch() | |