taekooktea commited on
Commit
f54eedc
·
verified ·
1 Parent(s): 63b9d67

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +16 -22
utils.py CHANGED
@@ -2,32 +2,28 @@ from transformers import AutoTokenizer
2
  from optimum.onnxruntime import ORTModelForCausalLM, ORTOptions
3
  from config import MODEL_NAME, MAX_NEW_TOKENS, TEMPERATURE, MAX_INPUT_LENGTH
4
 
5
- # 模型加载:CPU专属极速配置(INT8量化+内存优化,无无效计算)
6
  options = ORTOptions(
7
- enable_int8=True, # 核心:INT8量化,CPU计算量减半
8
- enable_dynamic_quantization=True, # 动态量化,适配不同输入长度
9
- enable_cpu_mem_optimization=True, # 新增:优化CPU内存分配,避免卡顿
10
- enable_flash_attention=False, # 关键:CPU不支持FlashAttention,关闭省检测耗时
11
- enable_sequential_execution=True # 适配CPU单核心/低核心,避免线程切换浪费
12
  )
13
 
14
- # 加载ONNX模型(明确指定CPU,跳过设备检测)
15
  model = ORTModelForCausalLM.from_pretrained(
16
  MODEL_NAME,
17
  from_transformers=True,
18
  ort_options=options,
19
- device_map="cpu", # 锁定CPU,避免资源分配耗时
20
- trust_remote_code=True # 兼容Phi-3-mini的ONNX格式,避免加载报错
21
  )
22
  tokenizer = AutoTokenizer.from_pretrained(
23
  MODEL_NAME,
24
  trust_remote_code=True,
25
- padding_side="left" # 优化批量推理时的padding效率
26
  )
27
 
28
- # 推理函数(和模型配置对齐,无多余计算)
29
  def generate_response(input_texts):
30
- # 输入处理:精简token,避免冗余
31
  inputs = tokenizer(
32
  input_texts,
33
  return_tensors="pt",
@@ -36,18 +32,16 @@ def generate_response(input_texts):
36
  max_length=MAX_INPUT_LENGTH,
37
  add_special_tokens=True
38
  )
39
-
40
- # 生成逻辑:极速模式(单beam+早停,无随机采样)
41
  outputs = model.generate(
42
  **inputs,
43
  max_new_tokens=MAX_NEW_TOKENS,
44
  temperature=TEMPERATURE,
45
- do_sample=False, # 关闭随机采样,减少CPU计算
46
- num_beams=1, # 单beam搜索,比多beam快50%+
47
- early_stopping=True, # 遇到终止符即停,不做无用功
48
- use_cache=True, # 启用缓存,复用前序计算结果
49
- pad_token_id=tokenizer.eos_token_id # 统一pad与终止符,避免警告
50
  )
51
-
52
- # 输出解码:跳过特殊token,快速返回
53
- return tokenizer.batch_decode(outputs, skip_special_tokens=True)
 
2
  from optimum.onnxruntime import ORTModelForCausalLM, ORTOptions
3
  from config import MODEL_NAME, MAX_NEW_TOKENS, TEMPERATURE, MAX_INPUT_LENGTH
4
 
 
5
  options = ORTOptions(
6
+ enable_int8=True,
7
+ enable_dynamic_quantization=True,
8
+ enable_cpu_mem_optimization=True,
9
+ enable_flash_attention=False,
10
+ enable_sequential_execution=True
11
  )
12
 
 
13
  model = ORTModelForCausalLM.from_pretrained(
14
  MODEL_NAME,
15
  from_transformers=True,
16
  ort_options=options,
17
+ device_map="cpu",
18
+ trust_remote_code=True
19
  )
20
  tokenizer = AutoTokenizer.from_pretrained(
21
  MODEL_NAME,
22
  trust_remote_code=True,
23
+ padding_side="left"
24
  )
25
 
 
26
  def generate_response(input_texts):
 
27
  inputs = tokenizer(
28
  input_texts,
29
  return_tensors="pt",
 
32
  max_length=MAX_INPUT_LENGTH,
33
  add_special_tokens=True
34
  )
 
 
35
  outputs = model.generate(
36
  **inputs,
37
  max_new_tokens=MAX_NEW_TOKENS,
38
  temperature=TEMPERATURE,
39
+ do_sample=False,
40
+ num_beams=1,
41
+ early_stopping=True,
42
+ use_cache=True,
43
+ pad_token_id=tokenizer.eos_token_id
44
  )
45
+ return tokenizer.batch_decode(outputs, skip_special_tokens=True)
46
+
47
+ __all__ = ["generate_response", "tokenizer"]