kooktaeeee / utils.py
taekooktea's picture
Update utils.py
f54eedc verified
from transformers import AutoTokenizer
from optimum.onnxruntime import ORTModelForCausalLM, ORTOptions
from config import MODEL_NAME, MAX_NEW_TOKENS, TEMPERATURE, MAX_INPUT_LENGTH
options = ORTOptions(
enable_int8=True,
enable_dynamic_quantization=True,
enable_cpu_mem_optimization=True,
enable_flash_attention=False,
enable_sequential_execution=True
)
model = ORTModelForCausalLM.from_pretrained(
MODEL_NAME,
from_transformers=True,
ort_options=options,
device_map="cpu",
trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(
MODEL_NAME,
trust_remote_code=True,
padding_side="left"
)
def generate_response(input_texts):
inputs = tokenizer(
input_texts,
return_tensors="pt",
padding=True,
truncation=True,
max_length=MAX_INPUT_LENGTH,
add_special_tokens=True
)
outputs = model.generate(
**inputs,
max_new_tokens=MAX_NEW_TOKENS,
temperature=TEMPERATURE,
do_sample=False,
num_beams=1,
early_stopping=True,
use_cache=True,
pad_token_id=tokenizer.eos_token_id
)
return tokenizer.batch_decode(outputs, skip_special_tokens=True)
__all__ = ["generate_response", "tokenizer"]