hsshin98
commited on
Commit
·
dfe1f0b
1
Parent(s):
ed81860
cpu
Browse files
app.py
CHANGED
|
@@ -4,7 +4,6 @@ import argparse
|
|
| 4 |
import glob
|
| 5 |
import multiprocessing as mp
|
| 6 |
import os
|
| 7 |
-
os.system('pip install git+https://github.com/facebookresearch/detectron2.git')
|
| 8 |
|
| 9 |
# fmt: off
|
| 10 |
import sys
|
|
@@ -40,6 +39,7 @@ def setup_cfg(args):
|
|
| 40 |
add_cat_seg_config(cfg)
|
| 41 |
cfg.merge_from_file(args.config_file)
|
| 42 |
cfg.merge_from_list(args.opts)
|
|
|
|
| 43 |
cfg.freeze()
|
| 44 |
return cfg
|
| 45 |
|
|
|
|
| 4 |
import glob
|
| 5 |
import multiprocessing as mp
|
| 6 |
import os
|
|
|
|
| 7 |
|
| 8 |
# fmt: off
|
| 9 |
import sys
|
|
|
|
| 39 |
add_cat_seg_config(cfg)
|
| 40 |
cfg.merge_from_file(args.config_file)
|
| 41 |
cfg.merge_from_list(args.opts)
|
| 42 |
+
cfg.MODEL.DEVICE = "cpu"
|
| 43 |
cfg.freeze()
|
| 44 |
return cfg
|
| 45 |
|
cat_seg/modeling/transformer/cat_seg_predictor.py
CHANGED
|
@@ -58,7 +58,7 @@ class CATSegPredictor(nn.Module):
|
|
| 58 |
if self.test_class_texts == None:
|
| 59 |
self.test_class_texts = self.class_texts
|
| 60 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 61 |
-
|
| 62 |
self.tokenizer = None
|
| 63 |
if clip_pretrained == "ViT-G" or clip_pretrained == "ViT-H":
|
| 64 |
# for OpenCLIP models
|
|
@@ -84,12 +84,12 @@ class CATSegPredictor(nn.Module):
|
|
| 84 |
prompt_templates = ['A photo of a {} in the scene',]
|
| 85 |
else:
|
| 86 |
raise NotImplementedError
|
| 87 |
-
|
| 88 |
-
self.text_features = self.class_embeddings(self.class_texts, prompt_templates, clip_model).permute(1, 0, 2).float()
|
| 89 |
-
self.text_features_test = self.class_embeddings(self.test_class_texts, prompt_templates, clip_model).permute(1, 0, 2).float()
|
| 90 |
|
| 91 |
self.clip_model = clip_model.float()
|
| 92 |
self.clip_preprocess = clip_preprocess
|
|
|
|
|
|
|
|
|
|
| 93 |
|
| 94 |
transformer = Aggregator(
|
| 95 |
text_guidance_dim=text_guidance_dim,
|
|
@@ -161,9 +161,9 @@ class CATSegPredictor(nn.Module):
|
|
| 161 |
else:
|
| 162 |
texts = [template.format(classname) for template in templates] # format with class
|
| 163 |
if self.tokenizer is not None:
|
| 164 |
-
texts = self.tokenizer(texts).
|
| 165 |
else:
|
| 166 |
-
texts = clip.tokenize(texts).
|
| 167 |
class_embeddings = clip_model.encode_text(texts)
|
| 168 |
class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
|
| 169 |
if len(templates) != class_embeddings.shape[0]:
|
|
@@ -171,5 +171,5 @@ class CATSegPredictor(nn.Module):
|
|
| 171 |
class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
|
| 172 |
class_embedding = class_embeddings
|
| 173 |
zeroshot_weights.append(class_embedding)
|
| 174 |
-
zeroshot_weights = torch.stack(zeroshot_weights, dim=1).
|
| 175 |
return zeroshot_weights
|
|
|
|
| 58 |
if self.test_class_texts == None:
|
| 59 |
self.test_class_texts = self.class_texts
|
| 60 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 61 |
+
self.device = device
|
| 62 |
self.tokenizer = None
|
| 63 |
if clip_pretrained == "ViT-G" or clip_pretrained == "ViT-H":
|
| 64 |
# for OpenCLIP models
|
|
|
|
| 84 |
prompt_templates = ['A photo of a {} in the scene',]
|
| 85 |
else:
|
| 86 |
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
| 87 |
|
| 88 |
self.clip_model = clip_model.float()
|
| 89 |
self.clip_preprocess = clip_preprocess
|
| 90 |
+
|
| 91 |
+
self.text_features = self.class_embeddings(self.class_texts, prompt_templates, clip_model).permute(1, 0, 2).float()
|
| 92 |
+
self.text_features_test = self.class_embeddings(self.test_class_texts, prompt_templates, clip_model).permute(1, 0, 2).float()
|
| 93 |
|
| 94 |
transformer = Aggregator(
|
| 95 |
text_guidance_dim=text_guidance_dim,
|
|
|
|
| 161 |
else:
|
| 162 |
texts = [template.format(classname) for template in templates] # format with class
|
| 163 |
if self.tokenizer is not None:
|
| 164 |
+
texts = self.tokenizer(texts).to(self.device)
|
| 165 |
else:
|
| 166 |
+
texts = clip.tokenize(texts).to(self.device)
|
| 167 |
class_embeddings = clip_model.encode_text(texts)
|
| 168 |
class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
|
| 169 |
if len(templates) != class_embeddings.shape[0]:
|
|
|
|
| 171 |
class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
|
| 172 |
class_embedding = class_embeddings
|
| 173 |
zeroshot_weights.append(class_embedding)
|
| 174 |
+
zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(self.device)
|
| 175 |
return zeroshot_weights
|