Spaces:
Running
on
Zero
Running
on
Zero
| import importlib | |
| import json | |
| from . import utils | |
| class EnhancedCompileModel: | |
| def patch( | |
| self, | |
| model, | |
| is_patcher, | |
| object_to_patch, | |
| compiler, | |
| fullgraph, | |
| dynamic, | |
| mode, | |
| options, | |
| disable, | |
| backend, | |
| ): | |
| utils.patch_optimized_module() | |
| utils.patch_same_meta() | |
| import_path, function_name = compiler.rsplit(".", 1) | |
| module = importlib.import_module(import_path) | |
| compile_function = getattr(module, function_name) | |
| mode = mode if mode else None | |
| options = json.loads(options) if options else None | |
| if compiler == "torch.compile" and backend == "inductor" and dynamic: | |
| # TODO: Fix this | |
| # File "pytorch/torch/_inductor/fx_passes/post_grad.py", line 643, in same_meta | |
| # and statically_known_true(sym_eq(val1.size(), val2.size())) | |
| # AttributeError: 'SymInt' object has no attribute 'size' | |
| pass | |
| if is_patcher: | |
| patcher = model[0].clone() | |
| else: | |
| patcher = model.patcher | |
| patcher = patcher.clone() | |
| patcher.add_object_patch( | |
| object_to_patch, | |
| compile_function( | |
| patcher.get_model_object(object_to_patch), | |
| fullgraph=fullgraph, | |
| dynamic=dynamic, | |
| mode=mode, | |
| options=options, | |
| disable=disable, | |
| backend=backend, | |
| ), | |
| ) | |
| if is_patcher: | |
| return (patcher,) | |
| else: | |
| model.patcher = patcher | |
| return (model,) | |