Upload 3 files
Browse files- 圖片驗證碼識別.py +296 -0
- 微調驗證碼識別模型.py +248 -0
- 找錯誤識別圖片.py +311 -0
圖片驗證碼識別.py
ADDED
|
@@ -0,0 +1,296 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from torch.utils.data import DataLoader, Dataset
|
| 4 |
+
from torch.amp import autocast, GradScaler
|
| 5 |
+
from torchvision import transforms
|
| 6 |
+
from datasets import load_dataset
|
| 7 |
+
import numpy as np
|
| 8 |
+
import string
|
| 9 |
+
import math
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
import os
|
| 12 |
+
import json
|
| 13 |
+
from torch.optim.lr_scheduler import LambdaLR
|
| 14 |
+
# --- 1. 設定參數 ---
|
| 15 |
+
class CFG:
|
| 16 |
+
# 資料集與字元集
|
| 17 |
+
dataset_name = "gary109/captcha-synth-v3"
|
| 18 |
+
characters = string.digits + string.ascii_lowercase + string.ascii_uppercase # '0123...abc...ABC...'
|
| 19 |
+
|
| 20 |
+
# 圖片尺寸
|
| 21 |
+
img_width = 200
|
| 22 |
+
img_height = 50
|
| 23 |
+
|
| 24 |
+
# 模型參數
|
| 25 |
+
d_model = 256 # Transformer 的特徵維度 (embedding dim)
|
| 26 |
+
nhead = 8 # Transformer 的多頭注意力頭數
|
| 27 |
+
num_encoder_layers = 4 # Transformer Encoder 的層數
|
| 28 |
+
dim_feedforward = 1024 # Transformer 前饋網路的維度
|
| 29 |
+
|
| 30 |
+
# 訓練參數
|
| 31 |
+
epochs = 10
|
| 32 |
+
batch_size = 128
|
| 33 |
+
lr = 1e-4
|
| 34 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 35 |
+
|
| 36 |
+
# --- 2. 資料準備 ---
|
| 37 |
+
|
| 38 |
+
# PyTorch Dataset
|
| 39 |
+
class CaptchaDataset(Dataset):
|
| 40 |
+
def __init__(self, hf_dataset,char_to_id, transform=None):
|
| 41 |
+
self.dataset = hf_dataset
|
| 42 |
+
self.transform = transform
|
| 43 |
+
self.char_to_id = char_to_id
|
| 44 |
+
|
| 45 |
+
def __len__(self):
|
| 46 |
+
return len(self.dataset)
|
| 47 |
+
|
| 48 |
+
def __getitem__(self, idx):
|
| 49 |
+
item = self.dataset[idx]
|
| 50 |
+
image = item['image'].convert("L") # 轉為灰階
|
| 51 |
+
label = item['text']
|
| 52 |
+
|
| 53 |
+
if self.transform:
|
| 54 |
+
image = self.transform(image)
|
| 55 |
+
|
| 56 |
+
# 將文字標籤轉換為數字序列
|
| 57 |
+
label_encoded = [self.char_to_id[char] for char in label]
|
| 58 |
+
|
| 59 |
+
return image, torch.tensor(label_encoded, dtype=torch.long)
|
| 60 |
+
|
| 61 |
+
# Dataloader 的 Collate Function,用於處理不同長度的標籤
|
| 62 |
+
def collate_fn(batch):
|
| 63 |
+
images, labels = zip(*batch)
|
| 64 |
+
images = torch.stack(images, 0)
|
| 65 |
+
|
| 66 |
+
# 對標籤進行填充
|
| 67 |
+
label_lengths = torch.tensor([len(label) for label in labels], dtype=torch.long)
|
| 68 |
+
padded_labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=0)
|
| 69 |
+
|
| 70 |
+
return images, padded_labels, label_lengths
|
| 71 |
+
|
| 72 |
+
# 定義圖像轉換
|
| 73 |
+
transform = transforms.Compose([
|
| 74 |
+
transforms.Resize((CFG.img_height, CFG.img_width)),
|
| 75 |
+
transforms.ToTensor(),
|
| 76 |
+
transforms.Normalize((0.5,), (0.5,)) # 歸一化
|
| 77 |
+
])
|
| 78 |
+
|
| 79 |
+
# --- 3. 模型架構 (CNN + Transformer) ---
|
| 80 |
+
|
| 81 |
+
class PositionalEncoding(nn.Module):
|
| 82 |
+
def __init__(self, d_model, dropout=0.1, max_len=5000):
|
| 83 |
+
super(PositionalEncoding, self).__init__()
|
| 84 |
+
self.dropout = nn.Dropout(p=dropout)
|
| 85 |
+
pe = torch.zeros(max_len, d_model)
|
| 86 |
+
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
| 87 |
+
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
|
| 88 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
| 89 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
| 90 |
+
pe = pe.unsqueeze(0).transpose(0, 1)
|
| 91 |
+
self.register_buffer('pe', pe)
|
| 92 |
+
|
| 93 |
+
def forward(self, x):
|
| 94 |
+
x = x + self.pe[:x.size(0), :]
|
| 95 |
+
return self.dropout(x)
|
| 96 |
+
|
| 97 |
+
class CaptchaTransformer(nn.Module):
|
| 98 |
+
def __init__(self, num_classes):
|
| 99 |
+
super(CaptchaTransformer, self).__init__()
|
| 100 |
+
# CNN Backbone
|
| 101 |
+
self.cnn = nn.Sequential(
|
| 102 |
+
nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1),
|
| 103 |
+
nn.ReLU(),
|
| 104 |
+
nn.MaxPool2d(kernel_size=2, stride=2),
|
| 105 |
+
nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
|
| 106 |
+
nn.ReLU(),
|
| 107 |
+
nn.MaxPool2d(kernel_size=2, stride=2),
|
| 108 |
+
nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
|
| 109 |
+
nn.BatchNorm2d(256),
|
| 110 |
+
nn.ReLU(),
|
| 111 |
+
nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
|
| 112 |
+
nn.ReLU(),
|
| 113 |
+
nn.AdaptiveMaxPool2d((1, None)),
|
| 114 |
+
nn.Conv2d(256, CFG.d_model, kernel_size=3, stride=1, padding=1),
|
| 115 |
+
nn.BatchNorm2d(CFG.d_model),
|
| 116 |
+
nn.ReLU()
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
self.pos_encoder = PositionalEncoding(CFG.d_model)
|
| 120 |
+
encoder_layer = nn.TransformerEncoderLayer(
|
| 121 |
+
d_model=CFG.d_model, nhead=CFG.nhead,
|
| 122 |
+
dim_feedforward=CFG.dim_feedforward, dropout=0.1
|
| 123 |
+
)
|
| 124 |
+
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=CFG.num_encoder_layers)
|
| 125 |
+
|
| 126 |
+
# 輸出層
|
| 127 |
+
self.output_layer = nn.Linear(CFG.d_model, num_classes)
|
| 128 |
+
|
| 129 |
+
def forward(self, x):
|
| 130 |
+
# x shape: [batch_size, channels, height, width]
|
| 131 |
+
x = self.cnn(x)
|
| 132 |
+
# x shape: [batch_size, d_model, new_height, new_width]
|
| 133 |
+
|
| 134 |
+
# 準備 Transformer 輸入
|
| 135 |
+
# (W, N, E) -> (sequence_length, batch_size, embedding_dim)
|
| 136 |
+
x = x.squeeze(2) # 壓縮高度維度
|
| 137 |
+
x = x.permute(2, 0, 1) # [width, batch_size, d_model]
|
| 138 |
+
|
| 139 |
+
x = self.pos_encoder(x)
|
| 140 |
+
x = self.transformer_encoder(x)
|
| 141 |
+
|
| 142 |
+
# x shape: [width, batch_size, d_model]
|
| 143 |
+
output = self.output_layer(x)
|
| 144 |
+
|
| 145 |
+
# CTC Loss 需要 log_softmax
|
| 146 |
+
return nn.functional.log_softmax(output, dim=2)
|
| 147 |
+
|
| 148 |
+
# --- 4. 訓練與驗證 ---
|
| 149 |
+
|
| 150 |
+
def greedy_decode(preds, id_to_char_map):
|
| 151 |
+
decoded_texts = []
|
| 152 |
+
# preds shape: [seq_len, batch_size, num_classes]
|
| 153 |
+
preds = preds.permute(1, 0, 2) # -> [batch_size, seq_len, num_classes]
|
| 154 |
+
pred_indices = torch.argmax(preds, dim=2)
|
| 155 |
+
|
| 156 |
+
for indices in pred_indices:
|
| 157 |
+
text = []
|
| 158 |
+
last_char_id = 0
|
| 159 |
+
for char_id in indices:
|
| 160 |
+
char_id = char_id.item()
|
| 161 |
+
if char_id != 0 and char_id != last_char_id: # 忽略 blank 和連續重複
|
| 162 |
+
text.append(id_to_char_map[char_id])
|
| 163 |
+
last_char_id = char_id
|
| 164 |
+
decoded_texts.append("".join(text))
|
| 165 |
+
return decoded_texts
|
| 166 |
+
char_to_id, id_to_char, VOCAB_SIZE = {}, {}, 0
|
| 167 |
+
def main():
|
| 168 |
+
global char_to_id, id_to_char, VOCAB_SIZE
|
| 169 |
+
print(f"Using device: {CFG.device}")
|
| 170 |
+
|
| 171 |
+
# 載入資料集
|
| 172 |
+
print("Loading dataset from Hugging Face Hub...")
|
| 173 |
+
train_hf_dataset = load_dataset(CFG.dataset_name, split="train")
|
| 174 |
+
|
| 175 |
+
# 切分訓練集和驗證集
|
| 176 |
+
val_hf_dataset = load_dataset(CFG.dataset_name, split="validation")
|
| 177 |
+
print("Generating vocabulary from the dataset...")
|
| 178 |
+
vocab_path = "vocab.json"
|
| 179 |
+
if os.path.exists(vocab_path):
|
| 180 |
+
print(f"Loading vocabulary from {vocab_path}...")
|
| 181 |
+
with open(vocab_path, 'r', encoding='utf-8') as f:
|
| 182 |
+
characters = json.load(f)
|
| 183 |
+
else:
|
| 184 |
+
# 1. 遍歷數據集,收集所有字元到一個 set 中以確保唯一性
|
| 185 |
+
all_chars = set()
|
| 186 |
+
total_samples = len(train_hf_dataset)
|
| 187 |
+
for i in tqdm(range(total_samples), desc="Scanning labels"):
|
| 188 |
+
label = train_hf_dataset[i]['text']
|
| 189 |
+
all_chars.update(list(label))
|
| 190 |
+
total_samples = len(val_hf_dataset)
|
| 191 |
+
for i in tqdm(range(total_samples), desc="Scanning labels"):
|
| 192 |
+
label = val_hf_dataset[i]['text']
|
| 193 |
+
all_chars.update(list(label))
|
| 194 |
+
# 2. 將 set 轉換為排序後的 list,確保每次運行的順序都一樣
|
| 195 |
+
# 這對模型的可複現性至關重要!
|
| 196 |
+
characters = sorted(list(all_chars))
|
| 197 |
+
with open(vocab_path, 'w', encoding='utf-8') as f:
|
| 198 |
+
json.dump(characters, f, ensure_ascii=False, indent=2)
|
| 199 |
+
print(f"Vocabulary saved to {vocab_path}")
|
| 200 |
+
print(f"Unique characters found: {''.join(characters)}")
|
| 201 |
+
CFG.characters = "".join(characters) # 為了方便查看
|
| 202 |
+
|
| 203 |
+
# 建立字元對應 ID 的字典
|
| 204 |
+
|
| 205 |
+
char_to_id = {char: i + 1 for i, char in enumerate(CFG.characters)}
|
| 206 |
+
id_to_char = {i + 1: char for i, char in enumerate(CFG.characters)}
|
| 207 |
+
VOCAB_SIZE = len(CFG.characters) + 1 # +1 for CTC blank token at index 0
|
| 208 |
+
print(f"Total unique characters: {VOCAB_SIZE - 1}")
|
| 209 |
+
train_dataset = CaptchaDataset(train_hf_dataset,char_to_id=char_to_id, transform=transform)
|
| 210 |
+
val_dataset = CaptchaDataset(val_hf_dataset,char_to_id=char_to_id, transform=transform)
|
| 211 |
+
|
| 212 |
+
train_loader = DataLoader(train_dataset, batch_size=CFG.batch_size, shuffle=True, collate_fn=collate_fn, num_workers=4)
|
| 213 |
+
val_loader = DataLoader(val_dataset, batch_size=CFG.batch_size, shuffle=False, collate_fn=collate_fn, num_workers=4)
|
| 214 |
+
|
| 215 |
+
print(f"Train samples: {len(train_dataset)}, Val samples: {len(val_dataset)}")
|
| 216 |
+
# 初始化模型、損失函數、優化器
|
| 217 |
+
model = CaptchaTransformer(num_classes=VOCAB_SIZE).to(CFG.device)
|
| 218 |
+
criterion = nn.CTCLoss(blank=0, zero_infinity=True)
|
| 219 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=CFG.lr)
|
| 220 |
+
warmup_steps = 1000
|
| 221 |
+
def warmup_lambda(current_step):
|
| 222 |
+
if current_step < warmup_steps:
|
| 223 |
+
# 在 warmup 期間,學習率從 0 線性增加到 1.0
|
| 224 |
+
return float(current_step) / float(max(1, warmup_steps))
|
| 225 |
+
# warmup 之後,學習率保持不變 (乘以 1.0)
|
| 226 |
+
return 1.0
|
| 227 |
+
scheduler = LambdaLR(optimizer, lr_lambda=warmup_lambda)
|
| 228 |
+
scaler = GradScaler()
|
| 229 |
+
# 訓練迴圈
|
| 230 |
+
for epoch in range(CFG.epochs):
|
| 231 |
+
model.train() # <<< 移到 Epoch 迴圈的開頭
|
| 232 |
+
train_loss = 0
|
| 233 |
+
loop = tqdm(train_loader, leave=True)
|
| 234 |
+
|
| 235 |
+
for i, (images, labels, label_lengths) in enumerate(loop):
|
| 236 |
+
images = images.to(CFG.device)
|
| 237 |
+
labels = labels.to(CFG.device)
|
| 238 |
+
label_lengths = label_lengths.to(CFG.device)
|
| 239 |
+
|
| 240 |
+
# <<< 1. 將 zero_grad 移到迴圈內,並使用 set_to_none=True
|
| 241 |
+
optimizer.zero_grad(set_to_none=True)
|
| 242 |
+
|
| 243 |
+
# <<< 2. 將 autocast 只包裹前向傳播和損失計算
|
| 244 |
+
with autocast(device_type=CFG.device, dtype=torch.bfloat16):
|
| 245 |
+
preds = model(images)
|
| 246 |
+
input_lengths = torch.full(size=(preds.size(1),), fill_value=preds.size(0), dtype=torch.long)
|
| 247 |
+
loss = criterion(preds, labels, input_lengths, label_lengths)
|
| 248 |
+
|
| 249 |
+
# <<< 3. 後續的 scaler 操作和優化器步驟在 autocast 之外
|
| 250 |
+
scaler.scale(loss).backward()
|
| 251 |
+
scaler.unscale_(optimizer) # 在裁剪前 unscale 梯度
|
| 252 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
|
| 253 |
+
scaler.step(optimizer)
|
| 254 |
+
scaler.update()
|
| 255 |
+
|
| 256 |
+
scheduler.step() # 每一步都要更新學習率
|
| 257 |
+
|
| 258 |
+
train_loss += loss.item()
|
| 259 |
+
loop.set_description(f"Epoch [{epoch+1}/{CFG.epochs}]")
|
| 260 |
+
loop.set_postfix(loss=loss.item(), lr=optimizer.param_groups[0]['lr'])
|
| 261 |
+
|
| 262 |
+
# 驗證迴圈
|
| 263 |
+
model.eval()
|
| 264 |
+
correct_predictions = 0
|
| 265 |
+
total_predictions = 0
|
| 266 |
+
is_printed = False
|
| 267 |
+
with torch.no_grad():
|
| 268 |
+
for images, labels, label_lengths in tqdm(val_loader, desc="Validation", leave=True):
|
| 269 |
+
with autocast(device_type=CFG.device, dtype=torch.bfloat16):
|
| 270 |
+
images = images.to(CFG.device)
|
| 271 |
+
preds = model(images)
|
| 272 |
+
|
| 273 |
+
decoded_preds = greedy_decode(preds, id_to_char)
|
| 274 |
+
# 將 padded labels 轉回文字
|
| 275 |
+
original_texts = []
|
| 276 |
+
for label, length in zip(labels, label_lengths):
|
| 277 |
+
original_texts.append("".join([id_to_char[l.item()] for l in label[:length]]))
|
| 278 |
+
|
| 279 |
+
for pred, target in zip(decoded_preds, original_texts):
|
| 280 |
+
if not is_printed:
|
| 281 |
+
print(pred,target)
|
| 282 |
+
is_printed = True
|
| 283 |
+
if pred == target:
|
| 284 |
+
correct_predictions += 1
|
| 285 |
+
total_predictions += 1
|
| 286 |
+
|
| 287 |
+
accuracy = correct_predictions / total_predictions
|
| 288 |
+
print(f"Epoch {epoch+1}, Train Loss: {train_loss/len(train_loader):.4f}, Val Accuracy: {accuracy:.4f}")
|
| 289 |
+
|
| 290 |
+
# 保存模型
|
| 291 |
+
torch.save(model.state_dict(), "captcha_transformer.pth")
|
| 292 |
+
print("Model saved to captcha_transformer.pth")
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
if __name__ == "__main__":
|
| 296 |
+
main()
|
微調驗證碼識別模型.py
ADDED
|
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from torch.utils.data import DataLoader, Dataset
|
| 4 |
+
from torch.optim.lr_scheduler import CosineAnnealingLR
|
| 5 |
+
from torch.amp import autocast, GradScaler
|
| 6 |
+
from torchvision import transforms
|
| 7 |
+
from datasets import load_dataset
|
| 8 |
+
import numpy as np
|
| 9 |
+
import string
|
| 10 |
+
import math
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
import os
|
| 13 |
+
import json
|
| 14 |
+
|
| 15 |
+
# ===================================================================
|
| 16 |
+
# 幾乎所有的定義都和原檔案一樣,可以直接複製過來
|
| 17 |
+
# 這樣能確保權重可以被正確載入
|
| 18 |
+
# ===================================================================
|
| 19 |
+
|
| 20 |
+
# --- 1. 設定參數 (微調專用) ---
|
| 21 |
+
class CFG_FINETUNE:
|
| 22 |
+
# 載入的模型和詞表路徑
|
| 23 |
+
model_path = "captcha_transformer_best_finetune.pth" # 確保這是你保存的最佳模型檔名
|
| 24 |
+
vocab_path = "vocab.json"
|
| 25 |
+
|
| 26 |
+
# 資料集 (保持不變)
|
| 27 |
+
dataset_name = "gary109/captcha-synth-v3"
|
| 28 |
+
|
| 29 |
+
# 圖片尺寸 (保持不變)
|
| 30 |
+
img_width = 200
|
| 31 |
+
img_height = 50
|
| 32 |
+
|
| 33 |
+
# 模型參數 (必須和原模型完全一致!)
|
| 34 |
+
d_model = 256
|
| 35 |
+
nhead = 8
|
| 36 |
+
num_encoder_layers = 4
|
| 37 |
+
dim_feedforward = 1024
|
| 38 |
+
|
| 39 |
+
# 微調參數 (這是關鍵!)
|
| 40 |
+
epochs = 5 # 微調通常不需要太多輪
|
| 41 |
+
batch_size = 32
|
| 42 |
+
lr = 1e-5 # <<< 使用一個更小的學習率!
|
| 43 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 44 |
+
|
| 45 |
+
# --- 2. 需要用到的類別和函數 (從原檔案複製) ---
|
| 46 |
+
# (為了讓此腳本能獨立運行,我們把所有必要的定義都複製過來)
|
| 47 |
+
|
| 48 |
+
class CaptchaDataset(Dataset):
|
| 49 |
+
def __init__(self, hf_dataset, char_to_id, transform=None):
|
| 50 |
+
self.dataset = hf_dataset
|
| 51 |
+
self.transform = transform
|
| 52 |
+
self.char_to_id = char_to_id
|
| 53 |
+
def __len__(self):
|
| 54 |
+
return len(self.dataset)
|
| 55 |
+
def __getitem__(self, idx):
|
| 56 |
+
item = self.dataset[idx]
|
| 57 |
+
image = item['image'].convert("L")
|
| 58 |
+
label = item.get('label') or item.get('text')
|
| 59 |
+
if self.transform:
|
| 60 |
+
image = self.transform(image)
|
| 61 |
+
label_encoded = [self.char_to_id[char] for char in label]
|
| 62 |
+
return image, torch.tensor(label_encoded, dtype=torch.long)
|
| 63 |
+
|
| 64 |
+
def collate_fn(batch):
|
| 65 |
+
images, labels = zip(*batch)
|
| 66 |
+
images = torch.stack(images, 0)
|
| 67 |
+
label_lengths = torch.tensor([len(label) for label in labels], dtype=torch.long)
|
| 68 |
+
padded_labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=0)
|
| 69 |
+
return images, padded_labels, label_lengths
|
| 70 |
+
|
| 71 |
+
transform = transforms.Compose([
|
| 72 |
+
# 輕微的幾何扭曲,模擬字元黏連和變形
|
| 73 |
+
transforms.RandomAffine(
|
| 74 |
+
degrees=8, # 隨機旋轉 ±8 度
|
| 75 |
+
translate=(0.1, 0.1), # 隨機平移 10%
|
| 76 |
+
scale=(0.9, 1.1), # 隨機縮放 10%
|
| 77 |
+
shear=5 # 隨機錯切
|
| 78 |
+
),
|
| 79 |
+
transforms.RandomPerspective(distortion_scale=0.2, p=0.5), # 隨機透視變換
|
| 80 |
+
# 調整大小
|
| 81 |
+
transforms.Resize((CFG_FINETUNE.img_height, CFG_FINETUNE.img_width)),
|
| 82 |
+
# 顏色抖動
|
| 83 |
+
transforms.ColorJitter(brightness=0.4, contrast=0.4),
|
| 84 |
+
# 轉換為 Tensor
|
| 85 |
+
transforms.ToTensor(),
|
| 86 |
+
# 隨機擦除 (關鍵!),模擬干擾線或字元斷裂
|
| 87 |
+
# 注意:這個操作必須在 ToTensor 之後
|
| 88 |
+
transforms.RandomErasing(p=0.5, scale=(0.02, 0.1), ratio=(0.3, 3.3), value=0),
|
| 89 |
+
# 歸一化
|
| 90 |
+
transforms.Normalize((0.5,), (0.5,))
|
| 91 |
+
])
|
| 92 |
+
val_transform = transforms.Compose([
|
| 93 |
+
transforms.Resize((CFG_FINETUNE.img_height, CFG_FINETUNE.img_width)),
|
| 94 |
+
transforms.ToTensor(),
|
| 95 |
+
transforms.Normalize((0.5,), (0.5,))
|
| 96 |
+
])
|
| 97 |
+
class PositionalEncoding(nn.Module):
|
| 98 |
+
# ... (從原檔案完整複製 PositionalEncoding 的程式碼)
|
| 99 |
+
def __init__(self, d_model, dropout=0.1, max_len=5000):
|
| 100 |
+
super(PositionalEncoding, self).__init__()
|
| 101 |
+
self.dropout = nn.Dropout(p=dropout)
|
| 102 |
+
pe = torch.zeros(max_len, d_model)
|
| 103 |
+
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
| 104 |
+
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
|
| 105 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
| 106 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
| 107 |
+
pe = pe.unsqueeze(0).transpose(0, 1)
|
| 108 |
+
self.register_buffer('pe', pe)
|
| 109 |
+
def forward(self, x):
|
| 110 |
+
x = x + self.pe[:x.size(0), :]
|
| 111 |
+
return self.dropout(x)
|
| 112 |
+
|
| 113 |
+
class CaptchaTransformer(nn.Module):
|
| 114 |
+
# ... (從原檔案完整複製 CaptchaTransformer 的程式碼)
|
| 115 |
+
def __init__(self, num_classes):
|
| 116 |
+
super(CaptchaTransformer, self).__init__()
|
| 117 |
+
self.cnn = nn.Sequential(
|
| 118 |
+
nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1), nn.ReLU(),
|
| 119 |
+
nn.MaxPool2d(kernel_size=2, stride=2),
|
| 120 |
+
nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1), nn.ReLU(),
|
| 121 |
+
nn.MaxPool2d(kernel_size=2, stride=2),
|
| 122 |
+
nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(256), nn.ReLU(),
|
| 123 |
+
nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1), nn.ReLU(),
|
| 124 |
+
nn.AdaptiveMaxPool2d((1, None)),
|
| 125 |
+
nn.Conv2d(256, CFG_FINETUNE.d_model, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(CFG_FINETUNE.d_model), nn.ReLU()
|
| 126 |
+
)
|
| 127 |
+
self.pos_encoder = PositionalEncoding(CFG_FINETUNE.d_model)
|
| 128 |
+
encoder_layer = nn.TransformerEncoderLayer(d_model=CFG_FINETUNE.d_model, nhead=CFG_FINETUNE.nhead, dim_feedforward=CFG_FINETUNE.dim_feedforward, dropout=0.1)
|
| 129 |
+
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=CFG_FINETUNE.num_encoder_layers)
|
| 130 |
+
self.output_layer = nn.Linear(CFG_FINETUNE.d_model, num_classes)
|
| 131 |
+
def forward(self, x):
|
| 132 |
+
x = self.cnn(x)
|
| 133 |
+
x = x.squeeze(2)
|
| 134 |
+
x = x.permute(2, 0, 1)
|
| 135 |
+
x = self.pos_encoder(x)
|
| 136 |
+
x = self.transformer_encoder(x)
|
| 137 |
+
output = self.output_layer(x)
|
| 138 |
+
return nn.functional.log_softmax(output, dim=2)
|
| 139 |
+
|
| 140 |
+
def greedy_decode(preds, id_to_char_map):
|
| 141 |
+
# ... (從原檔案完整複製 greedy_decode 的程式碼)
|
| 142 |
+
decoded_texts = []
|
| 143 |
+
preds = preds.permute(1, 0, 2)
|
| 144 |
+
pred_indices = torch.argmax(preds, dim=2)
|
| 145 |
+
for indices in pred_indices:
|
| 146 |
+
text = []
|
| 147 |
+
last_char_id = 0
|
| 148 |
+
for char_id in indices:
|
| 149 |
+
char_id = char_id.item()
|
| 150 |
+
if char_id != 0 and char_id != last_char_id:
|
| 151 |
+
text.append(id_to_char_map[char_id])
|
| 152 |
+
last_char_id = char_id
|
| 153 |
+
decoded_texts.append("".join(text))
|
| 154 |
+
return decoded_texts
|
| 155 |
+
|
| 156 |
+
# --- 3. 微調主程式 ---
|
| 157 |
+
def finetune():
|
| 158 |
+
print(f"Starting fine-tuning process on device: {CFG_FINETUNE.device}")
|
| 159 |
+
|
| 160 |
+
# --- 載入詞表 ---
|
| 161 |
+
with open(CFG_FINETUNE.vocab_path, 'r', encoding='utf-8') as f:
|
| 162 |
+
characters = json.load(f)
|
| 163 |
+
|
| 164 |
+
char_to_id = {char: i + 1 for i, char in enumerate(characters)}
|
| 165 |
+
id_to_char = {i + 1: char for i, char in enumerate(characters)}
|
| 166 |
+
VOCAB_SIZE = len(characters) + 1
|
| 167 |
+
print(f"Vocabulary loaded. Size: {VOCAB_SIZE - 1}")
|
| 168 |
+
|
| 169 |
+
# --- 準備資料 ---
|
| 170 |
+
print("Loading dataset for fine-tuning...")
|
| 171 |
+
train_hf_dataset = load_dataset(CFG_FINETUNE.dataset_name, split="train")
|
| 172 |
+
val_hf_dataset = load_dataset(CFG_FINETUNE.dataset_name, split="validation")
|
| 173 |
+
train_dataset = CaptchaDataset(train_hf_dataset, char_to_id, transform=transform)
|
| 174 |
+
val_dataset = CaptchaDataset(val_hf_dataset, char_to_id, transform=val_transform)
|
| 175 |
+
train_loader = DataLoader(train_dataset, batch_size=CFG_FINETUNE.batch_size, shuffle=True, collate_fn=collate_fn, num_workers=8,pin_memory=True)
|
| 176 |
+
val_loader = DataLoader(val_dataset, batch_size=CFG_FINETUNE.batch_size*8, shuffle=False, collate_fn=collate_fn, num_workers=8,pin_memory=True)
|
| 177 |
+
|
| 178 |
+
# --- 關鍵步驟:初始化模型並載入預訓練權重 ---
|
| 179 |
+
model = CaptchaTransformer(num_classes=VOCAB_SIZE).to(CFG_FINETUNE.device)
|
| 180 |
+
print(f"Loading pre-trained weights from: {CFG_FINETUNE.model_path}")
|
| 181 |
+
model.load_state_dict(torch.load(CFG_FINETUNE.model_path, map_location=CFG_FINETUNE.device))
|
| 182 |
+
print("Weights loaded successfully.")
|
| 183 |
+
|
| 184 |
+
# --- 設定新的優化器和學習率排程器 ---
|
| 185 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=CFG_FINETUNE.lr)
|
| 186 |
+
scaler = GradScaler()
|
| 187 |
+
|
| 188 |
+
# 使用餘弦退火學習率排程器
|
| 189 |
+
total_steps = len(train_loader) * CFG_FINETUNE.epochs
|
| 190 |
+
scheduler = CosineAnnealingLR(optimizer, T_max=total_steps, eta_min=1e-11) # 讓學習率平滑下降到接近0
|
| 191 |
+
|
| 192 |
+
best_accuracy = 0.8979 # <<< 從你已知的最佳準確率開始!
|
| 193 |
+
|
| 194 |
+
# --- 微調迴圈 ---
|
| 195 |
+
for epoch in range(CFG_FINETUNE.epochs):
|
| 196 |
+
model.train()
|
| 197 |
+
train_loss = 0
|
| 198 |
+
loop = tqdm(train_loader, leave=True)
|
| 199 |
+
|
| 200 |
+
for i, (images, labels, label_lengths) in enumerate(loop):
|
| 201 |
+
images, labels, label_lengths = images.to(CFG_FINETUNE.device), labels.to(CFG_FINETUNE.device), label_lengths.to(CFG_FINETUNE.device)
|
| 202 |
+
optimizer.zero_grad(set_to_none=True)
|
| 203 |
+
|
| 204 |
+
with autocast(device_type=CFG_FINETUNE.device, dtype=torch.bfloat16):
|
| 205 |
+
preds = model(images)
|
| 206 |
+
input_lengths = torch.full(size=(preds.size(1),), fill_value=preds.size(0), dtype=torch.long)
|
| 207 |
+
loss = nn.CTCLoss(blank=0, zero_infinity=True)(preds, labels, input_lengths, label_lengths)
|
| 208 |
+
|
| 209 |
+
scaler.scale(loss).backward()
|
| 210 |
+
scaler.unscale_(optimizer)
|
| 211 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
|
| 212 |
+
scaler.step(optimizer)
|
| 213 |
+
scaler.update()
|
| 214 |
+
scheduler.step() # 更新學習率
|
| 215 |
+
|
| 216 |
+
train_loss += loss.item()
|
| 217 |
+
loop.set_description(f"Fine-tune Epoch [{epoch+1}/{CFG_FINETUNE.epochs}]")
|
| 218 |
+
loop.set_postfix(loss=loss.item(), lr=optimizer.param_groups[0]['lr'])
|
| 219 |
+
|
| 220 |
+
# 驗證迴圈
|
| 221 |
+
model.eval()
|
| 222 |
+
correct_predictions, total_predictions = 0, 0
|
| 223 |
+
with torch.no_grad():
|
| 224 |
+
for images, labels, label_lengths in tqdm(val_loader, desc="Validation"):
|
| 225 |
+
images = images.to(CFG_FINETUNE.device)
|
| 226 |
+
with autocast(device_type=CFG_FINETUNE.device, dtype=torch.bfloat16):
|
| 227 |
+
preds = model(images)
|
| 228 |
+
|
| 229 |
+
decoded_preds = greedy_decode(preds, id_to_char)
|
| 230 |
+
original_texts = ["".join([id_to_char[l.item()] for l in label[:length]]) for label, length in zip(labels, label_lengths)]
|
| 231 |
+
|
| 232 |
+
for pred, target in zip(decoded_preds, original_texts):
|
| 233 |
+
if pred == target:
|
| 234 |
+
correct_predictions += 1
|
| 235 |
+
total_predictions += 1
|
| 236 |
+
|
| 237 |
+
accuracy = correct_predictions / total_predictions
|
| 238 |
+
print(f"Epoch {epoch+1}, Train Loss: {train_loss/len(train_loader):.4f}, Val Accuracy: {accuracy:.4f}")
|
| 239 |
+
|
| 240 |
+
# 保存更好的模型
|
| 241 |
+
if accuracy > best_accuracy:
|
| 242 |
+
best_accuracy = accuracy
|
| 243 |
+
# 使用一個新的檔名,避免覆蓋原始的最佳模型
|
| 244 |
+
torch.save(model.state_dict(), "captcha_transformer_best_finetune.pth")
|
| 245 |
+
print(f"🎉 New best fine-tuned model saved with accuracy: {best_accuracy:.4f}")
|
| 246 |
+
|
| 247 |
+
if __name__ == "__main__":
|
| 248 |
+
finetune()
|
找錯誤識別圖片.py
ADDED
|
@@ -0,0 +1,311 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from torch.utils.data import DataLoader, Dataset
|
| 4 |
+
from torch.amp import autocast
|
| 5 |
+
from torchvision import transforms
|
| 6 |
+
from datasets import load_dataset
|
| 7 |
+
import os
|
| 8 |
+
import json
|
| 9 |
+
import math
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
import re
|
| 12 |
+
from collections import Counter
|
| 13 |
+
import Levenshtein
|
| 14 |
+
from torchvision.transforms import functional as F
|
| 15 |
+
from PIL import Image
|
| 16 |
+
# ===================================================================
|
| 17 |
+
# 這是一個獨立腳本,所以我們從之前的檔案複製所有必要的定義
|
| 18 |
+
# ===================================================================
|
| 19 |
+
|
| 20 |
+
# --- 1. 設定參數 ---
|
| 21 |
+
class CFG_ANALYSIS:
|
| 22 |
+
# 載入的模型和詞表路徑 (使用你微調後的最佳模型)
|
| 23 |
+
model_path = "captcha_transformer_best_finetune.pth"
|
| 24 |
+
vocab_path = "vocab.json"
|
| 25 |
+
|
| 26 |
+
# 資料集 (使用驗證集)
|
| 27 |
+
dataset_name = "gary109/captcha-synth-v3"
|
| 28 |
+
|
| 29 |
+
# 儲存錯誤圖片的資料夾
|
| 30 |
+
output_dir = "error_analysis_results"
|
| 31 |
+
|
| 32 |
+
# 模型參數 (必須和訓練時完全一致!)
|
| 33 |
+
d_model = 256
|
| 34 |
+
nhead = 8
|
| 35 |
+
num_encoder_layers = 4
|
| 36 |
+
dim_feedforward = 1024
|
| 37 |
+
img_width = 200
|
| 38 |
+
img_height = 50
|
| 39 |
+
|
| 40 |
+
# 推理參數
|
| 41 |
+
batch_size = 1024 # 推理時可以使用更大的 batch size 來加速
|
| 42 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 43 |
+
|
| 44 |
+
# --- 2. 必要的類別和函數 (從原檔案複製) ---
|
| 45 |
+
|
| 46 |
+
class CaptchaDataset(Dataset):
|
| 47 |
+
def __init__(self, hf_dataset, char_to_id, transform=None):
|
| 48 |
+
self.dataset = hf_dataset
|
| 49 |
+
self.transform = transform
|
| 50 |
+
self.char_to_id = char_to_id
|
| 51 |
+
def __len__(self):
|
| 52 |
+
return len(self.dataset)
|
| 53 |
+
def __getitem__(self, idx):
|
| 54 |
+
item = self.dataset[idx]
|
| 55 |
+
image = item['image'].convert("L")
|
| 56 |
+
label = item.get('label') or item.get('text')
|
| 57 |
+
if self.transform:
|
| 58 |
+
image = self.transform(image)
|
| 59 |
+
label_encoded = [self.char_to_id[char] for char in label]
|
| 60 |
+
return image, torch.tensor(label_encoded, dtype=torch.long)
|
| 61 |
+
|
| 62 |
+
def collate_fn(batch):
|
| 63 |
+
images, labels = zip(*batch)
|
| 64 |
+
images = torch.stack(images, 0)
|
| 65 |
+
label_lengths = torch.tensor([len(label) for label in labels], dtype=torch.long)
|
| 66 |
+
padded_labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=0)
|
| 67 |
+
return images, padded_labels, label_lengths
|
| 68 |
+
class PadAndResize:
|
| 69 |
+
"""
|
| 70 |
+
一個自定義的 transform,它會將圖片縮放到指定的尺寸,同時保持原始長寬比。
|
| 71 |
+
不足的部分會用指定的顏色進行填充 (padding)。
|
| 72 |
+
"""
|
| 73 |
+
def __init__(self, output_size, fill=0):
|
| 74 |
+
"""
|
| 75 |
+
:param output_size: (height, width) 的元組
|
| 76 |
+
:param fill: 用於填充的像素值 (0=黑色, 255=白色)
|
| 77 |
+
"""
|
| 78 |
+
self.output_size = output_size
|
| 79 |
+
self.fill = fill
|
| 80 |
+
|
| 81 |
+
def __call__(self, img):
|
| 82 |
+
# 獲取目標尺寸和原始尺寸
|
| 83 |
+
target_h, target_w = self.output_size
|
| 84 |
+
original_w, original_h = img.size
|
| 85 |
+
|
| 86 |
+
# 計算長寬比
|
| 87 |
+
target_ratio = target_w / target_h
|
| 88 |
+
original_ratio = original_w / original_h
|
| 89 |
+
|
| 90 |
+
if original_ratio > target_ratio:
|
| 91 |
+
# 原始圖片比目標更「寬」,以寬度為基準進行縮放
|
| 92 |
+
new_w = target_w
|
| 93 |
+
new_h = int(new_w / original_ratio)
|
| 94 |
+
img = F.resize(img, (new_h, new_w))
|
| 95 |
+
|
| 96 |
+
# 計算需要填充的高度
|
| 97 |
+
pad_h = target_h - new_h
|
| 98 |
+
# 上下各填充一半
|
| 99 |
+
padding = (0, pad_h // 2, 0, target_h - new_h - (pad_h // 2))
|
| 100 |
+
else:
|
| 101 |
+
# 原始圖片比目標更「高」或長寬比相同,以高度為基準進行縮放
|
| 102 |
+
new_h = target_h
|
| 103 |
+
new_w = int(new_h * original_ratio)
|
| 104 |
+
img = F.resize(img, (new_h, new_w))
|
| 105 |
+
|
| 106 |
+
# 計算需要填充的寬度
|
| 107 |
+
pad_w = target_w - new_w
|
| 108 |
+
# 左右各填充一半
|
| 109 |
+
padding = (pad_w // 2, 0, target_w - new_w - (pad_w // 2), 0)
|
| 110 |
+
|
| 111 |
+
# 應用填充
|
| 112 |
+
return F.pad(img, padding, self.fill)
|
| 113 |
+
transform = transforms.Compose([
|
| 114 |
+
PadAndResize((CFG_ANALYSIS.img_height, CFG_ANALYSIS.img_width), fill=0),
|
| 115 |
+
transforms.ToTensor(),
|
| 116 |
+
transforms.Normalize((0.5,), (0.5,))
|
| 117 |
+
])
|
| 118 |
+
|
| 119 |
+
class PositionalEncoding(nn.Module):
|
| 120 |
+
def __init__(self, d_model, dropout=0.1, max_len=5000):
|
| 121 |
+
super(PositionalEncoding, self).__init__()
|
| 122 |
+
self.dropout = nn.Dropout(p=dropout)
|
| 123 |
+
pe = torch.zeros(max_len, d_model)
|
| 124 |
+
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
| 125 |
+
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
|
| 126 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
| 127 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
| 128 |
+
pe = pe.unsqueeze(0).transpose(0, 1)
|
| 129 |
+
self.register_buffer('pe', pe)
|
| 130 |
+
def forward(self, x):
|
| 131 |
+
x = x + self.pe[:x.size(0), :]
|
| 132 |
+
return self.dropout(x)
|
| 133 |
+
|
| 134 |
+
class CaptchaTransformer(nn.Module):
|
| 135 |
+
def __init__(self, num_classes):
|
| 136 |
+
super(CaptchaTransformer, self).__init__()
|
| 137 |
+
self.cnn = nn.Sequential(
|
| 138 |
+
nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1), nn.ReLU(),
|
| 139 |
+
nn.MaxPool2d(kernel_size=2, stride=2),
|
| 140 |
+
nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1), nn.ReLU(),
|
| 141 |
+
nn.MaxPool2d(kernel_size=2, stride=2),
|
| 142 |
+
nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(256), nn.ReLU(),
|
| 143 |
+
nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1), nn.ReLU(),
|
| 144 |
+
nn.AdaptiveMaxPool2d((1, None)),
|
| 145 |
+
nn.Conv2d(256, CFG_ANALYSIS.d_model, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(CFG_ANALYSIS.d_model), nn.ReLU()
|
| 146 |
+
)
|
| 147 |
+
self.pos_encoder = PositionalEncoding(CFG_ANALYSIS.d_model)
|
| 148 |
+
encoder_layer = nn.TransformerEncoderLayer(d_model=CFG_ANALYSIS.d_model, nhead=CFG_ANALYSIS.nhead, dim_feedforward=CFG_ANALYSIS.dim_feedforward, dropout=0.1)
|
| 149 |
+
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=CFG_ANALYSIS.num_encoder_layers)
|
| 150 |
+
self.output_layer = nn.Linear(CFG_ANALYSIS.d_model, num_classes)
|
| 151 |
+
def forward(self, x):
|
| 152 |
+
x = self.cnn(x)
|
| 153 |
+
x = x.squeeze(2)
|
| 154 |
+
x = x.permute(2, 0, 1)
|
| 155 |
+
x = self.pos_encoder(x)
|
| 156 |
+
x = self.transformer_encoder(x)
|
| 157 |
+
output = self.output_layer(x)
|
| 158 |
+
return nn.functional.log_softmax(output, dim=2)
|
| 159 |
+
|
| 160 |
+
def greedy_decode(preds, id_to_char_map):
|
| 161 |
+
decoded_texts = []
|
| 162 |
+
preds = preds.permute(1, 0, 2)
|
| 163 |
+
pred_indices = torch.argmax(preds, dim=2)
|
| 164 |
+
for indices in pred_indices:
|
| 165 |
+
text = []
|
| 166 |
+
last_char_id = 0
|
| 167 |
+
for char_id in indices:
|
| 168 |
+
char_id = char_id.item()
|
| 169 |
+
if char_id != 0 and char_id != last_char_id:
|
| 170 |
+
text.append(id_to_char_map[char_id])
|
| 171 |
+
last_char_id = char_id
|
| 172 |
+
decoded_texts.append("".join(text))
|
| 173 |
+
return decoded_texts
|
| 174 |
+
|
| 175 |
+
# --- 3. 錯誤分析主程式 ---
|
| 176 |
+
def analyze_errors():
|
| 177 |
+
print("--- Starting Quantitative Error Analysis ---")
|
| 178 |
+
|
| 179 |
+
# --- 準備工作 (和之前一樣) ---
|
| 180 |
+
CFG_ANALYSIS.output_dir = "error_analysis_v2_results"
|
| 181 |
+
os.makedirs(CFG_ANALYSIS.output_dir, exist_ok=True)
|
| 182 |
+
|
| 183 |
+
with open(CFG_ANALYSIS.vocab_path, 'r', encoding='utf-8') as f:
|
| 184 |
+
characters = json.load(f)
|
| 185 |
+
id_to_char = {i + 1: char for i, char in enumerate(characters)}
|
| 186 |
+
VOCAB_SIZE = len(characters) + 1
|
| 187 |
+
|
| 188 |
+
val_hf_dataset = load_dataset(CFG_ANALYSIS.dataset_name, split="validation")
|
| 189 |
+
char_to_id = {c: i + 1 for i, c in enumerate(characters)}
|
| 190 |
+
val_torch_dataset = CaptchaDataset(val_hf_dataset, char_to_id, transform=transform)
|
| 191 |
+
val_loader = DataLoader(val_torch_dataset, batch_size=CFG_ANALYSIS.batch_size, shuffle=False, collate_fn=collate_fn,num_workers=8,pin_memory=True)
|
| 192 |
+
|
| 193 |
+
model = CaptchaTransformer(num_classes=VOCAB_SIZE).to(CFG_ANALYSIS.device)
|
| 194 |
+
model.load_state_dict(torch.load(CFG_ANALYSIS.model_path, map_location=CFG_ANALYSIS.device))
|
| 195 |
+
model.eval()
|
| 196 |
+
print("Model loaded successfully.")
|
| 197 |
+
|
| 198 |
+
error_counts = Counter()
|
| 199 |
+
confusion_matrix = {char: Counter() for char in characters}
|
| 200 |
+
|
| 201 |
+
# --- 準備 list 來收集所有批次的結果 ---
|
| 202 |
+
all_preds_list = []
|
| 203 |
+
all_labels_list = []
|
| 204 |
+
all_label_lengths_list = []
|
| 205 |
+
all_indices_list = [] # <<< 新增:記錄每個樣本的原始索引
|
| 206 |
+
|
| 207 |
+
# ==========================================================
|
| 208 |
+
# ============[ Phase 1: GPU-Intensive Pass ]============
|
| 209 |
+
# ==========================================================
|
| 210 |
+
# 這個迴圈只負責模型推理,速度會非常快
|
| 211 |
+
with torch.no_grad():
|
| 212 |
+
for batch_idx, (images, labels, label_lengths) in enumerate(tqdm(val_loader, desc="Phase 1: GPU Inference")):
|
| 213 |
+
images = images.to(CFG_ANALYSIS.device)
|
| 214 |
+
with autocast(device_type=CFG_ANALYSIS.device, dtype=torch.bfloat16):
|
| 215 |
+
preds = model(images)
|
| 216 |
+
|
| 217 |
+
# 將結果從 GPU 移至 CPU RAM 並儲存
|
| 218 |
+
all_preds_list.append(preds.cpu())
|
| 219 |
+
all_labels_list.append(labels)
|
| 220 |
+
all_label_lengths_list.append(label_lengths)
|
| 221 |
+
|
| 222 |
+
# 記錄這個批次中每個樣本的原始索引
|
| 223 |
+
start_idx = batch_idx * CFG_ANALYSIS.batch_size
|
| 224 |
+
end_idx = start_idx + len(images)
|
| 225 |
+
all_indices_list.extend(range(start_idx, end_idx))
|
| 226 |
+
|
| 227 |
+
# ==========================================================
|
| 228 |
+
# ============[ Phase 2: CPU-Intensive Analysis ]===========
|
| 229 |
+
# ==========================================================
|
| 230 |
+
# 現在 GPU 已完成工作,我們在 CPU 上一次性處理所有收集到的結果
|
| 231 |
+
print("\nPhase 2: Analyzing results on CPU...")
|
| 232 |
+
|
| 233 |
+
# 1. 預測結果 (preds) 的序列長度是固定的,所��可以安全地合併
|
| 234 |
+
all_preds_tensor = torch.cat(all_preds_list, dim=1)
|
| 235 |
+
decoded_preds = greedy_decode(all_preds_tensor, id_to_char)
|
| 236 |
+
|
| 237 |
+
# 2. 標籤 (labels) 的長度是可變的,所以我們需要逐個批次處理
|
| 238 |
+
original_texts = []
|
| 239 |
+
for labels, label_lengths in zip(all_labels_list, all_label_lengths_list):
|
| 240 |
+
batch_texts = ["".join([id_to_char[l.item()] for l in label[:length]]) for label, length in zip(labels, label_lengths)]
|
| 241 |
+
original_texts.extend(batch_texts)
|
| 242 |
+
|
| 243 |
+
# 現在,我們在 CPU 上迴圈處理結果,進行分類和存檔
|
| 244 |
+
for i in tqdm(range(len(decoded_preds)), desc="Phase 2: Classifying errors and saving files"):
|
| 245 |
+
pred = decoded_preds[i]
|
| 246 |
+
target = original_texts[i]
|
| 247 |
+
|
| 248 |
+
if pred != target:
|
| 249 |
+
# --- 錯誤分類和統計 (和之前一樣) ---
|
| 250 |
+
error_type = "unknown"
|
| 251 |
+
if len(pred) != len(target):
|
| 252 |
+
error_type = "length_mismatch"
|
| 253 |
+
else:
|
| 254 |
+
distance = Levenshtein.distance(pred, target)
|
| 255 |
+
if distance == 1:
|
| 256 |
+
error_type = "substitution"
|
| 257 |
+
for j in range(len(pred)):
|
| 258 |
+
if pred[j] != target[j]:
|
| 259 |
+
confusion_matrix[target[j]][pred[j]] += 1
|
| 260 |
+
break
|
| 261 |
+
elif distance > 1:
|
| 262 |
+
error_type = "complex_error"
|
| 263 |
+
error_counts[error_type] += 1
|
| 264 |
+
|
| 265 |
+
# --- 存檔 (和之前一樣,但使用記錄好的索引) ---
|
| 266 |
+
error_dir = os.path.join(CFG_ANALYSIS.output_dir, error_type)
|
| 267 |
+
os.makedirs(error_dir, exist_ok=True)
|
| 268 |
+
|
| 269 |
+
original_idx = all_indices_list[i]
|
| 270 |
+
original_pil_image = val_hf_dataset[original_idx]['image']
|
| 271 |
+
|
| 272 |
+
pred_sanitized = re.sub(r'[\\/*?:"<>|]', "", pred) or "EMPTY"
|
| 273 |
+
target_sanitized = re.sub(r'[\\/*?:"<>|]', "", target)
|
| 274 |
+
filename = f"idx{original_idx}_pred_{pred_sanitized}_target_{target_sanitized}.png"
|
| 275 |
+
filepath = os.path.join(error_dir, filename)
|
| 276 |
+
original_pil_image.save(filepath)
|
| 277 |
+
|
| 278 |
+
# --- 3. 生成並打印統計報告 ---
|
| 279 |
+
total_errors = sum(error_counts.values())
|
| 280 |
+
report = "--- Error Analysis Report ---\n\n"
|
| 281 |
+
report += f"Total Errors Found: {total_errors}\n\n"
|
| 282 |
+
report += "Error Type Distribution:\n"
|
| 283 |
+
for error_type, count in error_counts.most_common():
|
| 284 |
+
percentage = (count / total_errors) * 100
|
| 285 |
+
report += f"- {error_type:<20}: {count:>5} errors ({percentage:.2f}%)\n"
|
| 286 |
+
|
| 287 |
+
# 找出最常見的15個替換錯誤
|
| 288 |
+
substitution_pairs = []
|
| 289 |
+
for target_char, preds in confusion_matrix.items():
|
| 290 |
+
for pred_char, count in preds.items():
|
| 291 |
+
if count > 0:
|
| 292 |
+
substitution_pairs.append(((target_char, pred_char), count))
|
| 293 |
+
|
| 294 |
+
# 按數量排序
|
| 295 |
+
top_15_substitutions = sorted(substitution_pairs, key=lambda item: item[1], reverse=True)[:15]
|
| 296 |
+
|
| 297 |
+
report += "\n\nTop 15 Character Substitution Errors (Target -> Prediction):\n"
|
| 298 |
+
for (target, pred), count in top_15_substitutions:
|
| 299 |
+
report += f"- '{target}' -> '{pred}': {count:>5} times\n"
|
| 300 |
+
|
| 301 |
+
print("\n" + report)
|
| 302 |
+
|
| 303 |
+
# 將報告寫入檔案
|
| 304 |
+
with open(os.path.join(CFG_ANALYSIS.output_dir, "report.txt"), "w", encoding="utf-8") as f:
|
| 305 |
+
f.write(report)
|
| 306 |
+
|
| 307 |
+
print(f"\nAnalysis complete. Report saved to '{os.path.join(CFG_ANALYSIS.output_dir, 'report.txt')}'")
|
| 308 |
+
print("You can now review the categorized images in the results folder.")
|
| 309 |
+
|
| 310 |
+
if __name__ == "__main__":
|
| 311 |
+
analyze_errors()
|