Spaces:
Runtime error
Runtime error
load model
Browse files
app.py
CHANGED
|
@@ -13,12 +13,6 @@ from transformers import pipeline
|
|
| 13 |
|
| 14 |
from sentence_transformers import SentenceTransformer, util
|
| 15 |
|
| 16 |
-
classifier_model = pipeline(
|
| 17 |
-
"zero-shot-classification", model="MoritzLaurer/deberta-v3-large-zeroshot-v1"
|
| 18 |
-
)
|
| 19 |
-
|
| 20 |
-
embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
|
| 21 |
-
|
| 22 |
# get db info from env vars
|
| 23 |
db_host = os.environ.get("DB_HOST")
|
| 24 |
db_user = os.environ.get("DB_USER")
|
|
@@ -77,6 +71,9 @@ potential_labels = get_potential_labels()
|
|
| 77 |
# Function to handle the classification
|
| 78 |
def classify_email_and_generate_response(representative_email, constituent_email):
|
| 79 |
potential_labels = get_potential_labels()
|
|
|
|
|
|
|
|
|
|
| 80 |
print("classifying email")
|
| 81 |
model_out = classifier_model(constituent_email, potential_labels, multi_label=True)
|
| 82 |
print("classification complete")
|
|
@@ -148,6 +145,7 @@ def get_similar_messages(constituent_email):
|
|
| 148 |
)
|
| 149 |
|
| 150 |
messages_for_category = db_cursor.fetchall()
|
|
|
|
| 151 |
|
| 152 |
all_message_chains = []
|
| 153 |
|
|
|
|
| 13 |
|
| 14 |
from sentence_transformers import SentenceTransformer, util
|
| 15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
# get db info from env vars
|
| 17 |
db_host = os.environ.get("DB_HOST")
|
| 18 |
db_user = os.environ.get("DB_USER")
|
|
|
|
| 71 |
# Function to handle the classification
|
| 72 |
def classify_email_and_generate_response(representative_email, constituent_email):
|
| 73 |
potential_labels = get_potential_labels()
|
| 74 |
+
classifier_model = pipeline(
|
| 75 |
+
"zero-shot-classification", model="MoritzLaurer/deberta-v3-large-zeroshot-v1"
|
| 76 |
+
)
|
| 77 |
print("classifying email")
|
| 78 |
model_out = classifier_model(constituent_email, potential_labels, multi_label=True)
|
| 79 |
print("classification complete")
|
|
|
|
| 145 |
)
|
| 146 |
|
| 147 |
messages_for_category = db_cursor.fetchall()
|
| 148 |
+
embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
|
| 149 |
|
| 150 |
all_message_chains = []
|
| 151 |
|