Spaces:
Running
on
Zero
Running
on
Zero
Update optimization_utils.py
Browse files- optimization_utils.py +3 -28
optimization_utils.py
CHANGED
|
@@ -98,35 +98,10 @@ def capture_component_call(
|
|
| 98 |
captured_call.kwargs = e.kwargs
|
| 99 |
|
| 100 |
|
| 101 |
-
# def drain_module_parameters(module: torch.nn.Module):
|
| 102 |
-
# state_dict_meta = {name: {'device': tensor.device, 'dtype': tensor.dtype} for name, tensor in module.state_dict().items()}
|
| 103 |
-
# state_dict = {name: torch.nn.Parameter(torch.empty_like(tensor, device='cpu')) for name, tensor in module.state_dict().items()}
|
| 104 |
-
# module.load_state_dict(state_dict, assign=True)
|
| 105 |
-
# for name, param in state_dict.items():
|
| 106 |
-
# meta = state_dict_meta[name]
|
| 107 |
-
# param.data = torch.Tensor([]).to(**meta)
|
| 108 |
-
|
| 109 |
def drain_module_parameters(module: torch.nn.Module):
|
| 110 |
-
state_dict_meta = {
|
| 111 |
-
|
| 112 |
-
for name, tensor in module.state_dict().items()
|
| 113 |
-
}
|
| 114 |
-
|
| 115 |
-
state_dict = {}
|
| 116 |
-
for name, tensor in module.state_dict().items():
|
| 117 |
-
try:
|
| 118 |
-
param = torch.nn.Parameter(torch.empty_like(tensor, device='cpu'))
|
| 119 |
-
except NotImplementedError:
|
| 120 |
-
# Fallback: dequantize (or convert) if empty_like isn't implemented
|
| 121 |
-
param = torch.nn.Parameter(tensor.dequantize().to('cpu') if hasattr(tensor, 'dequantize') else tensor.to('cpu'))
|
| 122 |
-
state_dict[name] = param
|
| 123 |
-
|
| 124 |
module.load_state_dict(state_dict, assign=True)
|
| 125 |
-
|
| 126 |
for name, param in state_dict.items():
|
| 127 |
meta = state_dict_meta[name]
|
| 128 |
-
|
| 129 |
-
param.data = torch.Tensor([]).to(**meta)
|
| 130 |
-
except NotImplementedError:
|
| 131 |
-
# Fallback for quantized tensors
|
| 132 |
-
param.data = (param.dequantize().to(**meta) if hasattr(param, 'dequantize') else torch.Tensor([]).to(**meta))
|
|
|
|
| 98 |
captured_call.kwargs = e.kwargs
|
| 99 |
|
| 100 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
def drain_module_parameters(module: torch.nn.Module):
|
| 102 |
+
state_dict_meta = {name: {'device': tensor.device, 'dtype': tensor.dtype} for name, tensor in module.state_dict().items()}
|
| 103 |
+
state_dict = {name: torch.nn.Parameter(torch.empty_like(tensor, device='cpu')) for name, tensor in module.state_dict().items()}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
module.load_state_dict(state_dict, assign=True)
|
|
|
|
| 105 |
for name, param in state_dict.items():
|
| 106 |
meta = state_dict_meta[name]
|
| 107 |
+
param.data = torch.Tensor([]).to(**meta)
|
|
|
|
|
|
|
|
|
|
|
|