Update app.py (get weights using hf_hub_download)
Browse files
app.py
CHANGED
|
@@ -18,6 +18,8 @@ import cv2
|
|
| 18 |
from zim import zim_model_registry, ZimPredictor, ZimAutomaticMaskGenerator
|
| 19 |
from zim.utils import show_mat_anns
|
| 20 |
|
|
|
|
|
|
|
| 21 |
def get_shortest_axis(image):
|
| 22 |
h, w, _ = image.shape
|
| 23 |
return h if h < w else w
|
|
@@ -213,12 +215,17 @@ def get_examples():
|
|
| 213 |
images = os.listdir(assets_dir)
|
| 214 |
return [os.path.join(assets_dir, img) for img in images]
|
| 215 |
|
| 216 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 217 |
|
| 218 |
-
backbone = "vit_l"
|
| 219 |
-
ckpt_p = "ckpts/zim_vit_l_2092"
|
| 220 |
|
| 221 |
-
|
|
|
|
|
|
|
|
|
|
| 222 |
if torch.cuda.is_available():
|
| 223 |
model.cuda()
|
| 224 |
|
|
|
|
| 18 |
from zim import zim_model_registry, ZimPredictor, ZimAutomaticMaskGenerator
|
| 19 |
from zim.utils import show_mat_anns
|
| 20 |
|
| 21 |
+
from huggingface_hub import hf_hub_download
|
| 22 |
+
|
| 23 |
def get_shortest_axis(image):
|
| 24 |
h, w, _ = image.shape
|
| 25 |
return h if h < w else w
|
|
|
|
| 215 |
images = os.listdir(assets_dir)
|
| 216 |
return [os.path.join(assets_dir, img) for img in images]
|
| 217 |
|
| 218 |
+
def download_onnx_weights(repo_id="naver-iv/zim-anything-vitl", file_dir="zim_vit_l_2092"):
|
| 219 |
+
hf_hub_download(repo_id=repo_id, filename=f"{file_dir}/encoder.onnx")
|
| 220 |
+
filepath = hf_hub_download(repo_id=repo_id, filename=f"{file_dir}/decoder.onnx")
|
| 221 |
+
|
| 222 |
+
return os.path.dirname(filepath)
|
| 223 |
|
|
|
|
|
|
|
| 224 |
|
| 225 |
+
if __name__ == "__main__":
|
| 226 |
+
backbone = "vit_l"
|
| 227 |
+
model = zim_model_registry[backbone](checkpoint=download_onnx_weights())
|
| 228 |
+
|
| 229 |
if torch.cuda.is_available():
|
| 230 |
model.cuda()
|
| 231 |
|