Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import torch.nn.functional as F | |
| def project_face_embs(pipeline, face_embs): | |
| ''' | |
| face_embs: (N, 512) normalized ArcFace embeddings | |
| ''' | |
| arcface_token_id = pipeline.tokenizer.encode("id", add_special_tokens=False)[0] | |
| input_ids = pipeline.tokenizer( | |
| "photo of a id person", | |
| truncation=True, | |
| padding="max_length", | |
| max_length=pipeline.tokenizer.model_max_length, | |
| return_tensors="pt", | |
| ).input_ids.to(pipeline.device) | |
| face_embs_padded = F.pad(face_embs, (0, pipeline.text_encoder.config.hidden_size-512), "constant", 0) | |
| token_embs = pipeline.text_encoder(input_ids=input_ids.repeat(len(face_embs), 1), return_token_embs=True) | |
| token_embs[input_ids==arcface_token_id] = face_embs_padded | |
| prompt_embeds = pipeline.text_encoder( | |
| input_ids=input_ids, | |
| input_token_embs=token_embs | |
| )[0] | |
| return prompt_embeds |