Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
| import ast | |
| import json | |
| import re | |
| from pathlib import Path | |
| import requests | |
| from backend.config import get_example_config | |
| def group_files_by_index(file_paths, data_type="audio"): | |
| # Regular expression pattern to extract the key from each image path | |
| if data_type == "audio": | |
| pattern = r"audio_(\d+).(png|wav)" | |
| elif data_type == "video": | |
| pattern = r"video_(\d+).(png|mp4)" | |
| else: | |
| pattern = r"img_(\d+).png" | |
| # Dictionary to store the grouped files | |
| grouped_files = {} | |
| # Iterate over each file path | |
| for file_path in file_paths: | |
| # Extract the key using the regular expression pattern | |
| match = re.search(pattern, file_path) | |
| if match: | |
| key = int(match.group(1)) | |
| # Add the file path to the corresponding group in the dictionary | |
| if key not in grouped_files: | |
| grouped_files[key] = [] | |
| grouped_files[key].append(file_path) | |
| # Sort the dictionary by keys | |
| sorted_grouped_files = dict(sorted(grouped_files.items())) | |
| return sorted_grouped_files | |
| def build_description( | |
| i, data_none, data_attack, quality_metrics=["psnr", "ssim", "lpips"] | |
| ): | |
| # TODO: handle this at data generation | |
| if isinstance(data_none["fake_det"], str): | |
| data_none["fake_det"] = ast.literal_eval(data_none["fake_det"]) | |
| if isinstance(data_none["watermark_det"], str): | |
| data_none["watermark_det"] = ast.literal_eval(data_none["watermark_det"]) | |
| if isinstance(data_attack["fake_det"], str): | |
| data_attack["fake_det"] = ast.literal_eval(data_attack["fake_det"]) | |
| if isinstance(data_attack["watermark_det"], str): | |
| data_attack["watermark_det"] = ast.literal_eval(data_attack["watermark_det"]) | |
| if i == 0: | |
| fake_det = data_none["fake_det"] | |
| return {"detected": fake_det} | |
| elif i == 1: | |
| # Fixed metrics | |
| det = data_none["watermark_det"] | |
| log10_p_value = float(data_none["log10_p_value"]) | |
| word_acc = data_attack["word_acc"] | |
| bit_acc = data_none["bit_acc"] | |
| # Dynamic metrics | |
| metrics_output = {} | |
| for metric in quality_metrics: | |
| value = float(data_none[metric]) | |
| metrics_output[metric] = round(value, 3) | |
| # Fixed metrics output | |
| metrics_output.update( | |
| { | |
| "detected": det, | |
| "log10_p_value": round(log10_p_value, 3), | |
| "bit_acc": round(bit_acc, 3), | |
| "word_acc": word_acc, | |
| } | |
| ) | |
| return metrics_output | |
| elif i == 2: | |
| fake_det = data_attack["fake_det"] | |
| return {"detected": fake_det} | |
| elif i == 3: # REVISIT THIS, it used to be == 3 | |
| det = data_attack["watermark_det"] | |
| log10_p_value = float(data_attack["log10_p_value"]) | |
| word_acc = data_attack["word_acc"] | |
| bit_acc = data_attack["bit_acc"] | |
| return { | |
| "detected": det, | |
| "log10_p_value": round(log10_p_value, 3), | |
| "bit_acc": round(bit_acc, 3), | |
| "word_acc": word_acc, | |
| } | |
| def build_infos(abs_path: Path, datatype: str, dataset_name: str, db_key: str): | |
| def generate_file_patterns(prefixes, extensions, indices): | |
| return [ | |
| f"{prefix}_{index:05d}.{ext}" | |
| for prefix in prefixes | |
| for index in indices | |
| for ext in extensions | |
| ] | |
| if datatype == "audio": | |
| quality_metrics = ["snr", "sisnr", "stoi", "pesq"] | |
| extensions = ["wav"] | |
| datatype_abbr = "audio" | |
| indices = [0, 1, 3, 4, 5] | |
| elif datatype == "image": | |
| quality_metrics = ["psnr", "ssim", "lpips"] | |
| extensions = ["png"] | |
| datatype_abbr = "img" | |
| indices = list(range(20)) | |
| elif datatype == "video": | |
| quality_metrics = ["psnr", "ssim", "lpips", "msssim", "vmaf"] | |
| extensions = ["mp4"] | |
| datatype_abbr = "video" | |
| indices = [0, 1, 3, 4, 5] | |
| eval_results_path = abs_path + f"{dataset_name}/examples_eval_results.json" | |
| # Determine if eval_results_path is a URL or local file | |
| if eval_results_path.startswith("http://") or eval_results_path.startswith( | |
| "https://" | |
| ): | |
| response = requests.get(eval_results_path) | |
| if response.status_code == 200: | |
| results_data = response.json() | |
| else: | |
| return {} | |
| else: | |
| try: | |
| with open(eval_results_path, "r") as f: | |
| results_data = json.load(f) | |
| except Exception as e: | |
| print(f"Failed to load local file: {e}") | |
| return {} | |
| dataset = results_data["eval"][db_key] | |
| prefixes = [ | |
| f"attacked_{datatype_abbr}", | |
| f"attacked_wmd_{datatype_abbr}", | |
| f"{datatype_abbr}", | |
| f"wmd_{datatype_abbr}", | |
| ] | |
| file_patterns = generate_file_patterns(prefixes, extensions, indices) | |
| infos = {} | |
| for model_name in dataset.keys(): | |
| model_infos = {} | |
| default_attack_name = "none" | |
| if datatype == "audio": | |
| default_attack_name = "identity" | |
| elif datatype == "video": | |
| default_attack_name = "Identity" | |
| identity_attack_rows = dataset[model_name][default_attack_name]["default"] | |
| for attack_name, attack_variants_data in dataset[model_name].items(): | |
| for attack_variant, attack_rows in attack_variants_data.items(): | |
| if attack_variant == "default": | |
| attack = attack_name | |
| else: | |
| attack = f"{attack_name}_{attack_variant}" | |
| if len(attack_rows) == 0: | |
| model_infos[attack] = [] | |
| continue | |
| file_paths = [ | |
| f"{abs_path}{dataset_name}/examples/{datatype}/{model_name}/{attack}/{pattern}" | |
| for pattern in file_patterns | |
| ] | |
| all_files = [] | |
| for i, files in group_files_by_index( | |
| file_paths, | |
| data_type=datatype, | |
| ).items(): | |
| data_none = [e for e in identity_attack_rows if e["idx"] == i][0] | |
| data_attack = [e for e in attack_rows if e["idx"] == i][0] | |
| files = sorted( | |
| [(f, Path(f).stem) for f in files], key=lambda x: x[1] | |
| ) | |
| files = files[2:] + files[:2] | |
| new_files = [] | |
| for variant_i, (file, name) in enumerate(files): | |
| file_info = { | |
| "name": name, | |
| "metadata": build_description( | |
| variant_i, data_none, data_attack, quality_metrics | |
| ), | |
| } | |
| if datatype == "audio": | |
| file_info["image_url"] = file.replace(".wav", ".png") | |
| file_info["audio_url"] = file | |
| elif datatype == "video": | |
| # file_info["image_url"] = file.replace(".mp4", ".png") | |
| file_info["video_url"] = file | |
| else: | |
| file_info["image_url"] = file | |
| new_files.append(file_info) | |
| all_files.extend(new_files) | |
| model_infos[attack] = all_files | |
| infos[model_name] = model_infos | |
| return infos | |
| def get_examples_tab(datatype: str): | |
| config = get_example_config(datatype) | |
| infos = build_infos( | |
| config["path"], | |
| datatype=datatype, | |
| dataset_name=config["dataset_name"], | |
| db_key=config["db_key"], | |
| ) | |
| return infos | |