import gc import PIL.Image import torch from controlnet_aux import LineartDetector class Preprocessor: MODEL_ID = "lllyasviel/Annotators" def __init__(self): self.model = None self.name = "" def load(self, name: str) -> None: if name == self.name: return if name == "Lineart": self.model = LineartDetector.from_pretrained(self.MODEL_ID) else: raise ValueError torch.cuda.empty_cache() gc.collect() self.name = name def __call__(self, image: PIL.Image.Image, **kwargs) -> PIL.Image.Image: return self.model(image, **kwargs)