import json import sys import random from collections import defaultdict def collect_dataset_info(file_path): """收集数据集信息,包括每个数据集的行号列表和首次出现顺序""" dataset_lines = defaultdict(list) order = [] seen = set() with open(file_path, 'r') as f: for line_num, line in enumerate(f, 1): try: data = json.loads(line.strip()) custom_id = data['custom_id'] dataset = custom_id.split('-')[0] if dataset not in seen: order.append(dataset) seen.add(dataset) dataset_lines[dataset].append(line_num) except json.JSONDecodeError: print(f"Error: Invalid JSON at line {line_num}", file=sys.stderr) except KeyError: print(f"Error: Missing 'custom_id' at line {line_num}", file=sys.stderr) except IndexError: print(f"Error: Invalid custom_id format at line {line_num}", file=sys.stderr) return dataset_lines, order def main(): if len(sys.argv) != 4: print("Usage: python sample_datasets.py ") sys.exit(1) input_file = sys.argv[1] output_file = sys.argv[2] try: N = int(sys.argv[3]) except ValueError: print("Error: N must be an integer.") sys.exit(1) # 收集数据集信息 dataset_info, dataset_order = collect_dataset_info(input_file) k = len(dataset_info) if k == 0: print("Error: No datasets found in the input file.") sys.exit(1) # 检查每个数据集是否有至少5个样本 for dataset, lines in dataset_info.items(): if len(lines) < 5: print(f"Error: Dataset '{dataset}' has fewer than 5 samples.") sys.exit(1) total_samples = sum(len(lines) for lines in dataset_info.values()) min_samples = 5 * k if N < min_samples or N > total_samples: print(f"Error: N must be between {min_samples} and {total_samples}.") sys.exit(1) # 计算可用样本数和剩余需要分配的样本数 available = {dataset: len(lines) - 5 for dataset, lines in dataset_info.items()} total_available = sum(available.values()) R = N - 5 * k if R > total_available: print(f"Error: Cannot allocate {R} samples from available {total_available}.") sys.exit(1) # 计算每个数据集分配的剩余样本数 allocations = [] sum_avail = total_available if total_available != 0 else 1 # 避免除以零 for dataset in dataset_order: avail = available[dataset] alloc_float = R * avail / sum_avail allocations.append(alloc_float) integer_part = [int(alloc) for alloc in allocations] remainders = [alloc - int_part for alloc, int_part in zip(allocations, integer_part)] remainder_total = R - sum(integer_part) # 分配余数 remainder_indices = sorted(enumerate(remainders), key=lambda x: (-x[1], x[0])) for i in range(remainder_total): idx = remainder_indices[i][0] integer_part[idx] += 1 # 计算每个数据集的最终采样数 sample_counts = {} for i, dataset in enumerate(dataset_order): alloc = integer_part[i] if alloc > available[dataset]: print(f"Error: Allocation for dataset '{dataset}' exceeds available samples.") sys.exit(1) sample_counts[dataset] = 5 + alloc # 打印采样分布信息(新增部分) print("\nSampling Distribution:") total_sampled = 0 for dataset in dataset_order: count = sample_counts[dataset] total_sampled += count print(f" - {dataset}: {count} samples") print(f"Total samples: {total_sampled} (target: {N})") # 验证总数正确性 if total_sampled != N: print(f"Error: Total sampled count mismatch ({total_sampled} vs {N})") sys.exit(1) # 随机选择行号 selected_lines = [] for dataset in dataset_order: lines = dataset_info[dataset] count = sample_counts[dataset] selected = random.sample(lines, count) selected_lines.extend(selected) selected_lines.sort() # 写入输出文件 current_idx = 0 total_selected = len(selected_lines) with open(input_file, 'r') as infile, open(output_file, 'w') as outfile: for line_num, line in enumerate(infile, 1): if current_idx >= total_selected: break if line_num == selected_lines[current_idx]: outfile.write(line) current_idx += 1 print(f"\nSuccessfully sampled {N} records to {output_file}.") if __name__ == "__main__": main()