Spaces:
Running
Running
| from crazy_functions.ipc_fns.mp import run_in_subprocess_with_timeout | |
| def force_breakdown(txt, limit, get_token_fn): | |
| """ 当无法用标点、空行分割时,我们用最暴力的方法切割 | |
| """ | |
| for i in reversed(range(len(txt))): | |
| if get_token_fn(txt[:i]) < limit: | |
| return txt[:i], txt[i:] | |
| return "Tiktoken未知错误", "Tiktoken未知错误" | |
| def maintain_storage(remain_txt_to_cut, remain_txt_to_cut_storage): | |
| """ 为了加速计算,我们采样一个特殊的手段。当 remain_txt_to_cut > `_max` 时, 我们把 _max 后的文字转存至 remain_txt_to_cut_storage | |
| 当 remain_txt_to_cut < `_min` 时,我们再把 remain_txt_to_cut_storage 中的部分文字取出 | |
| """ | |
| _min = int(5e4) | |
| _max = int(1e5) | |
| # print(len(remain_txt_to_cut), len(remain_txt_to_cut_storage)) | |
| if len(remain_txt_to_cut) < _min and len(remain_txt_to_cut_storage) > 0: | |
| remain_txt_to_cut = remain_txt_to_cut + remain_txt_to_cut_storage | |
| remain_txt_to_cut_storage = "" | |
| if len(remain_txt_to_cut) > _max: | |
| remain_txt_to_cut_storage = remain_txt_to_cut[_max:] + remain_txt_to_cut_storage | |
| remain_txt_to_cut = remain_txt_to_cut[:_max] | |
| return remain_txt_to_cut, remain_txt_to_cut_storage | |
| def cut(limit, get_token_fn, txt_tocut, must_break_at_empty_line, break_anyway=False): | |
| """ 文本切分 | |
| """ | |
| res = [] | |
| total_len = len(txt_tocut) | |
| fin_len = 0 | |
| remain_txt_to_cut = txt_tocut | |
| remain_txt_to_cut_storage = "" | |
| # 为了加速计算,我们采样一个特殊的手段。当 remain_txt_to_cut > `_max` 时, 我们把 _max 后的文字转存至 remain_txt_to_cut_storage | |
| remain_txt_to_cut, remain_txt_to_cut_storage = maintain_storage(remain_txt_to_cut, remain_txt_to_cut_storage) | |
| while True: | |
| if get_token_fn(remain_txt_to_cut) <= limit: | |
| # 如果剩余文本的token数小于限制,那么就不用切了 | |
| res.append(remain_txt_to_cut); fin_len+=len(remain_txt_to_cut) | |
| break | |
| else: | |
| # 如果剩余文本的token数大于限制,那么就切 | |
| lines = remain_txt_to_cut.split('\n') | |
| # 估计一个切分点 | |
| estimated_line_cut = limit / get_token_fn(remain_txt_to_cut) * len(lines) | |
| estimated_line_cut = int(estimated_line_cut) | |
| # 开始查找合适切分点的偏移(cnt) | |
| cnt = 0 | |
| for cnt in reversed(range(estimated_line_cut)): | |
| if must_break_at_empty_line: | |
| # 首先尝试用双空行(\n\n)作为切分点 | |
| if lines[cnt] != "": | |
| continue | |
| prev = "\n".join(lines[:cnt]) | |
| post = "\n".join(lines[cnt:]) | |
| if get_token_fn(prev) < limit: | |
| break | |
| if cnt == 0: | |
| # 如果没有找到合适的切分点 | |
| if break_anyway: | |
| # 是否允许暴力切分 | |
| prev, post = force_breakdown(remain_txt_to_cut, limit, get_token_fn) | |
| else: | |
| # 不允许直接报错 | |
| raise RuntimeError(f"存在一行极长的文本!{remain_txt_to_cut}") | |
| # 追加列表 | |
| res.append(prev); fin_len+=len(prev) | |
| # 准备下一次迭代 | |
| remain_txt_to_cut = post | |
| remain_txt_to_cut, remain_txt_to_cut_storage = maintain_storage(remain_txt_to_cut, remain_txt_to_cut_storage) | |
| process = fin_len/total_len | |
| print(f'正在文本切分 {int(process*100)}%') | |
| if len(remain_txt_to_cut.strip()) == 0: | |
| break | |
| return res | |
| def breakdown_text_to_satisfy_token_limit_(txt, limit, llm_model="gpt-3.5-turbo"): | |
| """ 使用多种方式尝试切分文本,以满足 token 限制 | |
| """ | |
| from request_llms.bridge_all import model_info | |
| enc = model_info[llm_model]['tokenizer'] | |
| def get_token_fn(txt): return len(enc.encode(txt, disallowed_special=())) | |
| try: | |
| # 第1次尝试,将双空行(\n\n)作为切分点 | |
| return cut(limit, get_token_fn, txt, must_break_at_empty_line=True) | |
| except RuntimeError: | |
| try: | |
| # 第2次尝试,将单空行(\n)作为切分点 | |
| return cut(limit, get_token_fn, txt, must_break_at_empty_line=False) | |
| except RuntimeError: | |
| try: | |
| # 第3次尝试,将英文句号(.)作为切分点 | |
| res = cut(limit, get_token_fn, txt.replace('.', '。\n'), must_break_at_empty_line=False) # 这个中文的句号是故意的,作为一个标识而存在 | |
| return [r.replace('。\n', '.') for r in res] | |
| except RuntimeError as e: | |
| try: | |
| # 第4次尝试,将中文句号(。)作为切分点 | |
| res = cut(limit, get_token_fn, txt.replace('。', '。。\n'), must_break_at_empty_line=False) | |
| return [r.replace('。。\n', '。') for r in res] | |
| except RuntimeError as e: | |
| # 第5次尝试,没办法了,随便切一下吧 | |
| return cut(limit, get_token_fn, txt, must_break_at_empty_line=False, break_anyway=True) | |
| breakdown_text_to_satisfy_token_limit = run_in_subprocess_with_timeout(breakdown_text_to_satisfy_token_limit_, timeout=60) | |
| if __name__ == '__main__': | |
| from crazy_functions.crazy_utils import read_and_clean_pdf_text | |
| file_content, page_one = read_and_clean_pdf_text("build/assets/at.pdf") | |
| from request_llms.bridge_all import model_info | |
| for i in range(5): | |
| file_content += file_content | |
| print(len(file_content)) | |
| TOKEN_LIMIT_PER_FRAGMENT = 2500 | |
| res = breakdown_text_to_satisfy_token_limit(file_content, TOKEN_LIMIT_PER_FRAGMENT) | |