Spaces:
Runtime error
Runtime error
| import torch | |
| from contextlib import contextmanager | |
| high_vram = False | |
| gpu = torch.device('cuda') | |
| cpu = torch.device('cpu') | |
| torch.zeros((1, 1)).to(gpu, torch.float32) | |
| torch.cuda.empty_cache() | |
| models_in_gpu = [] | |
| def movable_bnb_model(m): | |
| if hasattr(m, 'quantization_method'): | |
| m.quantization_method_backup = m.quantization_method | |
| del m.quantization_method | |
| try: | |
| yield None | |
| finally: | |
| if hasattr(m, 'quantization_method_backup'): | |
| m.quantization_method = m.quantization_method_backup | |
| del m.quantization_method_backup | |
| return | |
| def load_models_to_gpu(models): | |
| global models_in_gpu | |
| if not isinstance(models, (tuple, list)): | |
| models = [models] | |
| models_to_remain = [m for m in set(models) if m in models_in_gpu] | |
| models_to_load = [m for m in set(models) if m not in models_in_gpu] | |
| models_to_unload = [m for m in set(models_in_gpu) if m not in models_to_remain] | |
| if not high_vram: | |
| for m in models_to_unload: | |
| with movable_bnb_model(m): | |
| m.to(cpu) | |
| print('Unload to CPU:', m.__class__.__name__) | |
| models_in_gpu = models_to_remain | |
| for m in models_to_load: | |
| with movable_bnb_model(m): | |
| m.to(gpu) | |
| print('Load to GPU:', m.__class__.__name__) | |
| models_in_gpu = list(set(models_in_gpu + models)) | |
| torch.cuda.empty_cache() | |
| return | |
| def unload_all_models(extra_models=None): | |
| global models_in_gpu | |
| if extra_models is None: | |
| extra_models = [] | |
| if not isinstance(extra_models, (tuple, list)): | |
| extra_models = [extra_models] | |
| models_in_gpu = list(set(models_in_gpu + extra_models)) | |
| return load_models_to_gpu([]) | |