|
|
import json |
|
|
import sys |
|
|
import random |
|
|
from collections import defaultdict |
|
|
|
|
|
def collect_dataset_info(file_path): |
|
|
"""收集数据集信息,包括每个数据集的行号列表和首次出现顺序""" |
|
|
dataset_lines = defaultdict(list) |
|
|
order = [] |
|
|
seen = set() |
|
|
|
|
|
with open(file_path, 'r') as f: |
|
|
for line_num, line in enumerate(f, 1): |
|
|
try: |
|
|
data = json.loads(line.strip()) |
|
|
custom_id = data['custom_id'] |
|
|
dataset = custom_id.split('-')[0] |
|
|
|
|
|
if dataset not in seen: |
|
|
order.append(dataset) |
|
|
seen.add(dataset) |
|
|
|
|
|
dataset_lines[dataset].append(line_num) |
|
|
except json.JSONDecodeError: |
|
|
print(f"Error: Invalid JSON at line {line_num}", file=sys.stderr) |
|
|
except KeyError: |
|
|
print(f"Error: Missing 'custom_id' at line {line_num}", file=sys.stderr) |
|
|
except IndexError: |
|
|
print(f"Error: Invalid custom_id format at line {line_num}", file=sys.stderr) |
|
|
|
|
|
return dataset_lines, order |
|
|
|
|
|
def main(): |
|
|
if len(sys.argv) != 4: |
|
|
print("Usage: python sample_datasets.py <input.jsonl> <output.jsonl> <N>") |
|
|
sys.exit(1) |
|
|
|
|
|
input_file = sys.argv[1] |
|
|
output_file = sys.argv[2] |
|
|
try: |
|
|
N = int(sys.argv[3]) |
|
|
except ValueError: |
|
|
print("Error: N must be an integer.") |
|
|
sys.exit(1) |
|
|
|
|
|
|
|
|
dataset_info, dataset_order = collect_dataset_info(input_file) |
|
|
k = len(dataset_info) |
|
|
|
|
|
if k == 0: |
|
|
print("Error: No datasets found in the input file.") |
|
|
sys.exit(1) |
|
|
|
|
|
|
|
|
for dataset, lines in dataset_info.items(): |
|
|
if len(lines) < 5: |
|
|
print(f"Error: Dataset '{dataset}' has fewer than 5 samples.") |
|
|
sys.exit(1) |
|
|
|
|
|
total_samples = sum(len(lines) for lines in dataset_info.values()) |
|
|
min_samples = 5 * k |
|
|
|
|
|
if N < min_samples or N > total_samples: |
|
|
print(f"Error: N must be between {min_samples} and {total_samples}.") |
|
|
sys.exit(1) |
|
|
|
|
|
|
|
|
available = {dataset: len(lines) - 5 for dataset, lines in dataset_info.items()} |
|
|
total_available = sum(available.values()) |
|
|
R = N - 5 * k |
|
|
|
|
|
if R > total_available: |
|
|
print(f"Error: Cannot allocate {R} samples from available {total_available}.") |
|
|
sys.exit(1) |
|
|
|
|
|
|
|
|
allocations = [] |
|
|
sum_avail = total_available if total_available != 0 else 1 |
|
|
|
|
|
for dataset in dataset_order: |
|
|
avail = available[dataset] |
|
|
alloc_float = R * avail / sum_avail |
|
|
allocations.append(alloc_float) |
|
|
|
|
|
integer_part = [int(alloc) for alloc in allocations] |
|
|
remainders = [alloc - int_part for alloc, int_part in zip(allocations, integer_part)] |
|
|
remainder_total = R - sum(integer_part) |
|
|
|
|
|
|
|
|
remainder_indices = sorted(enumerate(remainders), key=lambda x: (-x[1], x[0])) |
|
|
for i in range(remainder_total): |
|
|
idx = remainder_indices[i][0] |
|
|
integer_part[idx] += 1 |
|
|
|
|
|
|
|
|
sample_counts = {} |
|
|
for i, dataset in enumerate(dataset_order): |
|
|
alloc = integer_part[i] |
|
|
if alloc > available[dataset]: |
|
|
print(f"Error: Allocation for dataset '{dataset}' exceeds available samples.") |
|
|
sys.exit(1) |
|
|
sample_counts[dataset] = 5 + alloc |
|
|
|
|
|
|
|
|
print("\nSampling Distribution:") |
|
|
total_sampled = 0 |
|
|
for dataset in dataset_order: |
|
|
count = sample_counts[dataset] |
|
|
total_sampled += count |
|
|
print(f" - {dataset}: {count} samples") |
|
|
print(f"Total samples: {total_sampled} (target: {N})") |
|
|
|
|
|
|
|
|
if total_sampled != N: |
|
|
print(f"Error: Total sampled count mismatch ({total_sampled} vs {N})") |
|
|
sys.exit(1) |
|
|
|
|
|
|
|
|
selected_lines = [] |
|
|
for dataset in dataset_order: |
|
|
lines = dataset_info[dataset] |
|
|
count = sample_counts[dataset] |
|
|
selected = random.sample(lines, count) |
|
|
selected_lines.extend(selected) |
|
|
|
|
|
selected_lines.sort() |
|
|
|
|
|
|
|
|
current_idx = 0 |
|
|
total_selected = len(selected_lines) |
|
|
|
|
|
with open(input_file, 'r') as infile, open(output_file, 'w') as outfile: |
|
|
for line_num, line in enumerate(infile, 1): |
|
|
if current_idx >= total_selected: |
|
|
break |
|
|
if line_num == selected_lines[current_idx]: |
|
|
outfile.write(line) |
|
|
current_idx += 1 |
|
|
|
|
|
print(f"\nSuccessfully sampled {N} records to {output_file}.") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |