tools / utils /upload /request_create.py
Adinosaur's picture
Upload folder using huggingface_hub
1c980b1 verified
import json
import sys
import random
from collections import defaultdict
def collect_dataset_info(file_path):
"""收集数据集信息,包括每个数据集的行号列表和首次出现顺序"""
dataset_lines = defaultdict(list)
order = []
seen = set()
with open(file_path, 'r') as f:
for line_num, line in enumerate(f, 1):
try:
data = json.loads(line.strip())
custom_id = data['custom_id']
dataset = custom_id.split('-')[0]
if dataset not in seen:
order.append(dataset)
seen.add(dataset)
dataset_lines[dataset].append(line_num)
except json.JSONDecodeError:
print(f"Error: Invalid JSON at line {line_num}", file=sys.stderr)
except KeyError:
print(f"Error: Missing 'custom_id' at line {line_num}", file=sys.stderr)
except IndexError:
print(f"Error: Invalid custom_id format at line {line_num}", file=sys.stderr)
return dataset_lines, order
def main():
if len(sys.argv) != 4:
print("Usage: python sample_datasets.py <input.jsonl> <output.jsonl> <N>")
sys.exit(1)
input_file = sys.argv[1]
output_file = sys.argv[2]
try:
N = int(sys.argv[3])
except ValueError:
print("Error: N must be an integer.")
sys.exit(1)
# 收集数据集信息
dataset_info, dataset_order = collect_dataset_info(input_file)
k = len(dataset_info)
if k == 0:
print("Error: No datasets found in the input file.")
sys.exit(1)
# 检查每个数据集是否有至少5个样本
for dataset, lines in dataset_info.items():
if len(lines) < 5:
print(f"Error: Dataset '{dataset}' has fewer than 5 samples.")
sys.exit(1)
total_samples = sum(len(lines) for lines in dataset_info.values())
min_samples = 5 * k
if N < min_samples or N > total_samples:
print(f"Error: N must be between {min_samples} and {total_samples}.")
sys.exit(1)
# 计算可用样本数和剩余需要分配的样本数
available = {dataset: len(lines) - 5 for dataset, lines in dataset_info.items()}
total_available = sum(available.values())
R = N - 5 * k
if R > total_available:
print(f"Error: Cannot allocate {R} samples from available {total_available}.")
sys.exit(1)
# 计算每个数据集分配的剩余样本数
allocations = []
sum_avail = total_available if total_available != 0 else 1 # 避免除以零
for dataset in dataset_order:
avail = available[dataset]
alloc_float = R * avail / sum_avail
allocations.append(alloc_float)
integer_part = [int(alloc) for alloc in allocations]
remainders = [alloc - int_part for alloc, int_part in zip(allocations, integer_part)]
remainder_total = R - sum(integer_part)
# 分配余数
remainder_indices = sorted(enumerate(remainders), key=lambda x: (-x[1], x[0]))
for i in range(remainder_total):
idx = remainder_indices[i][0]
integer_part[idx] += 1
# 计算每个数据集的最终采样数
sample_counts = {}
for i, dataset in enumerate(dataset_order):
alloc = integer_part[i]
if alloc > available[dataset]:
print(f"Error: Allocation for dataset '{dataset}' exceeds available samples.")
sys.exit(1)
sample_counts[dataset] = 5 + alloc
# 打印采样分布信息(新增部分)
print("\nSampling Distribution:")
total_sampled = 0
for dataset in dataset_order:
count = sample_counts[dataset]
total_sampled += count
print(f" - {dataset}: {count} samples")
print(f"Total samples: {total_sampled} (target: {N})")
# 验证总数正确性
if total_sampled != N:
print(f"Error: Total sampled count mismatch ({total_sampled} vs {N})")
sys.exit(1)
# 随机选择行号
selected_lines = []
for dataset in dataset_order:
lines = dataset_info[dataset]
count = sample_counts[dataset]
selected = random.sample(lines, count)
selected_lines.extend(selected)
selected_lines.sort()
# 写入输出文件
current_idx = 0
total_selected = len(selected_lines)
with open(input_file, 'r') as infile, open(output_file, 'w') as outfile:
for line_num, line in enumerate(infile, 1):
if current_idx >= total_selected:
break
if line_num == selected_lines[current_idx]:
outfile.write(line)
current_idx += 1
print(f"\nSuccessfully sampled {N} records to {output_file}.")
if __name__ == "__main__":
main()