Spaces:
Sleeping
Sleeping
Update ai_chatbot.py
Browse files- ai_chatbot.py +12 -119
ai_chatbot.py
CHANGED
|
@@ -1,13 +1,10 @@
|
|
| 1 |
from sentence_transformers import SentenceTransformer
|
| 2 |
import numpy as np
|
| 3 |
from typing import List, Dict, Tuple
|
| 4 |
-
import mysql.connector
|
| 5 |
-
from mysql.connector import Error
|
| 6 |
import re
|
| 7 |
|
| 8 |
class AIChatbot:
|
| 9 |
-
def __init__(self
|
| 10 |
-
self.db_config = db_config
|
| 11 |
# Load the pre-trained model (can use a smaller model for more speed)
|
| 12 |
self.model = SentenceTransformer('all-MiniLM-L6-v2')
|
| 13 |
# Warm up the model to avoid first-request slowness
|
|
@@ -16,47 +13,15 @@ class AIChatbot:
|
|
| 16 |
self.faqs = None
|
| 17 |
self.load_faqs()
|
| 18 |
|
| 19 |
-
def get_db_connection(self):
|
| 20 |
-
try:
|
| 21 |
-
connection = mysql.connector.connect(**self.db_config)
|
| 22 |
-
return connection
|
| 23 |
-
except Error as e:
|
| 24 |
-
print(f"Error connecting to database: {e}")
|
| 25 |
-
return None
|
| 26 |
-
|
| 27 |
def load_faqs(self):
|
| 28 |
-
"""
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
try:
|
| 32 |
-
cursor = connection.cursor(dictionary=True)
|
| 33 |
-
cursor.execute("SELECT id, question, answer FROM faqs WHERE is_active = 1 ORDER BY sort_order, id")
|
| 34 |
-
self.faqs = cursor.fetchall()
|
| 35 |
-
cursor.close()
|
| 36 |
-
|
| 37 |
-
if self.faqs:
|
| 38 |
-
# Compute and normalize embeddings for all questions
|
| 39 |
-
questions = [faq['question'] for faq in self.faqs]
|
| 40 |
-
embeddings = self.model.encode(questions, normalize_embeddings=True)
|
| 41 |
-
self.faq_embeddings = np.array(embeddings)
|
| 42 |
-
except Error as e:
|
| 43 |
-
print(f"Error loading FAQs: {e}")
|
| 44 |
-
finally:
|
| 45 |
-
connection.close()
|
| 46 |
|
| 47 |
def save_unanswered_question(self, question):
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
if connection:
|
| 52 |
-
cursor = connection.cursor()
|
| 53 |
-
query = "INSERT INTO unanswered_questions (question) VALUES (%s)"
|
| 54 |
-
cursor.execute(query, (question,))
|
| 55 |
-
connection.commit()
|
| 56 |
-
cursor.close()
|
| 57 |
-
connection.close()
|
| 58 |
-
except Error as e:
|
| 59 |
-
print(f"Error saving unanswered question: {e}")
|
| 60 |
|
| 61 |
def _tokenize(self, text: str):
|
| 62 |
if not text:
|
|
@@ -88,51 +53,7 @@ class AIChatbot:
|
|
| 88 |
|
| 89 |
def find_best_match(self, question: str, threshold: float = 0.7) -> Tuple[str, float]:
|
| 90 |
print(f"find_best_match called with: {question}") # Debug print
|
| 91 |
-
|
| 92 |
-
# First try to match with FAQs
|
| 93 |
-
if self.faqs and self.faq_embeddings is not None:
|
| 94 |
-
# Compute and normalize embedding for the input question
|
| 95 |
-
question_embedding = self.model.encode([question], normalize_embeddings=True)[0]
|
| 96 |
-
similarities = np.dot(self.faq_embeddings, question_embedding)
|
| 97 |
-
|
| 98 |
-
# Compute keyword overlap with each FAQ question
|
| 99 |
-
q_tokens = self._tokenize(question)
|
| 100 |
-
overlap_scores = []
|
| 101 |
-
for faq in self.faqs:
|
| 102 |
-
overlap_scores.append(self._overlap_ratio(q_tokens, self._tokenize(faq['question'])))
|
| 103 |
-
|
| 104 |
-
similarities = np.array(similarities)
|
| 105 |
-
overlap_scores = np.array(overlap_scores)
|
| 106 |
-
|
| 107 |
-
# Combined score to reduce false positives
|
| 108 |
-
combined = 0.7 * similarities + 0.3 * overlap_scores
|
| 109 |
-
|
| 110 |
-
# Apply WH-word intent consistency penalty
|
| 111 |
-
q_wh = self._wh_class(question)
|
| 112 |
-
if q_wh:
|
| 113 |
-
for i, faq in enumerate(self.faqs):
|
| 114 |
-
f_wh = self._wh_class(faq['question'])
|
| 115 |
-
if f_wh and f_wh != q_wh:
|
| 116 |
-
combined[i] *= 0.6 # penalize mismatched intent significantly
|
| 117 |
-
best_idx = int(np.argmax(combined))
|
| 118 |
-
best_semantic = float(similarities[best_idx])
|
| 119 |
-
best_overlap = float(overlap_scores[best_idx])
|
| 120 |
-
best_combined = float(combined[best_idx])
|
| 121 |
-
best_wh = self._wh_class(self.faqs[best_idx]['question'])
|
| 122 |
-
|
| 123 |
-
# Acceptance criteria: require good semantic OR strong combined with overlap
|
| 124 |
-
accept = (
|
| 125 |
-
best_semantic >= max(0.7, threshold)
|
| 126 |
-
or (best_combined >= threshold and best_overlap >= 0.3)
|
| 127 |
-
)
|
| 128 |
-
# Enforce WH intent match when present
|
| 129 |
-
if accept and q_wh and best_wh and q_wh != best_wh:
|
| 130 |
-
accept = False
|
| 131 |
-
|
| 132 |
-
if accept:
|
| 133 |
-
return self.faqs[best_idx]['answer'], best_combined
|
| 134 |
-
|
| 135 |
-
# If no FAQ match, provide general conversation response
|
| 136 |
return self._generate_general_response(question)
|
| 137 |
|
| 138 |
def _generate_general_response(self, question: str) -> Tuple[str, float]:
|
|
@@ -171,37 +92,9 @@ class AIChatbot:
|
|
| 171 |
return "I understand you're asking about something, but I'm specifically designed to help with PSAU-related questions like admissions, courses, and university information. Could you rephrase your question to be more specific about what you'd like to know about Pangasinan State University? I'm here to help with academic guidance and university-related inquiries!", 0.6
|
| 172 |
|
| 173 |
def get_suggested_questions(self, question: str, num_suggestions: int = 3) -> List[str]:
|
| 174 |
-
"""
|
| 175 |
-
|
| 176 |
-
return []
|
| 177 |
-
|
| 178 |
-
# Compute and normalize embedding for the input question
|
| 179 |
-
question_embedding = self.model.encode([question], normalize_embeddings=True)[0]
|
| 180 |
-
|
| 181 |
-
# Calculate cosine similarity
|
| 182 |
-
similarities = np.dot(self.faq_embeddings, question_embedding)
|
| 183 |
-
|
| 184 |
-
# Get top N similar questions
|
| 185 |
-
top_indices = np.argsort(similarities)[-num_suggestions:][::-1]
|
| 186 |
-
return [self.faqs[idx]['question'] for idx in top_indices if similarities[idx] > 0.3]
|
| 187 |
|
| 188 |
def add_faq(self, question: str, answer: str) -> bool:
|
| 189 |
-
"""
|
| 190 |
-
|
| 191 |
-
if connection:
|
| 192 |
-
try:
|
| 193 |
-
cursor = connection.cursor()
|
| 194 |
-
query = "INSERT INTO faqs (question, answer) VALUES (%s, %s)"
|
| 195 |
-
cursor.execute(query, (question, answer))
|
| 196 |
-
connection.commit()
|
| 197 |
-
cursor.close()
|
| 198 |
-
|
| 199 |
-
# Reload FAQs to update embeddings
|
| 200 |
-
self.load_faqs()
|
| 201 |
-
return True
|
| 202 |
-
except Error as e:
|
| 203 |
-
print(f"Error adding FAQ: {e}")
|
| 204 |
-
return False
|
| 205 |
-
finally:
|
| 206 |
-
connection.close()
|
| 207 |
-
return False
|
|
|
|
| 1 |
from sentence_transformers import SentenceTransformer
|
| 2 |
import numpy as np
|
| 3 |
from typing import List, Dict, Tuple
|
|
|
|
|
|
|
| 4 |
import re
|
| 5 |
|
| 6 |
class AIChatbot:
|
| 7 |
+
def __init__(self):
|
|
|
|
| 8 |
# Load the pre-trained model (can use a smaller model for more speed)
|
| 9 |
self.model = SentenceTransformer('all-MiniLM-L6-v2')
|
| 10 |
# Warm up the model to avoid first-request slowness
|
|
|
|
| 13 |
self.faqs = None
|
| 14 |
self.load_faqs()
|
| 15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
def load_faqs(self):
|
| 17 |
+
"""Disable FAQs entirely; operate as a general-conversation bot."""
|
| 18 |
+
self.faqs = []
|
| 19 |
+
self.faq_embeddings = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
def save_unanswered_question(self, question):
|
| 22 |
+
"""Log unanswered questions to console (can be extended to save to file)"""
|
| 23 |
+
print(f"Unanswered question logged: {question}")
|
| 24 |
+
# In a real implementation, you could save this to a file or send to an admin
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
def _tokenize(self, text: str):
|
| 27 |
if not text:
|
|
|
|
| 53 |
|
| 54 |
def find_best_match(self, question: str, threshold: float = 0.7) -> Tuple[str, float]:
|
| 55 |
print(f"find_best_match called with: {question}") # Debug print
|
| 56 |
+
# Always act as a general-conversation bot
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
return self._generate_general_response(question)
|
| 58 |
|
| 59 |
def _generate_general_response(self, question: str) -> Tuple[str, float]:
|
|
|
|
| 92 |
return "I understand you're asking about something, but I'm specifically designed to help with PSAU-related questions like admissions, courses, and university information. Could you rephrase your question to be more specific about what you'd like to know about Pangasinan State University? I'm here to help with academic guidance and university-related inquiries!", 0.6
|
| 93 |
|
| 94 |
def get_suggested_questions(self, question: str, num_suggestions: int = 3) -> List[str]:
|
| 95 |
+
"""No suggestions when FAQs are disabled."""
|
| 96 |
+
return []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
|
| 98 |
def add_faq(self, question: str, answer: str) -> bool:
|
| 99 |
+
"""No-op when FAQs are disabled."""
|
| 100 |
+
return False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|