Julian Bilcke
we are going to hack into finetrainers
9fd1204
raw
history blame
1.44 kB
from contextlib import contextmanager
from typing import List, Union
import torch
from diffusers.hooks import HookRegistry, ModelHook
_CONTROL_CHANNEL_CONCATENATE_HOOK = "FINETRAINERS_CONTROL_CHANNEL_CONCATENATE_HOOK"
class ControlChannelConcatenateHook(ModelHook):
def __init__(self, input_names: List[str], inputs: List[torch.Tensor], dims: List[int]):
self.input_names = input_names
self.inputs = inputs
self.dims = dims
def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
for input_name, input_tensor, dim in zip(self.input_names, self.inputs, self.dims):
original_tensor = args[input_name] if isinstance(input_name, int) else kwargs[input_name]
control_tensor = torch.cat([original_tensor, input_tensor], dim=dim)
if isinstance(input_name, int):
args[input_name] = control_tensor
else:
kwargs[input_name] = control_tensor
return args, kwargs
@contextmanager
def control_channel_concat(
module: torch.nn.Module, input_names: List[Union[int, str]], inputs: List[torch.Tensor], dims: List[int]
):
registry = HookRegistry.check_if_exists_or_initialize(module)
hook = ControlChannelConcatenateHook(input_names, inputs, dims)
registry.register_hook(hook, _CONTROL_CHANNEL_CONCATENATE_HOOK)
yield
registry.remove_hook(_CONTROL_CHANNEL_CONCATENATE_HOOK, recurse=False)