Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -19,6 +19,7 @@ import json
|
|
| 19 |
cache_dir = '/data/KB'
|
| 20 |
os.makedirs(cache_dir, exist_ok=True)
|
| 21 |
|
|
|
|
| 22 |
def weighted_mean_pooling(hidden, attention_mask):
|
| 23 |
attention_mask_ = attention_mask * attention_mask.cumsum(dim=1)
|
| 24 |
s = torch.sum(hidden * attention_mask_.unsqueeze(-1).float(), dim=1)
|
|
@@ -26,6 +27,7 @@ def weighted_mean_pooling(hidden, attention_mask):
|
|
| 26 |
reps = s / d
|
| 27 |
return reps
|
| 28 |
|
|
|
|
| 29 |
@torch.no_grad()
|
| 30 |
def encode(text_or_image_list):
|
| 31 |
global model, tokenizer
|
|
@@ -106,7 +108,7 @@ def add_pdf_gradio(pdf_file_binary, progress=gr.Progress()):
|
|
| 106 |
|
| 107 |
return knowledge_base_name
|
| 108 |
|
| 109 |
-
|
| 110 |
def retrieve_gradio(knowledge_base: str, query: str, topk: int):
|
| 111 |
global model, tokenizer
|
| 112 |
|
|
|
|
| 19 |
cache_dir = '/data/KB'
|
| 20 |
os.makedirs(cache_dir, exist_ok=True)
|
| 21 |
|
| 22 |
+
@spaces.GPU
|
| 23 |
def weighted_mean_pooling(hidden, attention_mask):
|
| 24 |
attention_mask_ = attention_mask * attention_mask.cumsum(dim=1)
|
| 25 |
s = torch.sum(hidden * attention_mask_.unsqueeze(-1).float(), dim=1)
|
|
|
|
| 27 |
reps = s / d
|
| 28 |
return reps
|
| 29 |
|
| 30 |
+
@spaces.GPU
|
| 31 |
@torch.no_grad()
|
| 32 |
def encode(text_or_image_list):
|
| 33 |
global model, tokenizer
|
|
|
|
| 108 |
|
| 109 |
return knowledge_base_name
|
| 110 |
|
| 111 |
+
@spaces.GPU
|
| 112 |
def retrieve_gradio(knowledge_base: str, query: str, topk: int):
|
| 113 |
global model, tokenizer
|
| 114 |
|