AIvry commited on
Commit
22ed625
·
verified ·
1 Parent(s): e82dbc3

Upload models.py

Browse files
Files changed (1) hide show
  1. models.py +359 -332
models.py CHANGED
@@ -1,333 +1,360 @@
1
- import queue
2
- import threading
3
- import gc
4
-
5
- import torch
6
- import torch.nn.functional as F
7
- from transformers import (
8
- HubertModel,
9
- Wav2Vec2FeatureExtractor,
10
- Wav2Vec2Model,
11
- WavLMModel,
12
- ASTModel,
13
- AutoFeatureExtractor,
14
- )
15
-
16
- from config import BATCH_SIZE, ENERGY_HOP_MS, ENERGY_WIN_MS, SR
17
- from utils import get_gpu_count
18
-
19
-
20
- class BalancedDualGPUModel:
21
-
22
- def __init__(self, model_name, layer, max_gpus=None):
23
- self.layer = layer
24
- self.models = []
25
- self.extractors = []
26
- self.devices = []
27
- ngpu = get_gpu_count(max_gpus)
28
-
29
- for gpu_id in range(min(ngpu, 2)):
30
- device = f"cuda:{gpu_id}"
31
- self.devices.append(device)
32
- ckpt, cls, _ = get_model_config(layer)[model_name]
33
- if cls is ASTModel:
34
- extractor = AutoFeatureExtractor.from_pretrained(ckpt)
35
- else:
36
- extractor = Wav2Vec2FeatureExtractor.from_pretrained(ckpt)
37
-
38
- attn_impl = "eager" if cls in (WavLMModel, ASTModel) else "sdpa"
39
- model = cls.from_pretrained(
40
- ckpt,
41
- output_hidden_states=True,
42
- use_safetensors=True,
43
- torch_dtype=torch.float16,
44
- low_cpu_mem_usage=True,
45
- attn_implementation=attn_impl
46
- )
47
- model.eval()
48
- model = model.to(device)
49
-
50
- for param in model.parameters():
51
- param.requires_grad = False
52
-
53
- self.extractors.append(extractor)
54
- self.models.append(model)
55
-
56
- self.gpu_queues = [queue.Queue() for _ in range(len(self.devices))]
57
- self.result_queue = queue.Queue()
58
- self.workers = []
59
- for i in range(len(self.devices)):
60
- worker = threading.Thread(target=self._gpu_worker, args=(i,))
61
- worker.daemon = True
62
- worker.start()
63
- self.workers.append(worker)
64
-
65
- def _gpu_worker(self, gpu_id):
66
- device = self.devices[gpu_id]
67
- model = self.models[gpu_id]
68
- extractor = self.extractors[gpu_id]
69
- while True:
70
- task = self.gpu_queues[gpu_id].get()
71
- if task is None:
72
- break
73
- signals, masks, use_mlm, task_id = task
74
- try:
75
- inputs = extractor(
76
- signals, sampling_rate=SR, return_tensors="pt", padding=True
77
- )
78
- input_values = inputs.input_values.to(device, non_blocking=True)
79
-
80
- torch.cuda.empty_cache()
81
-
82
- orig_mode = model.training
83
- model.train() if use_mlm else model.eval()
84
- with torch.no_grad():
85
- with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
86
- hs = model(
87
- input_values, output_hidden_states=True
88
- ).hidden_states[self.layer]
89
- model.train(orig_mode)
90
-
91
- B, T, D = hs.shape
92
- keep = []
93
- for b in range(B):
94
- mask_b = masks[b].float().unsqueeze(0).unsqueeze(0).to(device)
95
- mask_t = F.interpolate(mask_b, size=T, mode="nearest")[0, 0].bool()
96
- keep.append(hs[b][mask_t].cpu())
97
-
98
- # Aggressive cleanup
99
- del hs, input_values, inputs
100
- torch.cuda.empty_cache()
101
-
102
- if keep:
103
- L_max = max(x.shape[0] for x in keep)
104
- keep_padded = [
105
- F.pad(x, (0, 0, 0, L_max - x.shape[0])) for x in keep
106
- ]
107
- result = torch.stack(keep_padded, dim=0)
108
- else:
109
- result = torch.empty(0, 0, 0)
110
- self.result_queue.put((task_id, result))
111
- except Exception as e:
112
- self.result_queue.put((task_id, e))
113
- finally:
114
- # Always clear cache after processing
115
- torch.cuda.empty_cache()
116
-
117
- def process_batch(self, signals, masks, use_mlm=False):
118
- if not signals:
119
- return torch.empty(0, 0, 0)
120
- batch_size = len(signals)
121
- split = (batch_size + len(self.devices) - 1) // len(self.devices)
122
- results = {}
123
- task_id = 0
124
- for i in range(0, batch_size, split):
125
- end = min(i + split, batch_size)
126
- gpu_id = (i // split) % len(self.devices)
127
- self.gpu_queues[gpu_id].put(
128
- (signals[i:end], masks[i:end], use_mlm, task_id)
129
- )
130
- task_id += 1
131
- for _ in range(task_id):
132
- tid, result = self.result_queue.get()
133
- if isinstance(result, Exception):
134
- raise result
135
- results[tid] = result
136
- parts = [results[i] for i in range(task_id) if results[i].numel() > 0]
137
- return torch.cat(parts, dim=0) if parts else torch.empty(0, 0, 0)
138
-
139
- def cleanup(self):
140
- """Explicit cleanup method"""
141
- for q in self.gpu_queues:
142
- q.put(None)
143
- for w in self.workers:
144
- w.join(timeout=5.0)
145
- for model in self.models:
146
- del model
147
- for extractor in self.extractors:
148
- del extractor
149
- self.models.clear()
150
- self.extractors.clear()
151
- torch.cuda.empty_cache()
152
- gc.collect()
153
-
154
- def __del__(self):
155
- self.cleanup()
156
-
157
-
158
- # NO CACHE - we need to clean up models properly between runs
159
- def get_model_config(layer):
160
- return {
161
- "raw": (None, None, None),
162
- "wavlm": ("microsoft/wavlm-large", WavLMModel, layer),
163
- "wav2vec2": ("facebook/wav2vec2-large-lv60", Wav2Vec2Model, layer),
164
- "hubert": ("facebook/hubert-large-ll60k", HubertModel, layer),
165
- "wavlm_base": ("microsoft/wavlm-base", WavLMModel, layer),
166
- "wav2vec2_base": ("facebook/wav2vec2-base", Wav2Vec2Model, layer),
167
- "hubert_base": ("facebook/hubert-base-ls960", HubertModel, layer),
168
- "wav2vec2_xlsr": ("facebook/wav2vec2-large-xlsr-53", Wav2Vec2Model, layer),
169
- "ast": ("MIT/ast-finetuned-audioset-10-10-0.4593", ASTModel, layer),
170
- }
171
-
172
-
173
- # Store loaded models globally to properly manage them
174
- _loaded_models = {}
175
-
176
-
177
- def load_model(name, layer, max_gpus=None):
178
- global _loaded_models
179
-
180
- # Clean up any previously loaded models first
181
- if _loaded_models:
182
- for key, model_data in _loaded_models.items():
183
- if isinstance(model_data, tuple) and len(model_data) == 2:
184
- if isinstance(model_data[0], BalancedDualGPUModel):
185
- model_data[0].cleanup()
186
- elif isinstance(model_data[0], tuple):
187
- # Single GPU model
188
- _, model = model_data[0]
189
- del model
190
- _loaded_models.clear()
191
- torch.cuda.empty_cache()
192
- gc.collect()
193
-
194
- if name.lower() in {"raw", "waveform"}:
195
- return "raw", layer
196
-
197
- ngpu = get_gpu_count(max_gpus)
198
- if ngpu > 1:
199
- model = BalancedDualGPUModel(name, layer, max_gpus)
200
- _loaded_models[name] = (model, layer)
201
- return model, layer
202
- else:
203
- ckpt, cls, layer_eff = get_model_config(layer)[name]
204
- if cls is ASTModel:
205
- extractor = AutoFeatureExtractor.from_pretrained(ckpt)
206
- else:
207
- extractor = Wav2Vec2FeatureExtractor.from_pretrained(ckpt)
208
-
209
- device = "cuda:0" if torch.cuda.is_available() else "cpu"
210
- attn_impl = "eager" if cls in (WavLMModel, ASTModel) else "sdpa"
211
- model = cls.from_pretrained(
212
- ckpt,
213
- output_hidden_states=True,
214
- use_safetensors=True,
215
- torch_dtype=torch.float16,
216
- low_cpu_mem_usage=True,
217
- attn_implementation=attn_impl
218
- )
219
- model.eval()
220
- model = model.to(device)
221
-
222
- for param in model.parameters():
223
- param.requires_grad = False
224
-
225
- model_tuple = ((extractor, model), layer_eff)
226
- _loaded_models[name] = model_tuple
227
- return (extractor, model), layer_eff
228
-
229
-
230
- def cleanup_all_models():
231
- """Call this at the end of each experiment to ensure complete cleanup"""
232
- global _loaded_models
233
- if _loaded_models:
234
- for key, model_data in _loaded_models.items():
235
- if isinstance(model_data, tuple) and len(model_data) == 2:
236
- if isinstance(model_data[0], BalancedDualGPUModel):
237
- model_data[0].cleanup()
238
- elif isinstance(model_data[0], tuple):
239
- # Single GPU model
240
- _, model = model_data[0]
241
- del model
242
- _loaded_models.clear()
243
- torch.cuda.empty_cache()
244
- gc.collect()
245
-
246
-
247
- def embed_batch_raw(signals, masks_audio):
248
- win = int(ENERGY_WIN_MS * SR / 1000)
249
- hop = int(ENERGY_HOP_MS * SR / 1000)
250
- reps, L_max = [], 0
251
- for sig_np, mask_np in zip(signals, masks_audio):
252
- x = torch.as_tensor(sig_np[:-1], dtype=torch.float32)
253
- frames = x.unfold(0, win, hop)
254
- mask = torch.as_tensor(mask_np[: len(frames)], dtype=torch.bool)
255
- keep = frames[mask] if mask.any() else frames[:1]
256
- reps.append(keep)
257
- L_max = max(L_max, keep.size(0))
258
- reps = [F.pad(r, (0, 0, 0, L_max - r.size(0))) for r in reps]
259
- return torch.stack(reps, dim=0)
260
-
261
-
262
- def embed_batch_single_gpu(
263
- signals, masks_audio, extractor, model, layer, use_mlm=False
264
- ):
265
- if not signals:
266
- return torch.empty(0, 0, 0)
267
- device = next(model.parameters()).device
268
-
269
- max_batch = 2
270
- all_keeps = []
271
-
272
- for i in range(0, len(signals), max_batch):
273
- batch_signals = signals[i:i + max_batch]
274
- batch_masks = masks_audio[i:i + max_batch]
275
-
276
- inputs = extractor(batch_signals, sampling_rate=SR, return_tensors="pt", padding=True)
277
- input_values = inputs.input_values.to(device, non_blocking=True)
278
-
279
- orig_mode = model.training
280
- model.train() if use_mlm else model.eval()
281
-
282
- with torch.no_grad():
283
- with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
284
- hs = model(input_values, output_hidden_states=True).hidden_states[layer]
285
- model.train(orig_mode)
286
-
287
- B, T, D = hs.shape
288
- for b in range(B):
289
- mask_b = batch_masks[b].float().unsqueeze(0).unsqueeze(0).to(device)
290
- mask_t = F.interpolate(mask_b, size=T, mode="nearest")[0, 0].bool()
291
- all_keeps.append(hs[b][mask_t].cpu())
292
-
293
- # Aggressive cleanup
294
- del hs, input_values, inputs
295
- torch.cuda.empty_cache()
296
-
297
- if all_keeps:
298
- L_max = max(x.shape[0] for x in all_keeps)
299
- keep_padded = [F.pad(x, (0, 0, 0, L_max - x.shape[0])) for x in all_keeps]
300
- result = torch.stack(keep_padded, dim=0)
301
- # Clean up intermediate lists
302
- del all_keeps, keep_padded
303
- return result
304
- else:
305
- return torch.empty(0, 0, 0)
306
-
307
-
308
- def embed_batch(signals, masks_audio, model_wrapper, layer, use_mlm=False):
309
- if model_wrapper == "raw":
310
- return embed_batch_raw(signals, masks_audio)
311
- if isinstance(model_wrapper, BalancedDualGPUModel):
312
- all_embeddings = []
313
- batch_size = min(BATCH_SIZE, 2)
314
- for i in range(0, len(signals), batch_size):
315
- batch_emb = model_wrapper.process_batch(
316
- signals[i: i + batch_size], masks_audio[i: i + batch_size], use_mlm
317
- )
318
- if batch_emb.numel() > 0:
319
- all_embeddings.append(batch_emb)
320
- # Clear cache after each batch
321
- torch.cuda.empty_cache()
322
-
323
- if all_embeddings:
324
- result = torch.cat(all_embeddings, dim=0)
325
- del all_embeddings
326
- return result
327
- else:
328
- return torch.empty(0, 0, 0)
329
- else:
330
- extractor, model = model_wrapper
331
- return embed_batch_single_gpu(
332
- signals, masks_audio, extractor, model, layer, use_mlm=use_mlm
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
333
  )
 
1
+ import queue
2
+ import threading
3
+ import gc
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from transformers import (
8
+ HubertModel,
9
+ Wav2Vec2FeatureExtractor,
10
+ Wav2Vec2Model,
11
+ WavLMModel,
12
+ ASTModel,
13
+ AutoFeatureExtractor,
14
+ )
15
+
16
+ from config import BATCH_SIZE, ENERGY_HOP_MS, ENERGY_WIN_MS, SR
17
+ from utils import get_gpu_count
18
+
19
+
20
+ class BalancedDualGPUModel:
21
+
22
+ def __init__(self, model_name, layer, max_gpus=None):
23
+ self.layer = layer
24
+ self.models = []
25
+ self.extractors = []
26
+ self.devices = []
27
+ ngpu = get_gpu_count(max_gpus)
28
+
29
+ # This class should only be used when GPUs are available
30
+ if ngpu == 0:
31
+ raise RuntimeError("BalancedDualGPUModel requires at least 1 GPU")
32
+
33
+ for gpu_id in range(min(ngpu, 2)):
34
+ device = f"cuda:{gpu_id}"
35
+ self.devices.append(device)
36
+ ckpt, cls, _ = get_model_config(layer)[model_name]
37
+ if cls is ASTModel:
38
+ extractor = AutoFeatureExtractor.from_pretrained(ckpt)
39
+ else:
40
+ extractor = Wav2Vec2FeatureExtractor.from_pretrained(ckpt)
41
+
42
+ attn_impl = "eager" if cls in (WavLMModel, ASTModel) else "sdpa"
43
+ # Use float32 for better compatibility
44
+ model = cls.from_pretrained(
45
+ ckpt,
46
+ output_hidden_states=True,
47
+ use_safetensors=True,
48
+ torch_dtype=torch.float32, # Changed from float16
49
+ low_cpu_mem_usage=True,
50
+ attn_implementation=attn_impl
51
+ )
52
+ model.eval()
53
+ model = model.to(device)
54
+
55
+ for param in model.parameters():
56
+ param.requires_grad = False
57
+
58
+ self.extractors.append(extractor)
59
+ self.models.append(model)
60
+
61
+ self.gpu_queues = [queue.Queue() for _ in range(len(self.devices))]
62
+ self.result_queue = queue.Queue()
63
+ self.workers = []
64
+ for i in range(len(self.devices)):
65
+ worker = threading.Thread(target=self._gpu_worker, args=(i,))
66
+ worker.daemon = True
67
+ worker.start()
68
+ self.workers.append(worker)
69
+
70
+ def _gpu_worker(self, gpu_id):
71
+ device = self.devices[gpu_id]
72
+ model = self.models[gpu_id]
73
+ extractor = self.extractors[gpu_id]
74
+ while True:
75
+ task = self.gpu_queues[gpu_id].get()
76
+ if task is None:
77
+ break
78
+ signals, masks, use_mlm, task_id = task
79
+ try:
80
+ inputs = extractor(
81
+ signals, sampling_rate=SR, return_tensors="pt", padding=True
82
+ )
83
+ input_values = inputs.input_values.to(device, non_blocking=True)
84
+
85
+ torch.cuda.empty_cache()
86
+
87
+ orig_mode = model.training
88
+ model.train() if use_mlm else model.eval()
89
+ with torch.no_grad():
90
+ # Only use autocast on actual GPUs with float16 support
91
+ if torch.cuda.is_available() and device.startswith('cuda'):
92
+ with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
93
+ hs = model(
94
+ input_values, output_hidden_states=True
95
+ ).hidden_states[self.layer]
96
+ else:
97
+ hs = model(
98
+ input_values, output_hidden_states=True
99
+ ).hidden_states[self.layer]
100
+ model.train(orig_mode)
101
+
102
+ B, T, D = hs.shape
103
+ keep = []
104
+ for b in range(B):
105
+ mask_b = masks[b].float().unsqueeze(0).unsqueeze(0).to(device)
106
+ mask_t = F.interpolate(mask_b, size=T, mode="nearest")[0, 0].bool()
107
+ keep.append(hs[b][mask_t].cpu())
108
+
109
+ # Aggressive cleanup
110
+ del hs, input_values, inputs
111
+ torch.cuda.empty_cache()
112
+
113
+ if keep:
114
+ L_max = max(x.shape[0] for x in keep)
115
+ keep_padded = [
116
+ F.pad(x, (0, 0, 0, L_max - x.shape[0])) for x in keep
117
+ ]
118
+ result = torch.stack(keep_padded, dim=0)
119
+ else:
120
+ result = torch.empty(0, 0, 0)
121
+ self.result_queue.put((task_id, result))
122
+ except Exception as e:
123
+ self.result_queue.put((task_id, e))
124
+ finally:
125
+ # Always clear cache after processing
126
+ torch.cuda.empty_cache()
127
+
128
+ def process_batch(self, signals, masks, use_mlm=False):
129
+ if not signals:
130
+ return torch.empty(0, 0, 0)
131
+ batch_size = len(signals)
132
+ split = (batch_size + len(self.devices) - 1) // len(self.devices)
133
+ results = {}
134
+ task_id = 0
135
+ for i in range(0, batch_size, split):
136
+ end = min(i + split, batch_size)
137
+ gpu_id = (i // split) % len(self.devices)
138
+ self.gpu_queues[gpu_id].put(
139
+ (signals[i:end], masks[i:end], use_mlm, task_id)
140
+ )
141
+ task_id += 1
142
+ for _ in range(task_id):
143
+ tid, result = self.result_queue.get()
144
+ if isinstance(result, Exception):
145
+ raise result
146
+ results[tid] = result
147
+ parts = [results[i] for i in range(task_id) if results[i].numel() > 0]
148
+ return torch.cat(parts, dim=0) if parts else torch.empty(0, 0, 0)
149
+
150
+ def cleanup(self):
151
+ """Explicit cleanup method"""
152
+ for q in self.gpu_queues:
153
+ q.put(None)
154
+ for w in self.workers:
155
+ w.join(timeout=5.0)
156
+ for model in self.models:
157
+ del model
158
+ for extractor in self.extractors:
159
+ del extractor
160
+ self.models.clear()
161
+ self.extractors.clear()
162
+ torch.cuda.empty_cache()
163
+ gc.collect()
164
+
165
+ def __del__(self):
166
+ self.cleanup()
167
+
168
+
169
+ # NO CACHE - we need to clean up models properly between runs
170
+ def get_model_config(layer):
171
+ return {
172
+ "raw": (None, None, None),
173
+ "wavlm": ("microsoft/wavlm-large", WavLMModel, layer),
174
+ "wav2vec2": ("facebook/wav2vec2-large-lv60", Wav2Vec2Model, layer),
175
+ "hubert": ("facebook/hubert-large-ll60k", HubertModel, layer),
176
+ "wavlm_base": ("microsoft/wavlm-base", WavLMModel, layer),
177
+ "wav2vec2_base": ("facebook/wav2vec2-base", Wav2Vec2Model, layer),
178
+ "hubert_base": ("facebook/hubert-base-ls960", HubertModel, layer),
179
+ "wav2vec2_xlsr": ("facebook/wav2vec2-large-xlsr-53", Wav2Vec2Model, layer),
180
+ "ast": ("MIT/ast-finetuned-audioset-10-10-0.4593", ASTModel, layer),
181
+ }
182
+
183
+
184
+ # Store loaded models globally to properly manage them
185
+ _loaded_models = {}
186
+
187
+
188
+ def load_model(name, layer, max_gpus=None):
189
+ global _loaded_models
190
+
191
+ # Clean up any previously loaded models first
192
+ if _loaded_models:
193
+ for key, model_data in _loaded_models.items():
194
+ if isinstance(model_data, tuple) and len(model_data) == 2:
195
+ if isinstance(model_data[0], BalancedDualGPUModel):
196
+ model_data[0].cleanup()
197
+ elif isinstance(model_data[0], tuple):
198
+ # Single GPU model
199
+ _, model = model_data[0]
200
+ del model
201
+ _loaded_models.clear()
202
+ if torch.cuda.is_available():
203
+ torch.cuda.empty_cache()
204
+ gc.collect()
205
+
206
+ if name.lower() in {"raw", "waveform"}:
207
+ return "raw", layer
208
+
209
+ ngpu = get_gpu_count(max_gpus)
210
+
211
+ # Only use BalancedDualGPUModel if we have multiple GPUs
212
+ if ngpu > 1:
213
+ model = BalancedDualGPUModel(name, layer, max_gpus)
214
+ _loaded_models[name] = (model, layer)
215
+ return model, layer
216
+ else:
217
+ ckpt, cls, layer_eff = get_model_config(layer)[name]
218
+ if cls is ASTModel:
219
+ extractor = AutoFeatureExtractor.from_pretrained(ckpt)
220
+ else:
221
+ extractor = Wav2Vec2FeatureExtractor.from_pretrained(ckpt)
222
+
223
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
224
+ attn_impl = "eager" if cls in (WavLMModel, ASTModel) else "sdpa"
225
+
226
+ # CRITICAL FIX: Always use float32 for CPU compatibility
227
+ model = cls.from_pretrained(
228
+ ckpt,
229
+ output_hidden_states=True,
230
+ use_safetensors=True,
231
+ torch_dtype=torch.float32, # Changed from float16 to float32
232
+ low_cpu_mem_usage=True,
233
+ attn_implementation=attn_impl
234
+ )
235
+ model.eval()
236
+ model = model.to(device)
237
+
238
+ for param in model.parameters():
239
+ param.requires_grad = False
240
+
241
+ model_tuple = ((extractor, model), layer_eff)
242
+ _loaded_models[name] = model_tuple
243
+ return (extractor, model), layer_eff
244
+
245
+
246
+ def cleanup_all_models():
247
+ """Call this at the end of each experiment to ensure complete cleanup"""
248
+ global _loaded_models
249
+ if _loaded_models:
250
+ for key, model_data in _loaded_models.items():
251
+ if isinstance(model_data, tuple) and len(model_data) == 2:
252
+ if isinstance(model_data[0], BalancedDualGPUModel):
253
+ model_data[0].cleanup()
254
+ elif isinstance(model_data[0], tuple):
255
+ # Single GPU model
256
+ _, model = model_data[0]
257
+ del model
258
+ _loaded_models.clear()
259
+ if torch.cuda.is_available():
260
+ torch.cuda.empty_cache()
261
+ gc.collect()
262
+
263
+
264
+ def embed_batch_raw(signals, masks_audio):
265
+ win = int(ENERGY_WIN_MS * SR / 1000)
266
+ hop = int(ENERGY_HOP_MS * SR / 1000)
267
+ reps, L_max = [], 0
268
+ for sig_np, mask_np in zip(signals, masks_audio):
269
+ x = torch.as_tensor(sig_np[:-1], dtype=torch.float32)
270
+ frames = x.unfold(0, win, hop)
271
+ mask = torch.as_tensor(mask_np[: len(frames)], dtype=torch.bool)
272
+ keep = frames[mask] if mask.any() else frames[:1]
273
+ reps.append(keep)
274
+ L_max = max(L_max, keep.size(0))
275
+ reps = [F.pad(r, (0, 0, 0, L_max - r.size(0))) for r in reps]
276
+ return torch.stack(reps, dim=0)
277
+
278
+
279
+ def embed_batch_single_gpu(
280
+ signals, masks_audio, extractor, model, layer, use_mlm=False
281
+ ):
282
+ if not signals:
283
+ return torch.empty(0, 0, 0)
284
+ device = next(model.parameters()).device
285
+ is_cuda = device.type == 'cuda'
286
+
287
+ max_batch = 2
288
+ all_keeps = []
289
+
290
+ for i in range(0, len(signals), max_batch):
291
+ batch_signals = signals[i:i + max_batch]
292
+ batch_masks = masks_audio[i:i + max_batch]
293
+
294
+ inputs = extractor(batch_signals, sampling_rate=SR, return_tensors="pt", padding=True)
295
+ input_values = inputs.input_values.to(device, non_blocking=is_cuda)
296
+
297
+ orig_mode = model.training
298
+ model.train() if use_mlm else model.eval()
299
+
300
+ with torch.no_grad():
301
+ # CRITICAL FIX: Don't use autocast on CPU
302
+ if is_cuda:
303
+ # On GPU, we can use autocast with float16 for speed
304
+ with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
305
+ hs = model(input_values, output_hidden_states=True).hidden_states[layer]
306
+ else:
307
+ # On CPU, just run the model directly without autocast
308
+ hs = model(input_values, output_hidden_states=True).hidden_states[layer]
309
+
310
+ model.train(orig_mode)
311
+
312
+ B, T, D = hs.shape
313
+ for b in range(B):
314
+ mask_b = batch_masks[b].float().unsqueeze(0).unsqueeze(0).to(device)
315
+ mask_t = F.interpolate(mask_b, size=T, mode="nearest")[0, 0].bool()
316
+ all_keeps.append(hs[b][mask_t].cpu())
317
+
318
+ # Aggressive cleanup
319
+ del hs, input_values, inputs
320
+ if is_cuda:
321
+ torch.cuda.empty_cache()
322
+
323
+ if all_keeps:
324
+ L_max = max(x.shape[0] for x in all_keeps)
325
+ keep_padded = [F.pad(x, (0, 0, 0, L_max - x.shape[0])) for x in all_keeps]
326
+ result = torch.stack(keep_padded, dim=0)
327
+ # Clean up intermediate lists
328
+ del all_keeps, keep_padded
329
+ return result
330
+ else:
331
+ return torch.empty(0, 0, 0)
332
+
333
+
334
+ def embed_batch(signals, masks_audio, model_wrapper, layer, use_mlm=False):
335
+ if model_wrapper == "raw":
336
+ return embed_batch_raw(signals, masks_audio)
337
+ if isinstance(model_wrapper, BalancedDualGPUModel):
338
+ all_embeddings = []
339
+ batch_size = min(BATCH_SIZE, 2)
340
+ for i in range(0, len(signals), batch_size):
341
+ batch_emb = model_wrapper.process_batch(
342
+ signals[i: i + batch_size], masks_audio[i: i + batch_size], use_mlm
343
+ )
344
+ if batch_emb.numel() > 0:
345
+ all_embeddings.append(batch_emb)
346
+ # Clear cache after each batch
347
+ if torch.cuda.is_available():
348
+ torch.cuda.empty_cache()
349
+
350
+ if all_embeddings:
351
+ result = torch.cat(all_embeddings, dim=0)
352
+ del all_embeddings
353
+ return result
354
+ else:
355
+ return torch.empty(0, 0, 0)
356
+ else:
357
+ extractor, model = model_wrapper
358
+ return embed_batch_single_gpu(
359
+ signals, masks_audio, extractor, model, layer, use_mlm=use_mlm
360
  )