Spaces:
Runtime error
Runtime error
| import csv | |
| import enum | |
| import json | |
| import logging | |
| import os | |
| import re | |
| import string | |
| import sys | |
| import unicodedata | |
| from typing import Any, Dict, List, NewType, Union | |
| import numpy as np | |
| import openai | |
| import pandas as pd | |
| import requests | |
| import yaml | |
| from datasets import Dataset, load_dataset | |
| from easygoogletranslate import EasyGoogleTranslate | |
| from langchain.prompts import FewShotPromptTemplate, PromptTemplate | |
| from tqdm import tqdm | |
| from yaml.loader import SafeLoader | |
| from tasks import qa, summarization, ner, nli | |
| # from models.model_completion import gpt3x_completion, gemini_completion | |
| class LanguageType(enum.Enum): | |
| Low = "Low" | |
| High = "High" | |
| class ModelType(enum.Enum): | |
| English = "English" | |
| Multilingual = "Multilingual" | |
| def get_entities_gpt3_long(prompt): | |
| response = openai.ChatCompletion.create( | |
| engine="chatgpt", temperature=0, messages=[{"role": "user", "content": prompt}] | |
| ) | |
| return response["choices"][0]["message"]["content"] | |
| def gpt3x_completion( | |
| prompt: Union[str, List[Dict[str, str]]], | |
| ) -> str: | |
| import os | |
| import openai | |
| os.environ["OPENAI_API_KEY"] = '' | |
| def get_entities_chatGPT(final_prompt): | |
| response = openai.ChatCompletion.create( | |
| engine="gpt35-16k", | |
| temperature=0, | |
| messages=[ | |
| {"role": "user", "content": final_prompt} | |
| ] | |
| ) | |
| return response['choices'][0]['message']['content'] | |
| return get_entities_chatGPT(final_prompt=prompt) | |
| def mixtral_completion(prompt): | |
| url = "https://api.together.xyz/v1/chat/completions" | |
| # Define your Together API key | |
| together_api_key = "" # Replace with your actual API key | |
| # Define the request payload | |
| payload = { | |
| "temperature": 0, | |
| "max_tokens": 30, | |
| "model": "mistralai/Mixtral-8x7B-Instruct-v0.1", | |
| "messages": [{"role": "user", "content": f"{prompt}"}], | |
| } | |
| # Define request headers | |
| headers = { | |
| "Authorization": f"Bearer {together_api_key}", | |
| "Content-Type": "application/json", | |
| } | |
| # Send POST request | |
| response = requests.post(url, json=payload, headers=headers) | |
| # Check response status | |
| if response.status_code == 200: | |
| # Print the response content (API output) | |
| return response.json()["choices"][0]["message"]["content"] | |
| else: | |
| # Print error message if request fails | |
| print(f"Error: {response.status_code} - {response.text}") | |
| XQUAD_LANG2CODES = { | |
| "bengali": "bn", | |
| "korean": "ko", | |
| "swahili": "sw", | |
| "english": "en", | |
| "indonesian": "id", | |
| "arabic": "ar", | |
| "finnish": "fi", | |
| "telugu": "te", | |
| "russian": "ru", | |
| "german": "de", | |
| "greek": "el", | |
| "hindi": "hi", | |
| "vietnamese": "vi", | |
| "romanian": "ro", | |
| } | |
| INDICQA_LANG2CODES = { | |
| "indicqa": "as", | |
| "bengali": "bn", | |
| "gujarati": "gu", | |
| "hindi": "hi", | |
| "kannada": "kn", | |
| "malayalam": "ml", | |
| "marathi": "mr", | |
| "odia": "or", | |
| "punjabi": "pa", | |
| "tamil": "ta", | |
| "telugu": "te", | |
| "assamese": "as", | |
| } | |
| PUNCT = { | |
| chr(i) | |
| for i in range(sys.maxunicode) | |
| if unicodedata.category(chr(i)).startswith("P") | |
| }.union(string.punctuation) | |
| WHITESPACE_LANGS = ["en", "es", "hi", "vi", "de", "ar"] | |
| MIXED_SEGMENTATION_LANGS = ["zh"] | |
| TYDIQA_LANG2CODES = { | |
| "bengali": "bn", | |
| "korean": "ko", | |
| "swahili": "sw", | |
| "english": "en", | |
| "indonesian": "id", | |
| "arabic": "ar", | |
| "finnish": "fi", | |
| "telugu": "te", | |
| "russian": "ru", | |
| "assamese": "as", | |
| "persian": "fa", | |
| } | |
| logger = logging.Logger("Xlsum_task") | |
| LANGUAGE_TO_SUFFIX = { | |
| "chinese_simplified": "zh-CN", | |
| "french": "fr", | |
| "portuguese": "pt", | |
| "english": "en", | |
| "arabic": "ar", | |
| "hindi": "hi", | |
| "indonesian": "id", | |
| "amharic": "am", | |
| "bengali": "bn", | |
| "telugu": "te", | |
| "burmese": "my", | |
| "german": "de", | |
| "greek": "el", | |
| "tamil": "ta", | |
| "assamese": "as", | |
| "hindi": "hi", | |
| "vietnamese": "vi", | |
| "russian": "ru", | |
| "telugu": "te", | |
| "romanian": "ro", | |
| "malayalam": "ml", | |
| "persian": "fa", | |
| } | |
| PARAMS = NewType("PARAMS", Dict[str, Any]) | |
| def read_parameters(args_path) -> PARAMS: | |
| with open(args_path) as f: | |
| args = yaml.load(f, Loader=SafeLoader) | |
| return args | |
| def load_qa_dataset(dataset_name, lang, split, translate_test=False, limit=5): | |
| if dataset_name == "indicqa": | |
| if split != "train": | |
| dataset = load_dataset( | |
| "ai4bharat/IndicQA", f"indicqa.{INDICQA_LANG2CODES[lang]}" | |
| )[split] | |
| else: | |
| dataset = load_dataset("squad_v2")[split] | |
| elif dataset_name == "xquad": | |
| if split != "train": | |
| dataset = load_dataset("xquad", f"xquad.{XQUAD_LANG2CODES[lang]}")[ | |
| "validation" | |
| ] | |
| else: | |
| dataset = load_dataset("squad")[split] | |
| elif dataset_name == "tydiqa": | |
| dataset = load_dataset("tydiqa", "secondary_task")[split] | |
| dataset = dataset.map( | |
| lambda example: {"lang": TYDIQA_LANG2CODES[example["id"].split("-")[0]]} | |
| ) | |
| dataset = dataset.filter(lambda example: example["lang"] == lang) | |
| elif dataset_name == "mlqa": | |
| if split == "train": | |
| print("No Training Data for MLQA, switching to validation!") | |
| split = "validation" | |
| if translate_test: | |
| dataset_name = f"mlqa-translate-test.{lang}" | |
| else: | |
| dataset_name = f"mlqa.{lang}.{lang}" | |
| dataset = load_dataset("mlqa", dataset_name)[split] | |
| else: | |
| raise NotImplementedError() | |
| return dataset.select(np.arange(limit)) | |
| def construct_prompt( | |
| instruction: str, | |
| test_example: dict, | |
| ic_examples: List[dict], | |
| zero_shot: bool, | |
| lang: str, | |
| config: Dict[Any, Any], | |
| ): | |
| example_prompt = PromptTemplate( | |
| input_variables=["context", "question", "answers"], | |
| template="Context: {context}\nQuestion: {question}\n" "Answers: {answers}", | |
| ) | |
| zero_shot_template = ( | |
| f"""{instruction}""" + "\n<Context>: {context} \n<Question>: {question} " "" | |
| ) | |
| prompt = ( | |
| FewShotPromptTemplate( | |
| examples=ic_examples, | |
| prefix=instruction, | |
| example_prompt=example_prompt, | |
| suffix="<Context>: {context} \n<Question>: {question} \nAnswers: ?", | |
| input_variables=["question", "context"], | |
| ) | |
| if not zero_shot | |
| else PromptTemplate( | |
| input_variables=["question", "context"], template=zero_shot_template | |
| ) | |
| ) | |
| label = test_example["answers"] | |
| if config["input"] != lang: | |
| test_example = _translate_example( | |
| example=test_example, src_language=lang, target_language=config["input"] | |
| ) | |
| return ( | |
| prompt.format( | |
| question=test_example["question"], context=test_example["context"] | |
| ), | |
| label, | |
| ) | |
| def dump_metrics( | |
| lang: str, config: Dict[str, str], f1: float, em: float, metric_logger_path: str | |
| ): | |
| # Check if the metric logger file exists | |
| file_exists = os.path.exists(metric_logger_path) | |
| # Open the CSV file in append mode | |
| with open(metric_logger_path, "a", newline="") as f: | |
| csvwriter = csv.writer(f, delimiter=",") | |
| # Write header row if the file is newly created | |
| if not file_exists: | |
| header = ["Language", "Prefix", "Input", "Context", "Output", "F1", "Em"] | |
| csvwriter.writerow(header) | |
| csvwriter.writerow( | |
| [ | |
| lang, | |
| config["prefix"], | |
| config["input"], | |
| config["context"][0], | |
| config["output"], | |
| f1, | |
| em, | |
| ] | |
| ) | |
| def dump_predictions(idx, response, label, response_logger_file): | |
| obj = {"q_idx": idx, "prediction": response, "label": label} | |
| with open(response_logger_file, "a") as f: | |
| f.write(json.dumps(obj, ensure_ascii=False) + "\n") | |
| def _translate_instruction(basic_instruction: str, target_language: str) -> str: | |
| translator = EasyGoogleTranslate( | |
| source_language="en", | |
| target_language=LANGUAGE_TO_SUFFIX[target_language], | |
| timeout=50, | |
| ) | |
| return translator.translate(basic_instruction) | |
| def _translate_prediction_to_output_language( | |
| prediction: str, prediction_language: str, output_language: str | |
| ) -> str: | |
| translator = EasyGoogleTranslate( | |
| source_language=LANGUAGE_TO_SUFFIX[prediction_language], | |
| target_language=LANGUAGE_TO_SUFFIX[output_language], | |
| timeout=10, | |
| ) | |
| return translator.translate(prediction) | |
| def create_instruction(lang: str, expected_output: str): | |
| basic_instruction = ( | |
| "Answer to the <Question> below, based only to the given <Context>, Follow these instructions:\n" | |
| "1. The answer should include only words from the given context\n" | |
| "2. The answer must include up to 5 words\n" | |
| "3. The answer Should be the shortest as possible\n" | |
| f"4. The answer must be in {expected_output} only!, not another language!!!" | |
| ) | |
| return ( | |
| basic_instruction | |
| if lang == "english" | |
| else _translate_instruction(basic_instruction, target_language=lang) | |
| ) | |
| def _translate_example( | |
| example: Dict[str, str], src_language: str, target_language: str | |
| ): | |
| translator = EasyGoogleTranslate( | |
| source_language=LANGUAGE_TO_SUFFIX[str(src_language).lower()], | |
| target_language=LANGUAGE_TO_SUFFIX[str(target_language).lower()], | |
| timeout=30, | |
| ) | |
| return { | |
| "question": translator.translate(example["question"]), | |
| "context": translator.translate(example["context"][:2000]) | |
| + translator.translate(example["context"][2000:4000]) | |
| + translator.translate(example["context"][4000:6000]), | |
| "answers": translator.translate(example["answers"][0]), | |
| } | |
| # except Exception as e: | |
| # print(example["text"]) | |
| # print(example["summary"]) | |
| # print(e) | |
| def choose_few_shot_examples( | |
| train_dataset: Dataset, | |
| few_shot_size: int, | |
| context: List[str], | |
| selection_criteria: str, | |
| lang: str, | |
| ) -> List[Dict[str, Union[str, int]]]: | |
| """Selects few-shot examples from training datasets | |
| Args: | |
| train_dataset (Dataset): Training Dataset | |
| few_shot_size (int): Number of few-shot examples | |
| selection_criteria (few_shot_selection): How to select few-shot examples. Choices: [random, first_k] | |
| Returns: | |
| List[Dict[str, Union[str, int]]]: Selected examples | |
| """ | |
| selected_examples = [] | |
| example_idxs = [] | |
| if selection_criteria == "first_k": | |
| example_idxs = list(range(few_shot_size)) | |
| elif selection_criteria == "random": | |
| example_idxs = ( | |
| np.random.choice(len(train_dataset), size=few_shot_size, replace=True) | |
| .astype(int) | |
| .tolist() | |
| ) | |
| ic_examples = [ | |
| { | |
| "question": train_dataset[idx]["question"], | |
| "context": train_dataset[idx]["context"], | |
| "answers": train_dataset[idx]["answers"]["text"], | |
| } | |
| for idx in example_idxs | |
| ] | |
| for idx, ic_language in enumerate(context): | |
| ( | |
| selected_examples.append(ic_examples[idx]) | |
| if ic_language == lang | |
| else ( | |
| selected_examples.append( | |
| _translate_example( | |
| example=ic_examples[idx], | |
| src_language=lang, | |
| target_language=ic_language, | |
| ) | |
| ) | |
| ) | |
| ) | |
| return selected_examples | |
| def normalize_answer(s): | |
| """Lower text and remove punctuation, articles and extra whitespace.""" | |
| def remove_articles(text): | |
| return re.sub(r"\b(a|an|the)\b", " ", text) | |
| def white_space_fix(text): | |
| return " ".join(text.split()) | |
| def remove_punc(text): | |
| exclude = set(PUNCT) # set(string.punctuation) | |
| return "".join(ch for ch in text if ch not in exclude) | |
| def lower(text): | |
| return text.lower() | |
| return white_space_fix(remove_articles(remove_punc(lower(s)))) | |
| def process_test_example( | |
| test_data, config_header, idx, test_example, config, zero_shot, lang, params | |
| ): | |
| try: | |
| # Your existing code for processing each test example | |
| instruction = create_instruction( | |
| lang=config["prefix"], expected_output=config["output"] | |
| ) | |
| text_example = { | |
| "question": test_example["question"], | |
| "context": test_example["context"], | |
| "answers": test_example["answers"]["text"], | |
| } | |
| ic_examples = [] | |
| if not zero_shot: | |
| ic_examples = choose_few_shot_examples( | |
| train_dataset=test_data, | |
| few_shot_size=len(config["context"]), | |
| context=config["context"], | |
| selection_criteria="random", | |
| lang=params["selected_language"], | |
| ) | |
| prompt, label = construct_prompt( | |
| instruction=instruction, | |
| test_example=text_example, | |
| ic_examples=ic_examples, | |
| zero_shot=zero_shot, | |
| lang=lang, | |
| config=config, | |
| ) | |
| pred = gpt3x_completion(prompt=prompt) | |
| print(pred) | |
| logger.info("Saving prediction to persistent volume") | |
| os.makedirs( | |
| f"{params['response_logger_root']}/{params['model']}/{lang}", exist_ok=True | |
| ) | |
| dump_predictions( | |
| idx=idx, | |
| response=pred, | |
| label=label, | |
| response_logger_file=f"{params['response_logger_root']}/{params['model']}/{lang}/{config_header}.csv", | |
| ) | |
| except Exception as e: | |
| # Handle exceptions here | |
| print(f"Error processing example {idx}: {e}") | |
| def run_one_configuration(selected_language, config, zero_shot, dataset_name, limit=10): | |
| test_data = load_qa_dataset( | |
| dataset_name=dataset_name, | |
| lang=selected_language, | |
| split="validation" if dataset_name == "xquad" else "test", | |
| limit=limit, | |
| ) | |
| for idx, test_example in (pbar := tqdm(enumerate(test_data))): | |
| try: | |
| instruction = create_instruction( | |
| lang=config["prefix"], expected_output=config["output"] | |
| ) | |
| text_example = { | |
| "question": test_example["question"], | |
| "context": test_example["context"], | |
| "answers": test_example["answers"]["text"], | |
| } | |
| ic_examples = [] | |
| if not zero_shot: | |
| ic_examples = choose_few_shot_examples( | |
| train_dataset=test_data, | |
| few_shot_size=len(config["context"]), | |
| context=config["context"], | |
| selection_criteria="random", | |
| lang=selected_language, | |
| ) | |
| prompt, label = construct_prompt( | |
| instruction=instruction, | |
| test_example=text_example, | |
| ic_examples=ic_examples, | |
| zero_shot=zero_shot, | |
| lang=selected_language, | |
| config=config, | |
| ) | |
| pred = gpt3x_completion(prompt=prompt) | |
| return pred | |
| except Exception as e: | |
| print(f"Found an exception {e}, continue to the next example") | |
| continue | |
| QA = "QA" | |
| SUMMARIZATION = "Summarization" | |
| NLI = "NLI" | |
| NER = "NER" | |
| def construct_generic_prompt(task, instruction, test_example, zero_shot, num_examples, selected_language, dataset, | |
| config): | |
| print(task) | |
| if task == SUMMARIZATION: | |
| prompt = summarization.construct_prompt( | |
| instruction=instruction, | |
| test_example=test_example, | |
| zero_shot=zero_shot, | |
| dataset=dataset, | |
| num_examples=num_examples, | |
| lang=str(selected_language).lower(), | |
| config=config, | |
| ) | |
| elif task == NER: | |
| prompt = ner.construct_prompt( | |
| instruction=instruction, | |
| test_example=test_example, | |
| zero_shot=zero_shot, | |
| num_examples=num_examples, | |
| lang=str(selected_language).lower(), | |
| config=config, | |
| ) | |
| elif task == QA: | |
| prompt = qa.construct_prompt( | |
| instruction=instruction, | |
| test_example=test_example, | |
| zero_shot=zero_shot, | |
| num_examples=num_examples, | |
| lang=str(selected_language).lower(), | |
| config=config, | |
| # dataset_name=dataset | |
| ) | |
| else: | |
| prompt = nli.construct_prompt( | |
| instruction=instruction, | |
| test_example=test_example, | |
| zero_shot=zero_shot, | |
| num_examples=num_examples, | |
| lang=str(selected_language).lower(), | |
| config=config, | |
| ) | |
| return prompt | |
| def _get_language_type(language: str): | |
| df = pd.read_csv("utils/languages_by_word_count.csv") | |
| number_of_words = df[df['Language'] == language]['number of words'].iloc[0] | |
| print(number_of_words) | |
| return LanguageType.Low if number_of_words < 150276400 else LanguageType.High | |
| class Config: | |
| def __init__(self, prefix="source", context="source", examples="source", output="source"): | |
| self.prefix = prefix | |
| self.context = context | |
| self.examples = examples | |
| self.output = output | |
| def set(self, prefix=None, context=None, examples=None, output=None): | |
| if prefix: self.prefix = prefix | |
| if context: self.context = context | |
| if examples: self.examples = examples | |
| if output: self.output = output | |
| def to_dict(self): | |
| return { | |
| 'prefix': self.prefix, | |
| 'context': self.context, | |
| 'examples': self.examples, | |
| 'output': self.output | |
| } | |
| def recommend_config(task, lang, model_type): | |
| print(task) | |
| print(model_type) | |
| language_type = _get_language_type(lang) | |
| config = Config() | |
| print(language_type) | |
| if task == QA: | |
| if model_type == ModelType.English.value: | |
| config.set(prefix='source', context='source', examples='source', output='source') | |
| else: | |
| config.set(prefix='english', context='source', examples='source', output='source') | |
| if task == NER: | |
| if model_type == ModelType.English.value: | |
| config.set(prefix='source', context='source', examples='source', output='source') | |
| elif language_type == LanguageType.High: | |
| config.set(prefix='english', context='source', examples='source', output='source') | |
| else: | |
| config.set(prefix='english', context='source', examples='source', output='english') | |
| if task == NLI: | |
| if model_type == ModelType.English.value: | |
| config.set(prefix='source', context='source', examples='source', output='source') | |
| elif language_type == LanguageType.High: | |
| print("here") | |
| config.set(prefix='english', context='source', examples='english') | |
| else: | |
| print("here1") | |
| config.set(prefix='english', context='english', examples='english') | |
| if task == SUMMARIZATION: | |
| config.set(context='english') | |
| return config.to_dict() | |