Transformers
Safetensors
captcha_transformer
Gavin-chen commited on
Commit
49f8f09
·
verified ·
1 Parent(s): 859beab

Upload 3 files

Browse files
圖片驗證碼識別.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.utils.data import DataLoader, Dataset
4
+ from torch.amp import autocast, GradScaler
5
+ from torchvision import transforms
6
+ from datasets import load_dataset
7
+ import numpy as np
8
+ import string
9
+ import math
10
+ from tqdm import tqdm
11
+ import os
12
+ import json
13
+ from torch.optim.lr_scheduler import LambdaLR
14
+ # --- 1. 設定參數 ---
15
+ class CFG:
16
+ # 資料集與字元集
17
+ dataset_name = "gary109/captcha-synth-v3"
18
+ characters = string.digits + string.ascii_lowercase + string.ascii_uppercase # '0123...abc...ABC...'
19
+
20
+ # 圖片尺寸
21
+ img_width = 200
22
+ img_height = 50
23
+
24
+ # 模型參數
25
+ d_model = 256 # Transformer 的特徵維度 (embedding dim)
26
+ nhead = 8 # Transformer 的多頭注意力頭數
27
+ num_encoder_layers = 4 # Transformer Encoder 的層數
28
+ dim_feedforward = 1024 # Transformer 前饋網路的維度
29
+
30
+ # 訓練參數
31
+ epochs = 10
32
+ batch_size = 128
33
+ lr = 1e-4
34
+ device = "cuda" if torch.cuda.is_available() else "cpu"
35
+
36
+ # --- 2. 資料準備 ---
37
+
38
+ # PyTorch Dataset
39
+ class CaptchaDataset(Dataset):
40
+ def __init__(self, hf_dataset,char_to_id, transform=None):
41
+ self.dataset = hf_dataset
42
+ self.transform = transform
43
+ self.char_to_id = char_to_id
44
+
45
+ def __len__(self):
46
+ return len(self.dataset)
47
+
48
+ def __getitem__(self, idx):
49
+ item = self.dataset[idx]
50
+ image = item['image'].convert("L") # 轉為灰階
51
+ label = item['text']
52
+
53
+ if self.transform:
54
+ image = self.transform(image)
55
+
56
+ # 將文字標籤轉換為數字序列
57
+ label_encoded = [self.char_to_id[char] for char in label]
58
+
59
+ return image, torch.tensor(label_encoded, dtype=torch.long)
60
+
61
+ # Dataloader 的 Collate Function,用於處理不同長度的標籤
62
+ def collate_fn(batch):
63
+ images, labels = zip(*batch)
64
+ images = torch.stack(images, 0)
65
+
66
+ # 對標籤進行填充
67
+ label_lengths = torch.tensor([len(label) for label in labels], dtype=torch.long)
68
+ padded_labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=0)
69
+
70
+ return images, padded_labels, label_lengths
71
+
72
+ # 定義圖像轉換
73
+ transform = transforms.Compose([
74
+ transforms.Resize((CFG.img_height, CFG.img_width)),
75
+ transforms.ToTensor(),
76
+ transforms.Normalize((0.5,), (0.5,)) # 歸一化
77
+ ])
78
+
79
+ # --- 3. 模型架構 (CNN + Transformer) ---
80
+
81
+ class PositionalEncoding(nn.Module):
82
+ def __init__(self, d_model, dropout=0.1, max_len=5000):
83
+ super(PositionalEncoding, self).__init__()
84
+ self.dropout = nn.Dropout(p=dropout)
85
+ pe = torch.zeros(max_len, d_model)
86
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
87
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
88
+ pe[:, 0::2] = torch.sin(position * div_term)
89
+ pe[:, 1::2] = torch.cos(position * div_term)
90
+ pe = pe.unsqueeze(0).transpose(0, 1)
91
+ self.register_buffer('pe', pe)
92
+
93
+ def forward(self, x):
94
+ x = x + self.pe[:x.size(0), :]
95
+ return self.dropout(x)
96
+
97
+ class CaptchaTransformer(nn.Module):
98
+ def __init__(self, num_classes):
99
+ super(CaptchaTransformer, self).__init__()
100
+ # CNN Backbone
101
+ self.cnn = nn.Sequential(
102
+ nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1),
103
+ nn.ReLU(),
104
+ nn.MaxPool2d(kernel_size=2, stride=2),
105
+ nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
106
+ nn.ReLU(),
107
+ nn.MaxPool2d(kernel_size=2, stride=2),
108
+ nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
109
+ nn.BatchNorm2d(256),
110
+ nn.ReLU(),
111
+ nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
112
+ nn.ReLU(),
113
+ nn.AdaptiveMaxPool2d((1, None)),
114
+ nn.Conv2d(256, CFG.d_model, kernel_size=3, stride=1, padding=1),
115
+ nn.BatchNorm2d(CFG.d_model),
116
+ nn.ReLU()
117
+ )
118
+
119
+ self.pos_encoder = PositionalEncoding(CFG.d_model)
120
+ encoder_layer = nn.TransformerEncoderLayer(
121
+ d_model=CFG.d_model, nhead=CFG.nhead,
122
+ dim_feedforward=CFG.dim_feedforward, dropout=0.1
123
+ )
124
+ self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=CFG.num_encoder_layers)
125
+
126
+ # 輸出層
127
+ self.output_layer = nn.Linear(CFG.d_model, num_classes)
128
+
129
+ def forward(self, x):
130
+ # x shape: [batch_size, channels, height, width]
131
+ x = self.cnn(x)
132
+ # x shape: [batch_size, d_model, new_height, new_width]
133
+
134
+ # 準備 Transformer 輸入
135
+ # (W, N, E) -> (sequence_length, batch_size, embedding_dim)
136
+ x = x.squeeze(2) # 壓縮高度維度
137
+ x = x.permute(2, 0, 1) # [width, batch_size, d_model]
138
+
139
+ x = self.pos_encoder(x)
140
+ x = self.transformer_encoder(x)
141
+
142
+ # x shape: [width, batch_size, d_model]
143
+ output = self.output_layer(x)
144
+
145
+ # CTC Loss 需要 log_softmax
146
+ return nn.functional.log_softmax(output, dim=2)
147
+
148
+ # --- 4. 訓練與驗證 ---
149
+
150
+ def greedy_decode(preds, id_to_char_map):
151
+ decoded_texts = []
152
+ # preds shape: [seq_len, batch_size, num_classes]
153
+ preds = preds.permute(1, 0, 2) # -> [batch_size, seq_len, num_classes]
154
+ pred_indices = torch.argmax(preds, dim=2)
155
+
156
+ for indices in pred_indices:
157
+ text = []
158
+ last_char_id = 0
159
+ for char_id in indices:
160
+ char_id = char_id.item()
161
+ if char_id != 0 and char_id != last_char_id: # 忽略 blank 和連續重複
162
+ text.append(id_to_char_map[char_id])
163
+ last_char_id = char_id
164
+ decoded_texts.append("".join(text))
165
+ return decoded_texts
166
+ char_to_id, id_to_char, VOCAB_SIZE = {}, {}, 0
167
+ def main():
168
+ global char_to_id, id_to_char, VOCAB_SIZE
169
+ print(f"Using device: {CFG.device}")
170
+
171
+ # 載入資料集
172
+ print("Loading dataset from Hugging Face Hub...")
173
+ train_hf_dataset = load_dataset(CFG.dataset_name, split="train")
174
+
175
+ # 切分訓練集和驗證集
176
+ val_hf_dataset = load_dataset(CFG.dataset_name, split="validation")
177
+ print("Generating vocabulary from the dataset...")
178
+ vocab_path = "vocab.json"
179
+ if os.path.exists(vocab_path):
180
+ print(f"Loading vocabulary from {vocab_path}...")
181
+ with open(vocab_path, 'r', encoding='utf-8') as f:
182
+ characters = json.load(f)
183
+ else:
184
+ # 1. 遍歷數據集,收集所有字元到一個 set 中以確保唯一性
185
+ all_chars = set()
186
+ total_samples = len(train_hf_dataset)
187
+ for i in tqdm(range(total_samples), desc="Scanning labels"):
188
+ label = train_hf_dataset[i]['text']
189
+ all_chars.update(list(label))
190
+ total_samples = len(val_hf_dataset)
191
+ for i in tqdm(range(total_samples), desc="Scanning labels"):
192
+ label = val_hf_dataset[i]['text']
193
+ all_chars.update(list(label))
194
+ # 2. 將 set 轉換為排序後的 list,確保每次運行的順序都一樣
195
+ # 這對模型的可複現性至關重要!
196
+ characters = sorted(list(all_chars))
197
+ with open(vocab_path, 'w', encoding='utf-8') as f:
198
+ json.dump(characters, f, ensure_ascii=False, indent=2)
199
+ print(f"Vocabulary saved to {vocab_path}")
200
+ print(f"Unique characters found: {''.join(characters)}")
201
+ CFG.characters = "".join(characters) # 為了方便查看
202
+
203
+ # 建立字元對應 ID 的字典
204
+
205
+ char_to_id = {char: i + 1 for i, char in enumerate(CFG.characters)}
206
+ id_to_char = {i + 1: char for i, char in enumerate(CFG.characters)}
207
+ VOCAB_SIZE = len(CFG.characters) + 1 # +1 for CTC blank token at index 0
208
+ print(f"Total unique characters: {VOCAB_SIZE - 1}")
209
+ train_dataset = CaptchaDataset(train_hf_dataset,char_to_id=char_to_id, transform=transform)
210
+ val_dataset = CaptchaDataset(val_hf_dataset,char_to_id=char_to_id, transform=transform)
211
+
212
+ train_loader = DataLoader(train_dataset, batch_size=CFG.batch_size, shuffle=True, collate_fn=collate_fn, num_workers=4)
213
+ val_loader = DataLoader(val_dataset, batch_size=CFG.batch_size, shuffle=False, collate_fn=collate_fn, num_workers=4)
214
+
215
+ print(f"Train samples: {len(train_dataset)}, Val samples: {len(val_dataset)}")
216
+ # 初始化模型、損失函數、優化器
217
+ model = CaptchaTransformer(num_classes=VOCAB_SIZE).to(CFG.device)
218
+ criterion = nn.CTCLoss(blank=0, zero_infinity=True)
219
+ optimizer = torch.optim.AdamW(model.parameters(), lr=CFG.lr)
220
+ warmup_steps = 1000
221
+ def warmup_lambda(current_step):
222
+ if current_step < warmup_steps:
223
+ # 在 warmup 期間,學習率從 0 線性增加到 1.0
224
+ return float(current_step) / float(max(1, warmup_steps))
225
+ # warmup 之後,學習率保持不變 (乘以 1.0)
226
+ return 1.0
227
+ scheduler = LambdaLR(optimizer, lr_lambda=warmup_lambda)
228
+ scaler = GradScaler()
229
+ # 訓練迴圈
230
+ for epoch in range(CFG.epochs):
231
+ model.train() # <<< 移到 Epoch 迴圈的開頭
232
+ train_loss = 0
233
+ loop = tqdm(train_loader, leave=True)
234
+
235
+ for i, (images, labels, label_lengths) in enumerate(loop):
236
+ images = images.to(CFG.device)
237
+ labels = labels.to(CFG.device)
238
+ label_lengths = label_lengths.to(CFG.device)
239
+
240
+ # <<< 1. 將 zero_grad 移到迴圈內,並使用 set_to_none=True
241
+ optimizer.zero_grad(set_to_none=True)
242
+
243
+ # <<< 2. 將 autocast 只包裹前向傳播和損失計算
244
+ with autocast(device_type=CFG.device, dtype=torch.bfloat16):
245
+ preds = model(images)
246
+ input_lengths = torch.full(size=(preds.size(1),), fill_value=preds.size(0), dtype=torch.long)
247
+ loss = criterion(preds, labels, input_lengths, label_lengths)
248
+
249
+ # <<< 3. 後續的 scaler 操作和優化器步驟在 autocast 之外
250
+ scaler.scale(loss).backward()
251
+ scaler.unscale_(optimizer) # 在裁剪前 unscale 梯度
252
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
253
+ scaler.step(optimizer)
254
+ scaler.update()
255
+
256
+ scheduler.step() # 每一步都要更新學習率
257
+
258
+ train_loss += loss.item()
259
+ loop.set_description(f"Epoch [{epoch+1}/{CFG.epochs}]")
260
+ loop.set_postfix(loss=loss.item(), lr=optimizer.param_groups[0]['lr'])
261
+
262
+ # 驗證迴圈
263
+ model.eval()
264
+ correct_predictions = 0
265
+ total_predictions = 0
266
+ is_printed = False
267
+ with torch.no_grad():
268
+ for images, labels, label_lengths in tqdm(val_loader, desc="Validation", leave=True):
269
+ with autocast(device_type=CFG.device, dtype=torch.bfloat16):
270
+ images = images.to(CFG.device)
271
+ preds = model(images)
272
+
273
+ decoded_preds = greedy_decode(preds, id_to_char)
274
+ # 將 padded labels 轉回文字
275
+ original_texts = []
276
+ for label, length in zip(labels, label_lengths):
277
+ original_texts.append("".join([id_to_char[l.item()] for l in label[:length]]))
278
+
279
+ for pred, target in zip(decoded_preds, original_texts):
280
+ if not is_printed:
281
+ print(pred,target)
282
+ is_printed = True
283
+ if pred == target:
284
+ correct_predictions += 1
285
+ total_predictions += 1
286
+
287
+ accuracy = correct_predictions / total_predictions
288
+ print(f"Epoch {epoch+1}, Train Loss: {train_loss/len(train_loader):.4f}, Val Accuracy: {accuracy:.4f}")
289
+
290
+ # 保存模型
291
+ torch.save(model.state_dict(), "captcha_transformer.pth")
292
+ print("Model saved to captcha_transformer.pth")
293
+
294
+
295
+ if __name__ == "__main__":
296
+ main()
微調驗證碼識別模型.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.utils.data import DataLoader, Dataset
4
+ from torch.optim.lr_scheduler import CosineAnnealingLR
5
+ from torch.amp import autocast, GradScaler
6
+ from torchvision import transforms
7
+ from datasets import load_dataset
8
+ import numpy as np
9
+ import string
10
+ import math
11
+ from tqdm import tqdm
12
+ import os
13
+ import json
14
+
15
+ # ===================================================================
16
+ # 幾乎所有的定義都和原檔案一樣,可以直接複製過來
17
+ # 這樣能確保權重可以被正確載入
18
+ # ===================================================================
19
+
20
+ # --- 1. 設定參數 (微調專用) ---
21
+ class CFG_FINETUNE:
22
+ # 載入的模型和詞表路徑
23
+ model_path = "captcha_transformer_best_finetune.pth" # 確保這是你保存的最佳模型檔名
24
+ vocab_path = "vocab.json"
25
+
26
+ # 資料集 (保持不變)
27
+ dataset_name = "gary109/captcha-synth-v3"
28
+
29
+ # 圖片尺寸 (保持不變)
30
+ img_width = 200
31
+ img_height = 50
32
+
33
+ # 模型參數 (必須和原模型完全一致!)
34
+ d_model = 256
35
+ nhead = 8
36
+ num_encoder_layers = 4
37
+ dim_feedforward = 1024
38
+
39
+ # 微調參數 (這是關鍵!)
40
+ epochs = 5 # 微調通常不需要太多輪
41
+ batch_size = 32
42
+ lr = 1e-5 # <<< 使用一個更小的學習率!
43
+ device = "cuda" if torch.cuda.is_available() else "cpu"
44
+
45
+ # --- 2. 需要用到的類別和函數 (從原檔案複製) ---
46
+ # (為了讓此腳本能獨立運行,我們把所有必要的定義都複製過來)
47
+
48
+ class CaptchaDataset(Dataset):
49
+ def __init__(self, hf_dataset, char_to_id, transform=None):
50
+ self.dataset = hf_dataset
51
+ self.transform = transform
52
+ self.char_to_id = char_to_id
53
+ def __len__(self):
54
+ return len(self.dataset)
55
+ def __getitem__(self, idx):
56
+ item = self.dataset[idx]
57
+ image = item['image'].convert("L")
58
+ label = item.get('label') or item.get('text')
59
+ if self.transform:
60
+ image = self.transform(image)
61
+ label_encoded = [self.char_to_id[char] for char in label]
62
+ return image, torch.tensor(label_encoded, dtype=torch.long)
63
+
64
+ def collate_fn(batch):
65
+ images, labels = zip(*batch)
66
+ images = torch.stack(images, 0)
67
+ label_lengths = torch.tensor([len(label) for label in labels], dtype=torch.long)
68
+ padded_labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=0)
69
+ return images, padded_labels, label_lengths
70
+
71
+ transform = transforms.Compose([
72
+ # 輕微的幾何扭曲,模擬字元黏連和變形
73
+ transforms.RandomAffine(
74
+ degrees=8, # 隨機旋轉 ±8 度
75
+ translate=(0.1, 0.1), # 隨機平移 10%
76
+ scale=(0.9, 1.1), # 隨機縮放 10%
77
+ shear=5 # 隨機錯切
78
+ ),
79
+ transforms.RandomPerspective(distortion_scale=0.2, p=0.5), # 隨機透視變換
80
+ # 調整大小
81
+ transforms.Resize((CFG_FINETUNE.img_height, CFG_FINETUNE.img_width)),
82
+ # 顏色抖動
83
+ transforms.ColorJitter(brightness=0.4, contrast=0.4),
84
+ # 轉換為 Tensor
85
+ transforms.ToTensor(),
86
+ # 隨機擦除 (關鍵!),模擬干擾線或字元斷裂
87
+ # 注意:這個操作必須在 ToTensor 之後
88
+ transforms.RandomErasing(p=0.5, scale=(0.02, 0.1), ratio=(0.3, 3.3), value=0),
89
+ # 歸一化
90
+ transforms.Normalize((0.5,), (0.5,))
91
+ ])
92
+ val_transform = transforms.Compose([
93
+ transforms.Resize((CFG_FINETUNE.img_height, CFG_FINETUNE.img_width)),
94
+ transforms.ToTensor(),
95
+ transforms.Normalize((0.5,), (0.5,))
96
+ ])
97
+ class PositionalEncoding(nn.Module):
98
+ # ... (從原檔案完整複製 PositionalEncoding 的程式碼)
99
+ def __init__(self, d_model, dropout=0.1, max_len=5000):
100
+ super(PositionalEncoding, self).__init__()
101
+ self.dropout = nn.Dropout(p=dropout)
102
+ pe = torch.zeros(max_len, d_model)
103
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
104
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
105
+ pe[:, 0::2] = torch.sin(position * div_term)
106
+ pe[:, 1::2] = torch.cos(position * div_term)
107
+ pe = pe.unsqueeze(0).transpose(0, 1)
108
+ self.register_buffer('pe', pe)
109
+ def forward(self, x):
110
+ x = x + self.pe[:x.size(0), :]
111
+ return self.dropout(x)
112
+
113
+ class CaptchaTransformer(nn.Module):
114
+ # ... (從原檔案完整複製 CaptchaTransformer 的程式碼)
115
+ def __init__(self, num_classes):
116
+ super(CaptchaTransformer, self).__init__()
117
+ self.cnn = nn.Sequential(
118
+ nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1), nn.ReLU(),
119
+ nn.MaxPool2d(kernel_size=2, stride=2),
120
+ nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1), nn.ReLU(),
121
+ nn.MaxPool2d(kernel_size=2, stride=2),
122
+ nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(256), nn.ReLU(),
123
+ nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1), nn.ReLU(),
124
+ nn.AdaptiveMaxPool2d((1, None)),
125
+ nn.Conv2d(256, CFG_FINETUNE.d_model, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(CFG_FINETUNE.d_model), nn.ReLU()
126
+ )
127
+ self.pos_encoder = PositionalEncoding(CFG_FINETUNE.d_model)
128
+ encoder_layer = nn.TransformerEncoderLayer(d_model=CFG_FINETUNE.d_model, nhead=CFG_FINETUNE.nhead, dim_feedforward=CFG_FINETUNE.dim_feedforward, dropout=0.1)
129
+ self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=CFG_FINETUNE.num_encoder_layers)
130
+ self.output_layer = nn.Linear(CFG_FINETUNE.d_model, num_classes)
131
+ def forward(self, x):
132
+ x = self.cnn(x)
133
+ x = x.squeeze(2)
134
+ x = x.permute(2, 0, 1)
135
+ x = self.pos_encoder(x)
136
+ x = self.transformer_encoder(x)
137
+ output = self.output_layer(x)
138
+ return nn.functional.log_softmax(output, dim=2)
139
+
140
+ def greedy_decode(preds, id_to_char_map):
141
+ # ... (從原檔案完整複製 greedy_decode 的程式碼)
142
+ decoded_texts = []
143
+ preds = preds.permute(1, 0, 2)
144
+ pred_indices = torch.argmax(preds, dim=2)
145
+ for indices in pred_indices:
146
+ text = []
147
+ last_char_id = 0
148
+ for char_id in indices:
149
+ char_id = char_id.item()
150
+ if char_id != 0 and char_id != last_char_id:
151
+ text.append(id_to_char_map[char_id])
152
+ last_char_id = char_id
153
+ decoded_texts.append("".join(text))
154
+ return decoded_texts
155
+
156
+ # --- 3. 微調主程式 ---
157
+ def finetune():
158
+ print(f"Starting fine-tuning process on device: {CFG_FINETUNE.device}")
159
+
160
+ # --- 載入詞表 ---
161
+ with open(CFG_FINETUNE.vocab_path, 'r', encoding='utf-8') as f:
162
+ characters = json.load(f)
163
+
164
+ char_to_id = {char: i + 1 for i, char in enumerate(characters)}
165
+ id_to_char = {i + 1: char for i, char in enumerate(characters)}
166
+ VOCAB_SIZE = len(characters) + 1
167
+ print(f"Vocabulary loaded. Size: {VOCAB_SIZE - 1}")
168
+
169
+ # --- 準備資料 ---
170
+ print("Loading dataset for fine-tuning...")
171
+ train_hf_dataset = load_dataset(CFG_FINETUNE.dataset_name, split="train")
172
+ val_hf_dataset = load_dataset(CFG_FINETUNE.dataset_name, split="validation")
173
+ train_dataset = CaptchaDataset(train_hf_dataset, char_to_id, transform=transform)
174
+ val_dataset = CaptchaDataset(val_hf_dataset, char_to_id, transform=val_transform)
175
+ train_loader = DataLoader(train_dataset, batch_size=CFG_FINETUNE.batch_size, shuffle=True, collate_fn=collate_fn, num_workers=8,pin_memory=True)
176
+ val_loader = DataLoader(val_dataset, batch_size=CFG_FINETUNE.batch_size*8, shuffle=False, collate_fn=collate_fn, num_workers=8,pin_memory=True)
177
+
178
+ # --- 關鍵步驟:初始化模型並載入預訓練權重 ---
179
+ model = CaptchaTransformer(num_classes=VOCAB_SIZE).to(CFG_FINETUNE.device)
180
+ print(f"Loading pre-trained weights from: {CFG_FINETUNE.model_path}")
181
+ model.load_state_dict(torch.load(CFG_FINETUNE.model_path, map_location=CFG_FINETUNE.device))
182
+ print("Weights loaded successfully.")
183
+
184
+ # --- 設定新的優化器和學習率排程器 ---
185
+ optimizer = torch.optim.AdamW(model.parameters(), lr=CFG_FINETUNE.lr)
186
+ scaler = GradScaler()
187
+
188
+ # 使用餘弦退火學習率排程器
189
+ total_steps = len(train_loader) * CFG_FINETUNE.epochs
190
+ scheduler = CosineAnnealingLR(optimizer, T_max=total_steps, eta_min=1e-11) # 讓學習率平滑下降到接近0
191
+
192
+ best_accuracy = 0.8979 # <<< 從你已知的最佳準確率開始!
193
+
194
+ # --- 微調迴圈 ---
195
+ for epoch in range(CFG_FINETUNE.epochs):
196
+ model.train()
197
+ train_loss = 0
198
+ loop = tqdm(train_loader, leave=True)
199
+
200
+ for i, (images, labels, label_lengths) in enumerate(loop):
201
+ images, labels, label_lengths = images.to(CFG_FINETUNE.device), labels.to(CFG_FINETUNE.device), label_lengths.to(CFG_FINETUNE.device)
202
+ optimizer.zero_grad(set_to_none=True)
203
+
204
+ with autocast(device_type=CFG_FINETUNE.device, dtype=torch.bfloat16):
205
+ preds = model(images)
206
+ input_lengths = torch.full(size=(preds.size(1),), fill_value=preds.size(0), dtype=torch.long)
207
+ loss = nn.CTCLoss(blank=0, zero_infinity=True)(preds, labels, input_lengths, label_lengths)
208
+
209
+ scaler.scale(loss).backward()
210
+ scaler.unscale_(optimizer)
211
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
212
+ scaler.step(optimizer)
213
+ scaler.update()
214
+ scheduler.step() # 更新學習率
215
+
216
+ train_loss += loss.item()
217
+ loop.set_description(f"Fine-tune Epoch [{epoch+1}/{CFG_FINETUNE.epochs}]")
218
+ loop.set_postfix(loss=loss.item(), lr=optimizer.param_groups[0]['lr'])
219
+
220
+ # 驗證迴圈
221
+ model.eval()
222
+ correct_predictions, total_predictions = 0, 0
223
+ with torch.no_grad():
224
+ for images, labels, label_lengths in tqdm(val_loader, desc="Validation"):
225
+ images = images.to(CFG_FINETUNE.device)
226
+ with autocast(device_type=CFG_FINETUNE.device, dtype=torch.bfloat16):
227
+ preds = model(images)
228
+
229
+ decoded_preds = greedy_decode(preds, id_to_char)
230
+ original_texts = ["".join([id_to_char[l.item()] for l in label[:length]]) for label, length in zip(labels, label_lengths)]
231
+
232
+ for pred, target in zip(decoded_preds, original_texts):
233
+ if pred == target:
234
+ correct_predictions += 1
235
+ total_predictions += 1
236
+
237
+ accuracy = correct_predictions / total_predictions
238
+ print(f"Epoch {epoch+1}, Train Loss: {train_loss/len(train_loader):.4f}, Val Accuracy: {accuracy:.4f}")
239
+
240
+ # 保存更好的模型
241
+ if accuracy > best_accuracy:
242
+ best_accuracy = accuracy
243
+ # 使用一個新的檔名,避免覆蓋原始的最佳模型
244
+ torch.save(model.state_dict(), "captcha_transformer_best_finetune.pth")
245
+ print(f"🎉 New best fine-tuned model saved with accuracy: {best_accuracy:.4f}")
246
+
247
+ if __name__ == "__main__":
248
+ finetune()
找錯誤識別圖片.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.utils.data import DataLoader, Dataset
4
+ from torch.amp import autocast
5
+ from torchvision import transforms
6
+ from datasets import load_dataset
7
+ import os
8
+ import json
9
+ import math
10
+ from tqdm import tqdm
11
+ import re
12
+ from collections import Counter
13
+ import Levenshtein
14
+ from torchvision.transforms import functional as F
15
+ from PIL import Image
16
+ # ===================================================================
17
+ # 這是一個獨立腳本,所以我們從之前的檔案複製所有必要的定義
18
+ # ===================================================================
19
+
20
+ # --- 1. 設定參數 ---
21
+ class CFG_ANALYSIS:
22
+ # 載入的模型和詞表路徑 (使用你微調後的最佳模型)
23
+ model_path = "captcha_transformer_best_finetune.pth"
24
+ vocab_path = "vocab.json"
25
+
26
+ # 資料集 (使用驗證集)
27
+ dataset_name = "gary109/captcha-synth-v3"
28
+
29
+ # 儲存錯誤圖片的資料夾
30
+ output_dir = "error_analysis_results"
31
+
32
+ # 模型參數 (必須和訓練時完全一致!)
33
+ d_model = 256
34
+ nhead = 8
35
+ num_encoder_layers = 4
36
+ dim_feedforward = 1024
37
+ img_width = 200
38
+ img_height = 50
39
+
40
+ # 推理參數
41
+ batch_size = 1024 # 推理時可以使用更大的 batch size 來加速
42
+ device = "cuda" if torch.cuda.is_available() else "cpu"
43
+
44
+ # --- 2. 必要的類別和函數 (從原檔案複製) ---
45
+
46
+ class CaptchaDataset(Dataset):
47
+ def __init__(self, hf_dataset, char_to_id, transform=None):
48
+ self.dataset = hf_dataset
49
+ self.transform = transform
50
+ self.char_to_id = char_to_id
51
+ def __len__(self):
52
+ return len(self.dataset)
53
+ def __getitem__(self, idx):
54
+ item = self.dataset[idx]
55
+ image = item['image'].convert("L")
56
+ label = item.get('label') or item.get('text')
57
+ if self.transform:
58
+ image = self.transform(image)
59
+ label_encoded = [self.char_to_id[char] for char in label]
60
+ return image, torch.tensor(label_encoded, dtype=torch.long)
61
+
62
+ def collate_fn(batch):
63
+ images, labels = zip(*batch)
64
+ images = torch.stack(images, 0)
65
+ label_lengths = torch.tensor([len(label) for label in labels], dtype=torch.long)
66
+ padded_labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=0)
67
+ return images, padded_labels, label_lengths
68
+ class PadAndResize:
69
+ """
70
+ 一個自定義的 transform,它會將圖片縮放到指定的尺寸,同時保持原始長寬比。
71
+ 不足的部分會用指定的顏色進行填充 (padding)。
72
+ """
73
+ def __init__(self, output_size, fill=0):
74
+ """
75
+ :param output_size: (height, width) 的元組
76
+ :param fill: 用於填充的像素值 (0=黑色, 255=白色)
77
+ """
78
+ self.output_size = output_size
79
+ self.fill = fill
80
+
81
+ def __call__(self, img):
82
+ # 獲取目標尺寸和原始尺寸
83
+ target_h, target_w = self.output_size
84
+ original_w, original_h = img.size
85
+
86
+ # 計算長寬比
87
+ target_ratio = target_w / target_h
88
+ original_ratio = original_w / original_h
89
+
90
+ if original_ratio > target_ratio:
91
+ # 原始圖片比目標更「寬」,以寬度為基準進行縮放
92
+ new_w = target_w
93
+ new_h = int(new_w / original_ratio)
94
+ img = F.resize(img, (new_h, new_w))
95
+
96
+ # 計算需要填充的高度
97
+ pad_h = target_h - new_h
98
+ # 上下各填充一半
99
+ padding = (0, pad_h // 2, 0, target_h - new_h - (pad_h // 2))
100
+ else:
101
+ # 原始圖片比目標更「高」或長寬比相同,以高度為基準進行縮放
102
+ new_h = target_h
103
+ new_w = int(new_h * original_ratio)
104
+ img = F.resize(img, (new_h, new_w))
105
+
106
+ # 計算需要填充的寬度
107
+ pad_w = target_w - new_w
108
+ # 左右各填充一半
109
+ padding = (pad_w // 2, 0, target_w - new_w - (pad_w // 2), 0)
110
+
111
+ # 應用填充
112
+ return F.pad(img, padding, self.fill)
113
+ transform = transforms.Compose([
114
+ PadAndResize((CFG_ANALYSIS.img_height, CFG_ANALYSIS.img_width), fill=0),
115
+ transforms.ToTensor(),
116
+ transforms.Normalize((0.5,), (0.5,))
117
+ ])
118
+
119
+ class PositionalEncoding(nn.Module):
120
+ def __init__(self, d_model, dropout=0.1, max_len=5000):
121
+ super(PositionalEncoding, self).__init__()
122
+ self.dropout = nn.Dropout(p=dropout)
123
+ pe = torch.zeros(max_len, d_model)
124
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
125
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
126
+ pe[:, 0::2] = torch.sin(position * div_term)
127
+ pe[:, 1::2] = torch.cos(position * div_term)
128
+ pe = pe.unsqueeze(0).transpose(0, 1)
129
+ self.register_buffer('pe', pe)
130
+ def forward(self, x):
131
+ x = x + self.pe[:x.size(0), :]
132
+ return self.dropout(x)
133
+
134
+ class CaptchaTransformer(nn.Module):
135
+ def __init__(self, num_classes):
136
+ super(CaptchaTransformer, self).__init__()
137
+ self.cnn = nn.Sequential(
138
+ nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1), nn.ReLU(),
139
+ nn.MaxPool2d(kernel_size=2, stride=2),
140
+ nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1), nn.ReLU(),
141
+ nn.MaxPool2d(kernel_size=2, stride=2),
142
+ nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(256), nn.ReLU(),
143
+ nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1), nn.ReLU(),
144
+ nn.AdaptiveMaxPool2d((1, None)),
145
+ nn.Conv2d(256, CFG_ANALYSIS.d_model, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(CFG_ANALYSIS.d_model), nn.ReLU()
146
+ )
147
+ self.pos_encoder = PositionalEncoding(CFG_ANALYSIS.d_model)
148
+ encoder_layer = nn.TransformerEncoderLayer(d_model=CFG_ANALYSIS.d_model, nhead=CFG_ANALYSIS.nhead, dim_feedforward=CFG_ANALYSIS.dim_feedforward, dropout=0.1)
149
+ self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=CFG_ANALYSIS.num_encoder_layers)
150
+ self.output_layer = nn.Linear(CFG_ANALYSIS.d_model, num_classes)
151
+ def forward(self, x):
152
+ x = self.cnn(x)
153
+ x = x.squeeze(2)
154
+ x = x.permute(2, 0, 1)
155
+ x = self.pos_encoder(x)
156
+ x = self.transformer_encoder(x)
157
+ output = self.output_layer(x)
158
+ return nn.functional.log_softmax(output, dim=2)
159
+
160
+ def greedy_decode(preds, id_to_char_map):
161
+ decoded_texts = []
162
+ preds = preds.permute(1, 0, 2)
163
+ pred_indices = torch.argmax(preds, dim=2)
164
+ for indices in pred_indices:
165
+ text = []
166
+ last_char_id = 0
167
+ for char_id in indices:
168
+ char_id = char_id.item()
169
+ if char_id != 0 and char_id != last_char_id:
170
+ text.append(id_to_char_map[char_id])
171
+ last_char_id = char_id
172
+ decoded_texts.append("".join(text))
173
+ return decoded_texts
174
+
175
+ # --- 3. 錯誤分析主程式 ---
176
+ def analyze_errors():
177
+ print("--- Starting Quantitative Error Analysis ---")
178
+
179
+ # --- 準備工作 (和之前一樣) ---
180
+ CFG_ANALYSIS.output_dir = "error_analysis_v2_results"
181
+ os.makedirs(CFG_ANALYSIS.output_dir, exist_ok=True)
182
+
183
+ with open(CFG_ANALYSIS.vocab_path, 'r', encoding='utf-8') as f:
184
+ characters = json.load(f)
185
+ id_to_char = {i + 1: char for i, char in enumerate(characters)}
186
+ VOCAB_SIZE = len(characters) + 1
187
+
188
+ val_hf_dataset = load_dataset(CFG_ANALYSIS.dataset_name, split="validation")
189
+ char_to_id = {c: i + 1 for i, c in enumerate(characters)}
190
+ val_torch_dataset = CaptchaDataset(val_hf_dataset, char_to_id, transform=transform)
191
+ val_loader = DataLoader(val_torch_dataset, batch_size=CFG_ANALYSIS.batch_size, shuffle=False, collate_fn=collate_fn,num_workers=8,pin_memory=True)
192
+
193
+ model = CaptchaTransformer(num_classes=VOCAB_SIZE).to(CFG_ANALYSIS.device)
194
+ model.load_state_dict(torch.load(CFG_ANALYSIS.model_path, map_location=CFG_ANALYSIS.device))
195
+ model.eval()
196
+ print("Model loaded successfully.")
197
+
198
+ error_counts = Counter()
199
+ confusion_matrix = {char: Counter() for char in characters}
200
+
201
+ # --- 準備 list 來收集所有批次的結果 ---
202
+ all_preds_list = []
203
+ all_labels_list = []
204
+ all_label_lengths_list = []
205
+ all_indices_list = [] # <<< 新增:記錄每個樣本的原始索引
206
+
207
+ # ==========================================================
208
+ # ============[ Phase 1: GPU-Intensive Pass ]============
209
+ # ==========================================================
210
+ # 這個迴圈只負責模型推理,速度會非常快
211
+ with torch.no_grad():
212
+ for batch_idx, (images, labels, label_lengths) in enumerate(tqdm(val_loader, desc="Phase 1: GPU Inference")):
213
+ images = images.to(CFG_ANALYSIS.device)
214
+ with autocast(device_type=CFG_ANALYSIS.device, dtype=torch.bfloat16):
215
+ preds = model(images)
216
+
217
+ # 將結果從 GPU 移至 CPU RAM 並儲存
218
+ all_preds_list.append(preds.cpu())
219
+ all_labels_list.append(labels)
220
+ all_label_lengths_list.append(label_lengths)
221
+
222
+ # 記錄這個批次中每個樣本的原始索引
223
+ start_idx = batch_idx * CFG_ANALYSIS.batch_size
224
+ end_idx = start_idx + len(images)
225
+ all_indices_list.extend(range(start_idx, end_idx))
226
+
227
+ # ==========================================================
228
+ # ============[ Phase 2: CPU-Intensive Analysis ]===========
229
+ # ==========================================================
230
+ # 現在 GPU 已完成工作,我們在 CPU 上一次性處理所有收集到的結果
231
+ print("\nPhase 2: Analyzing results on CPU...")
232
+
233
+ # 1. 預測結果 (preds) 的序列長度是固定的,所��可以安全地合併
234
+ all_preds_tensor = torch.cat(all_preds_list, dim=1)
235
+ decoded_preds = greedy_decode(all_preds_tensor, id_to_char)
236
+
237
+ # 2. 標籤 (labels) 的長度是可變的,所以我們需要逐個批次處理
238
+ original_texts = []
239
+ for labels, label_lengths in zip(all_labels_list, all_label_lengths_list):
240
+ batch_texts = ["".join([id_to_char[l.item()] for l in label[:length]]) for label, length in zip(labels, label_lengths)]
241
+ original_texts.extend(batch_texts)
242
+
243
+ # 現在,我們在 CPU 上迴圈處理結果,進行分類和存檔
244
+ for i in tqdm(range(len(decoded_preds)), desc="Phase 2: Classifying errors and saving files"):
245
+ pred = decoded_preds[i]
246
+ target = original_texts[i]
247
+
248
+ if pred != target:
249
+ # --- 錯誤分類和統計 (和之前一樣) ---
250
+ error_type = "unknown"
251
+ if len(pred) != len(target):
252
+ error_type = "length_mismatch"
253
+ else:
254
+ distance = Levenshtein.distance(pred, target)
255
+ if distance == 1:
256
+ error_type = "substitution"
257
+ for j in range(len(pred)):
258
+ if pred[j] != target[j]:
259
+ confusion_matrix[target[j]][pred[j]] += 1
260
+ break
261
+ elif distance > 1:
262
+ error_type = "complex_error"
263
+ error_counts[error_type] += 1
264
+
265
+ # --- 存檔 (和之前一樣,但使用記錄好的索引) ---
266
+ error_dir = os.path.join(CFG_ANALYSIS.output_dir, error_type)
267
+ os.makedirs(error_dir, exist_ok=True)
268
+
269
+ original_idx = all_indices_list[i]
270
+ original_pil_image = val_hf_dataset[original_idx]['image']
271
+
272
+ pred_sanitized = re.sub(r'[\\/*?:"<>|]', "", pred) or "EMPTY"
273
+ target_sanitized = re.sub(r'[\\/*?:"<>|]', "", target)
274
+ filename = f"idx{original_idx}_pred_{pred_sanitized}_target_{target_sanitized}.png"
275
+ filepath = os.path.join(error_dir, filename)
276
+ original_pil_image.save(filepath)
277
+
278
+ # --- 3. 生成並打印統計報告 ---
279
+ total_errors = sum(error_counts.values())
280
+ report = "--- Error Analysis Report ---\n\n"
281
+ report += f"Total Errors Found: {total_errors}\n\n"
282
+ report += "Error Type Distribution:\n"
283
+ for error_type, count in error_counts.most_common():
284
+ percentage = (count / total_errors) * 100
285
+ report += f"- {error_type:<20}: {count:>5} errors ({percentage:.2f}%)\n"
286
+
287
+ # 找出最常見的15個替換錯誤
288
+ substitution_pairs = []
289
+ for target_char, preds in confusion_matrix.items():
290
+ for pred_char, count in preds.items():
291
+ if count > 0:
292
+ substitution_pairs.append(((target_char, pred_char), count))
293
+
294
+ # 按數量排序
295
+ top_15_substitutions = sorted(substitution_pairs, key=lambda item: item[1], reverse=True)[:15]
296
+
297
+ report += "\n\nTop 15 Character Substitution Errors (Target -> Prediction):\n"
298
+ for (target, pred), count in top_15_substitutions:
299
+ report += f"- '{target}' -> '{pred}': {count:>5} times\n"
300
+
301
+ print("\n" + report)
302
+
303
+ # 將報告寫入檔案
304
+ with open(os.path.join(CFG_ANALYSIS.output_dir, "report.txt"), "w", encoding="utf-8") as f:
305
+ f.write(report)
306
+
307
+ print(f"\nAnalysis complete. Report saved to '{os.path.join(CFG_ANALYSIS.output_dir, 'report.txt')}'")
308
+ print("You can now review the categorized images in the results folder.")
309
+
310
+ if __name__ == "__main__":
311
+ analyze_errors()