Spaces:
Runtime error
Runtime error
| """Modified from https://github.com/kijai/ComfyUI-MochiWrapper | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| def autocast_model_forward(cls, origin_dtype, *inputs, **kwargs): | |
| weight_dtype = cls.weight.dtype | |
| cls.to(origin_dtype) | |
| # Convert all inputs to the original dtype | |
| inputs = [input.to(origin_dtype) for input in inputs] | |
| out = cls.original_forward(*inputs, **kwargs) | |
| cls.to(weight_dtype) | |
| return out | |
| def replace_parameters_by_name(module, name_keywords, device): | |
| from torch import nn | |
| for name, param in list(module.named_parameters(recurse=False)): | |
| if any(keyword in name for keyword in name_keywords): | |
| if isinstance(param, nn.Parameter): | |
| tensor = param.data | |
| delattr(module, name) | |
| setattr(module, name, tensor.to(device=device)) | |
| for child_name, child_module in module.named_children(): | |
| replace_parameters_by_name(child_module, name_keywords, device) | |
| def convert_model_weight_to_float8(model, exclude_module_name=['embed_tokens']): | |
| for name, module in model.named_modules(): | |
| flag = False | |
| for _exclude_module_name in exclude_module_name: | |
| if _exclude_module_name in name: | |
| flag = True | |
| if flag: | |
| continue | |
| for param_name, param in module.named_parameters(): | |
| flag = False | |
| for _exclude_module_name in exclude_module_name: | |
| if _exclude_module_name in param_name: | |
| flag = True | |
| if flag: | |
| continue | |
| param.data = param.data.to(torch.float8_e4m3fn) | |
| def convert_weight_dtype_wrapper(module, origin_dtype): | |
| for name, module in module.named_modules(): | |
| if name == "" or "embed_tokens" in name: | |
| continue | |
| original_forward = module.forward | |
| if hasattr(module, "weight") and module.weight is not None: | |
| setattr(module, "original_forward", original_forward) | |
| setattr( | |
| module, | |
| "forward", | |
| lambda *inputs, m=module, **kwargs: autocast_model_forward(m, origin_dtype, *inputs, **kwargs) | |
| ) | |