Spaces:
Runtime error
Runtime error
| """ | |
| Split long conversations based on certain max length. | |
| Usage: python3 -m fastchat.data.split_long_conversation \ | |
| --in sharegpt_clean.json \ | |
| --out sharegpt_split.json \ | |
| --model-name-or-path $<model-name> | |
| """ | |
| import argparse | |
| from concurrent.futures import ProcessPoolExecutor | |
| import json | |
| from typing import Dict, Sequence, Optional | |
| import transformers | |
| from tqdm import tqdm | |
| def make_sample(sample, start_idx, end_idx): | |
| assert (end_idx - start_idx) % 2 == 0 | |
| return { | |
| "id": sample["id"] + "_" + str(start_idx), | |
| "model": sample.get("model", ""), | |
| "conversations": sample["conversations"][start_idx:end_idx], | |
| } | |
| tokenizer = max_length = None | |
| def split_one_sample(sample): | |
| tokenized_lens = [] | |
| conversations = sample["conversations"] | |
| conversations = conversations[: len(conversations) // 2 * 2] | |
| for c in conversations: | |
| length = len(tokenizer(c["value"]).input_ids) + 6 | |
| tokenized_lens.append(length) | |
| start_idx = 0 | |
| cur_len = 0 | |
| if len(conversations) % 2 != 0 or len(conversations) < 2: | |
| return [] | |
| new_samples = [] | |
| for i in range(0, len(conversations), 2): | |
| tmp_len = tokenized_lens[i] + tokenized_lens[i + 1] | |
| if cur_len + tmp_len > max_length: | |
| new_samples.append(make_sample(sample, start_idx, i)) | |
| start_idx = i | |
| cur_len = 0 | |
| elif i == len(conversations) - 2: | |
| new_samples.append(make_sample(sample, start_idx, i + 2)) | |
| cur_len += tmp_len | |
| return new_samples | |
| def worker(input_data): | |
| result = [] | |
| for sample in input_data: | |
| result.extend(split_one_sample(sample)) | |
| return result | |
| def split_all(content, begin, end, tokenizer_, max_length_): | |
| """ | |
| Keep the maximum round of conversations within the max token length constraint | |
| """ | |
| global tokenizer, max_length | |
| tokenizer = tokenizer_ | |
| max_length = max_length_ | |
| content = content[begin:end] | |
| new_content = [] | |
| # Split content into chunks | |
| chunks = [content[i : i + 1000] for i in range(0, len(content), 1000)] | |
| with ProcessPoolExecutor() as executor: | |
| for result in tqdm(executor.map(worker, chunks), total=len(chunks)): | |
| new_content.extend(result) | |
| return new_content | |
| def filter_invalid_roles(content): | |
| new_content = [] | |
| for i, c in enumerate(content): | |
| roles = ["human", "gpt"] | |
| if len(c["conversations"]) <= 0: | |
| continue | |
| valid = True | |
| for j, s in enumerate(c["conversations"]): | |
| if s["from"] != roles[j % 2]: | |
| valid = False | |
| break | |
| if valid: | |
| new_content.append(c) | |
| return new_content | |
| def main(args): | |
| content = json.load(open(args.in_file, "r")) | |
| tokenizer = transformers.AutoTokenizer.from_pretrained( | |
| args.model_name_or_path, | |
| model_max_length=args.max_length, | |
| padding_side="right", | |
| use_fast=False, | |
| ) | |
| new_content = split_all(content, args.begin, args.end, tokenizer, args.max_length) | |
| new_content = filter_invalid_roles(new_content) | |
| print(f"#in: {len(content)}, #out: {len(new_content)}") | |
| json.dump(new_content, open(args.out_file, "w"), indent=2, ensure_ascii=False) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--in-file", type=str, required=True) | |
| parser.add_argument("--out-file", type=str, default="sharegpt_split.json") | |
| parser.add_argument("--begin", type=int) | |
| parser.add_argument("--end", type=int) | |
| parser.add_argument("--model-name-or-path", type=str, required=True) | |
| parser.add_argument("--max-length", type=int, default=2048) | |
| args = parser.parse_args() | |
| main(args) | |