Spaces:
Runtime error
Runtime error
Fixed typing bugs in styleclip projection
Browse files
styleclip/styleclip_global.py
CHANGED
|
@@ -120,6 +120,7 @@ def get_direction(neutral_class, target_class, beta, di, clip_model=None):
|
|
| 120 |
|
| 121 |
dt = class_weights[:, 1] - class_weights[:, 0]
|
| 122 |
dt = dt / dt.norm()
|
|
|
|
| 123 |
relevance = di @ dt
|
| 124 |
mask = relevance.abs() > beta
|
| 125 |
direction = relevance * mask
|
|
@@ -151,7 +152,7 @@ def style_dict_to_style_tensor(style_dict, reference_generator):
|
|
| 151 |
return style_tensor
|
| 152 |
|
| 153 |
def project_code_with_styleclip(source_latent, source_class, target_class, alpha, beta, reference_generator, di, clip_model=None):
|
| 154 |
-
edit_direction = get_direction(source_class, target_class, beta)
|
| 155 |
|
| 156 |
source_s = style_dict_to_style_tensor(source_latent, reference_generator)
|
| 157 |
|
|
|
|
| 120 |
|
| 121 |
dt = class_weights[:, 1] - class_weights[:, 0]
|
| 122 |
dt = dt / dt.norm()
|
| 123 |
+
dt = dt.type(type(di))
|
| 124 |
relevance = di @ dt
|
| 125 |
mask = relevance.abs() > beta
|
| 126 |
direction = relevance * mask
|
|
|
|
| 152 |
return style_tensor
|
| 153 |
|
| 154 |
def project_code_with_styleclip(source_latent, source_class, target_class, alpha, beta, reference_generator, di, clip_model=None):
|
| 155 |
+
edit_direction = get_direction(source_class, target_class, beta, di, clip_model)
|
| 156 |
|
| 157 |
source_s = style_dict_to_style_tensor(source_latent, reference_generator)
|
| 158 |
|