Spaces:
Runtime error
Runtime error
Mehdi Cherti
commited on
Commit
·
2f12ebe
1
Parent(s):
b6996c9
load clip encoder from path if possible
Browse files- clip_encoder.py +7 -1
clip_encoder.py
CHANGED
|
@@ -2,7 +2,7 @@ import torch
|
|
| 2 |
import torch.nn as nn
|
| 3 |
import open_clip
|
| 4 |
from einops import rearrange
|
| 5 |
-
|
| 6 |
|
| 7 |
def exists(val):
|
| 8 |
return val is not None
|
|
@@ -11,6 +11,12 @@ class CLIPEncoder(nn.Module):
|
|
| 11 |
|
| 12 |
def __init__(self, model, pretrained):
|
| 13 |
super().__init__()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
self.model = model
|
| 15 |
self.pretrained = pretrained
|
| 16 |
self.model, _, _ = open_clip.create_model_and_transforms(model, pretrained=pretrained)
|
|
|
|
| 2 |
import torch.nn as nn
|
| 3 |
import open_clip
|
| 4 |
from einops import rearrange
|
| 5 |
+
import os
|
| 6 |
|
| 7 |
def exists(val):
|
| 8 |
return val is not None
|
|
|
|
| 11 |
|
| 12 |
def __init__(self, model, pretrained):
|
| 13 |
super().__init__()
|
| 14 |
+
#ViT_H_14_laion2b_s32b_b79k
|
| 15 |
+
fname = "models/" + model.replace("-", "_") + "_" + pretrained + ".pt"
|
| 16 |
+
|
| 17 |
+
if os.path.exists(fname):
|
| 18 |
+
print(fname)
|
| 19 |
+
pretrained = fname
|
| 20 |
self.model = model
|
| 21 |
self.pretrained = pretrained
|
| 22 |
self.model, _, _ = open_clip.create_model_and_transforms(model, pretrained=pretrained)
|