|
|
import json |
|
|
import random |
|
|
import os |
|
|
|
|
|
|
|
|
NUM_TO_WORD = { |
|
|
1: "one", |
|
|
2: "two", |
|
|
3: "three", |
|
|
4: "four", |
|
|
} |
|
|
|
|
|
import torch |
|
|
from PIL import Image |
|
|
import numpy as np |
|
|
from diffusers import FluxPipeline |
|
|
from flow_grpo.diffusers_patch.flux_pipeline_with_logprob import pipeline_with_logprob |
|
|
import importlib |
|
|
|
|
|
model_id = "black-forest-labs/FLUX.1-dev" |
|
|
device = "cuda" |
|
|
|
|
|
pipe = FluxPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16) |
|
|
pipe = pipe.to(device) |
|
|
|
|
|
def process_jsonl(input_file, output_file, image_directory): |
|
|
""" |
|
|
处理输入的jsonl文件,并生成新的jsonl文件和图片。 |
|
|
|
|
|
Args: |
|
|
input_file (str): 输入的jsonl文件名。 |
|
|
output_file (str): 输出的jsonl文件名。 |
|
|
image_directory (str): 保存图片的目录。 |
|
|
""" |
|
|
|
|
|
if not os.path.exists(image_directory): |
|
|
os.makedirs(image_directory) |
|
|
|
|
|
with open(input_file, 'r', encoding='utf-8') as infile, \ |
|
|
open(output_file, 'w', encoding='utf-8') as outfile: |
|
|
for i, line in enumerate(infile): |
|
|
try: |
|
|
data = json.loads(line.strip()) |
|
|
|
|
|
|
|
|
original_count = data["include"][0]["count"] |
|
|
class_name = data["include"][0]["class"] |
|
|
|
|
|
image = pipe( |
|
|
data["t2i_prompt"], |
|
|
height=1024, |
|
|
width=1024, |
|
|
guidance_scale=3.5, |
|
|
num_inference_steps=50, |
|
|
max_sequence_length=512, |
|
|
).images[0] |
|
|
image_path = os.path.join(image_directory, f"image_{i}.jpg") |
|
|
image.save(image_path) |
|
|
|
|
|
|
|
|
change_num = set([1, 2, 3, 4]) - set([original_count]) |
|
|
for num in change_num: |
|
|
new_data = { |
|
|
"tag": data["tag"], |
|
|
"include": [{"class": class_name, "count": num}], |
|
|
"exclude": [{"class": class_name, "count": num + 1}], |
|
|
"t2i_prompt": data["t2i_prompt"], |
|
|
"prompt": f"Change the number of {class_name} in the image to {NUM_TO_WORD[num]}.", |
|
|
"image": image_path |
|
|
} |
|
|
|
|
|
|
|
|
outfile.write(json.dumps(new_data, ensure_ascii=False) + '\n') |
|
|
|
|
|
except (json.JSONDecodeError, KeyError, IndexError) as e: |
|
|
print(f"处理第 {i+1} 行时出错: {e}") |
|
|
continue |
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
|
|
input_filename = "metadata.jsonl" |
|
|
output_filename = "output.jsonl" |
|
|
image_save_directory = "generated_images" |
|
|
|
|
|
|
|
|
process_jsonl(input_filename, output_filename, image_save_directory) |
|
|
|
|
|
print(f"处理完成!结果已保存到 '{output_filename}',图片路径保存在 '{image_save_directory}' 目录。") |