Spaces:
Running
Running
| import nltk | |
| from nltk.tokenize import sent_tokenize | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM | |
| import torch | |
| import src.exception.Exception as ExceptionCustom | |
| # Use a pipeline as a high-level helper | |
| from transformers import pipeline | |
| METHOD = "TRANSLATE" | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| def paraphraseTranslateMethod(requestValue: str, model: str): | |
| nltk.download('punkt') | |
| nltk.download('punkt_tab') | |
| exception = ExceptionCustom.checkForException(requestValue, METHOD) | |
| if exception: | |
| return "", exception | |
| tokenized_sent_list = sent_tokenize(requestValue) | |
| result_value = [] | |
| for SENTENCE in tokenized_sent_list: | |
| if model == 'roen': | |
| tokenizerROMENG = AutoTokenizer.from_pretrained("BlackKakapo/opus-mt-ro-en") | |
| modelROMENG = AutoModelForSeq2SeqLM.from_pretrained("BlackKakapo/opus-mt-ro-en") | |
| modelROMENG.to(device) | |
| input_ids = tokenizerROMENG(SENTENCE, return_tensors='pt').to(device) | |
| output = modelROMENG.generate( | |
| input_ids=input_ids.input_ids, | |
| do_sample=True, | |
| max_length=512, | |
| top_k=90, | |
| top_p=0.97, | |
| early_stopping=False | |
| ) | |
| result = tokenizerROMENG.batch_decode(output, skip_special_tokens=True)[0] | |
| else: | |
| tokenizerENGROM = AutoTokenizer.from_pretrained("BlackKakapo/opus-mt-en-ro") | |
| modelENGROM = AutoModelForSeq2SeqLM.from_pretrained("BlackKakapo/opus-mt-en-ro") | |
| modelENGROM.to(device) | |
| input_ids = tokenizerENGROM(SENTENCE, return_tensors='pt').to(device) | |
| output = modelENGROM.generate( | |
| input_ids=input_ids.input_ids, | |
| do_sample=True, | |
| max_length=512, | |
| top_k=90, | |
| top_p=0.97, | |
| early_stopping=False | |
| ) | |
| result = tokenizerENGROM.batch_decode(output, skip_special_tokens=True)[0] | |
| result_value.append(result) | |
| return " ".join(result_value).strip(), model | |
| def gemma(requestValue: str, model: str = 'Gargaz/gemma-2b-romanian-better'): | |
| requestValue = requestValue.replace('\n', ' ') | |
| prompt = f"Translate this to Romanian using a formal tone, responding only with the translated text: {requestValue}" | |
| messages = [{"role": "user", "content": f"Translate this text to Romanian: {requestValue}"}] | |
| if '/' not in model: | |
| model = 'Gargaz/gemma-2b-romanian-better' | |
| # limit max_new_tokens to 150% of the requestValue | |
| max_new_tokens = int(len(requestValue) + len(requestValue) * 0.5) | |
| try: | |
| pipe = pipeline( | |
| "text-generation", | |
| model=model, | |
| device=-1, | |
| max_new_tokens=max_new_tokens, # Keep short to reduce verbosity | |
| do_sample=False # Use greedy decoding for determinism | |
| ) | |
| output = pipe(messages, num_return_sequences=1, return_full_text=False) | |
| generated_text = output[0]["generated_text"] | |
| result = generated_text.split('\n', 1)[0] if '\n' in generated_text else generated_text | |
| return result.strip() | |
| except Exception as error: | |
| return error | |
| def gemma_direct(requestValue: str, model: str = 'Gargaz/gemma-2b-romanian-better'): | |
| # Load model directly | |
| model_name = model if '/' in model else 'Gargaz/gemma-2b-romanian-better' | |
| # limit max_new_tokens to 150% of the requestValue | |
| prompt = f"Translate this text to Romanian: {requestValue}" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForCausalLM.from_pretrained(model_name).to(device) | |
| input_ids = tokenizer.encode(requestValue, add_special_tokens=True) | |
| num_tokens = len(input_ids) | |
| # Estimate output length (e.g., 50% longer) | |
| max_new_tokens = int(num_tokens * 1.5) | |
| max_new_tokens += max_new_tokens % 2 # ensure it's even | |
| messages = [{"role": "user", "content": prompt}] | |
| try: | |
| inputs = tokenizer.apply_chat_template( | |
| messages, | |
| add_generation_prompt=True, | |
| tokenize=True, | |
| return_dict=True, | |
| return_tensors="pt", | |
| ).to(device) | |
| outputs = model.generate(**inputs, max_new_tokens=max_new_tokens) | |
| response = tokenizer.decode(outputs[0][inputs["input_ids"].shape[-1]:], skip_special_tokens=True) | |
| result = response.split('\n', 1)[0] if '\n' in response else response | |
| return result.strip() | |
| except Exception as error: | |
| return error | |