Spaces:
Running
Running
Update ai_chatbot.py
Browse files- ai_chatbot.py +152 -110
ai_chatbot.py
CHANGED
|
@@ -1,61 +1,62 @@
|
|
| 1 |
from sentence_transformers import SentenceTransformer
|
| 2 |
import numpy as np
|
| 3 |
from typing import List, Dict, Tuple
|
|
|
|
|
|
|
| 4 |
import re
|
| 5 |
-
import random
|
| 6 |
|
| 7 |
class AIChatbot:
|
| 8 |
-
def __init__(self):
|
|
|
|
| 9 |
# Load the pre-trained model (can use a smaller model for more speed)
|
| 10 |
self.model = SentenceTransformer('all-MiniLM-L6-v2')
|
| 11 |
# Warm up the model to avoid first-request slowness
|
| 12 |
_ = self.model.encode(["Hello, world!"])
|
| 13 |
-
self.
|
| 14 |
-
self.
|
| 15 |
-
self.
|
| 16 |
|
| 17 |
-
def
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
|
|
|
| 43 |
|
| 44 |
-
def
|
| 45 |
-
"
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
"
|
| 57 |
-
]
|
| 58 |
-
return random.choice(general_responses)
|
| 59 |
|
| 60 |
def _tokenize(self, text: str):
|
| 61 |
if not text:
|
|
@@ -85,81 +86,122 @@ class AIChatbot:
|
|
| 85 |
return key
|
| 86 |
return ''
|
| 87 |
|
| 88 |
-
def find_best_match(self,
|
| 89 |
-
print(f"find_best_match called with: {
|
| 90 |
-
|
| 91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
similarities = np.dot(self.conversation_embeddings, message_embedding)
|
| 96 |
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
|
| 104 |
-
|
| 105 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
|
| 107 |
-
|
| 108 |
-
|
| 109 |
|
| 110 |
-
#
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
combined[i] *= 0.6 # penalize mismatched intent significantly
|
| 117 |
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
)
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
|
| 135 |
-
def
|
| 136 |
-
"""Get suggested
|
| 137 |
-
if not self.
|
| 138 |
return []
|
| 139 |
|
| 140 |
-
# Compute and normalize embedding for the input
|
| 141 |
-
|
| 142 |
|
| 143 |
# Calculate cosine similarity
|
| 144 |
-
similarities = np.dot(self.
|
| 145 |
|
| 146 |
-
# Get top N similar
|
| 147 |
top_indices = np.argsort(similarities)[-num_suggestions:][::-1]
|
| 148 |
-
return [self.
|
| 149 |
|
| 150 |
-
def
|
| 151 |
-
"""Add a new
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from sentence_transformers import SentenceTransformer
|
| 2 |
import numpy as np
|
| 3 |
from typing import List, Dict, Tuple
|
| 4 |
+
import mysql.connector
|
| 5 |
+
from mysql.connector import Error
|
| 6 |
import re
|
|
|
|
| 7 |
|
| 8 |
class AIChatbot:
|
| 9 |
+
def __init__(self, db_config: Dict[str, str]):
|
| 10 |
+
self.db_config = db_config
|
| 11 |
# Load the pre-trained model (can use a smaller model for more speed)
|
| 12 |
self.model = SentenceTransformer('all-MiniLM-L6-v2')
|
| 13 |
# Warm up the model to avoid first-request slowness
|
| 14 |
_ = self.model.encode(["Hello, world!"])
|
| 15 |
+
self.faq_embeddings = None
|
| 16 |
+
self.faqs = None
|
| 17 |
+
self.load_faqs()
|
| 18 |
|
| 19 |
+
def get_db_connection(self):
|
| 20 |
+
try:
|
| 21 |
+
connection = mysql.connector.connect(**self.db_config)
|
| 22 |
+
return connection
|
| 23 |
+
except Error as e:
|
| 24 |
+
print(f"Error connecting to database: {e}")
|
| 25 |
+
return None
|
| 26 |
+
|
| 27 |
+
def load_faqs(self):
|
| 28 |
+
"""Load active FAQs from database and compute their normalized embeddings"""
|
| 29 |
+
connection = self.get_db_connection()
|
| 30 |
+
if connection:
|
| 31 |
+
try:
|
| 32 |
+
cursor = connection.cursor(dictionary=True)
|
| 33 |
+
cursor.execute("SELECT id, question, answer FROM faqs WHERE is_active = 1 ORDER BY sort_order, id")
|
| 34 |
+
self.faqs = cursor.fetchall()
|
| 35 |
+
cursor.close()
|
| 36 |
+
|
| 37 |
+
if self.faqs:
|
| 38 |
+
# Compute and normalize embeddings for all questions
|
| 39 |
+
questions = [faq['question'] for faq in self.faqs]
|
| 40 |
+
embeddings = self.model.encode(questions, normalize_embeddings=True)
|
| 41 |
+
self.faq_embeddings = np.array(embeddings)
|
| 42 |
+
except Error as e:
|
| 43 |
+
print(f"Error loading FAQs: {e}")
|
| 44 |
+
finally:
|
| 45 |
+
connection.close()
|
| 46 |
|
| 47 |
+
def save_unanswered_question(self, question):
|
| 48 |
+
print(f"Saving unanswered question: {question}") # Debug print
|
| 49 |
+
try:
|
| 50 |
+
connection = self.get_db_connection()
|
| 51 |
+
if connection:
|
| 52 |
+
cursor = connection.cursor()
|
| 53 |
+
query = "INSERT INTO unanswered_questions (question) VALUES (%s)"
|
| 54 |
+
cursor.execute(query, (question,))
|
| 55 |
+
connection.commit()
|
| 56 |
+
cursor.close()
|
| 57 |
+
connection.close()
|
| 58 |
+
except Error as e:
|
| 59 |
+
print(f"Error saving unanswered question: {e}")
|
|
|
|
|
|
|
| 60 |
|
| 61 |
def _tokenize(self, text: str):
|
| 62 |
if not text:
|
|
|
|
| 86 |
return key
|
| 87 |
return ''
|
| 88 |
|
| 89 |
+
def find_best_match(self, question: str, threshold: float = 0.7) -> Tuple[str, float]:
|
| 90 |
+
print(f"find_best_match called with: {question}") # Debug print
|
| 91 |
+
|
| 92 |
+
# First try to match with FAQs
|
| 93 |
+
if self.faqs and self.faq_embeddings is not None:
|
| 94 |
+
# Compute and normalize embedding for the input question
|
| 95 |
+
question_embedding = self.model.encode([question], normalize_embeddings=True)[0]
|
| 96 |
+
similarities = np.dot(self.faq_embeddings, question_embedding)
|
| 97 |
+
|
| 98 |
+
# Compute keyword overlap with each FAQ question
|
| 99 |
+
q_tokens = self._tokenize(question)
|
| 100 |
+
overlap_scores = []
|
| 101 |
+
for faq in self.faqs:
|
| 102 |
+
overlap_scores.append(self._overlap_ratio(q_tokens, self._tokenize(faq['question'])))
|
| 103 |
|
| 104 |
+
similarities = np.array(similarities)
|
| 105 |
+
overlap_scores = np.array(overlap_scores)
|
|
|
|
| 106 |
|
| 107 |
+
# Combined score to reduce false positives
|
| 108 |
+
combined = 0.7 * similarities + 0.3 * overlap_scores
|
| 109 |
+
|
| 110 |
+
# Apply WH-word intent consistency penalty
|
| 111 |
+
q_wh = self._wh_class(question)
|
| 112 |
+
if q_wh:
|
| 113 |
+
for i, faq in enumerate(self.faqs):
|
| 114 |
+
f_wh = self._wh_class(faq['question'])
|
| 115 |
+
if f_wh and f_wh != q_wh:
|
| 116 |
+
combined[i] *= 0.6 # penalize mismatched intent significantly
|
| 117 |
+
best_idx = int(np.argmax(combined))
|
| 118 |
+
best_semantic = float(similarities[best_idx])
|
| 119 |
+
best_overlap = float(overlap_scores[best_idx])
|
| 120 |
+
best_combined = float(combined[best_idx])
|
| 121 |
+
best_wh = self._wh_class(self.faqs[best_idx]['question'])
|
| 122 |
|
| 123 |
+
# Acceptance criteria: require good semantic OR strong combined with overlap
|
| 124 |
+
accept = (
|
| 125 |
+
best_semantic >= max(0.7, threshold)
|
| 126 |
+
or (best_combined >= threshold and best_overlap >= 0.3)
|
| 127 |
+
)
|
| 128 |
+
# Enforce WH intent match when present
|
| 129 |
+
if accept and q_wh and best_wh and q_wh != best_wh:
|
| 130 |
+
accept = False
|
| 131 |
|
| 132 |
+
if accept:
|
| 133 |
+
return self.faqs[best_idx]['answer'], best_combined
|
| 134 |
|
| 135 |
+
# If no FAQ match, provide general conversation response
|
| 136 |
+
return self._generate_general_response(question)
|
| 137 |
+
|
| 138 |
+
def _generate_general_response(self, question: str) -> Tuple[str, float]:
|
| 139 |
+
"""Generate general conversation responses for non-FAQ questions"""
|
| 140 |
+
question_lower = question.lower().strip()
|
|
|
|
| 141 |
|
| 142 |
+
# Greeting responses
|
| 143 |
+
if any(greeting in question_lower for greeting in ['hello', 'hi', 'hey', 'good morning', 'good afternoon', 'good evening']):
|
| 144 |
+
return "Hello! I'm the PSAU AI assistant. I'm here to help you with questions about university admissions, courses, and general information about Pangasinan State University. How can I assist you today?", 0.8
|
| 145 |
+
|
| 146 |
+
# Thank you responses
|
| 147 |
+
if any(thanks in question_lower for thanks in ['thank you', 'thanks', 'thank', 'appreciate']):
|
| 148 |
+
return "You're very welcome! I'm happy to help. Is there anything else you'd like to know about PSAU or university admissions?", 0.9
|
| 149 |
+
|
| 150 |
+
# Goodbye responses
|
| 151 |
+
if any(goodbye in question_lower for goodbye in ['bye', 'goodbye', 'see you', 'farewell']):
|
| 152 |
+
return "Goodbye! It was nice chatting with you. Feel free to come back anytime if you have more questions about PSAU. Good luck with your academic journey!", 0.9
|
| 153 |
+
|
| 154 |
+
# How are you responses
|
| 155 |
+
if any(how in question_lower for how in ['how are you', 'how do you do', 'how is it going']):
|
| 156 |
+
return "I'm doing great, thank you for asking! I'm here and ready to help you with any questions about PSAU admissions, courses, or university life. What would you like to know?", 0.8
|
| 157 |
+
|
| 158 |
+
# What can you do responses
|
| 159 |
+
if any(what in question_lower for what in ['what can you do', 'what do you do', 'what are your capabilities']):
|
| 160 |
+
return "I can help you with:\n• University admission requirements and procedures\n• Course information and recommendations\n• General questions about PSAU\n• Academic guidance and support\n• Information about campus life\n\nWhat specific information are you looking for?", 0.9
|
| 161 |
+
|
| 162 |
+
# About PSAU responses
|
| 163 |
+
if any(about in question_lower for about in ['about psa', 'about psu', 'about pangasinan state', 'tell me about']):
|
| 164 |
+
return "Pangasinan State University (PSAU) is a premier state university in the Philippines offering quality education across various fields. We provide undergraduate and graduate programs in areas like Computer Science, Business, Education, Nursing, and more. We're committed to academic excellence and student success. What would you like to know more about?", 0.8
|
| 165 |
+
|
| 166 |
+
# Help responses
|
| 167 |
+
if any(help in question_lower for help in ['help', 'assist', 'support']):
|
| 168 |
+
return "I'm here to help! I can assist you with:\n• Admission requirements and deadlines\n• Course information and recommendations\n• Academic programs and majors\n• Campus facilities and services\n• General university information\n\nJust ask me any question and I'll do my best to help you!", 0.9
|
| 169 |
+
|
| 170 |
+
# Default general response
|
| 171 |
+
return "I understand you're asking about something, but I'm specifically designed to help with PSAU-related questions like admissions, courses, and university information. Could you rephrase your question to be more specific about what you'd like to know about Pangasinan State University? I'm here to help with academic guidance and university-related inquiries!", 0.6
|
| 172 |
|
| 173 |
+
def get_suggested_questions(self, question: str, num_suggestions: int = 3) -> List[str]:
|
| 174 |
+
"""Get suggested questions based on the input question"""
|
| 175 |
+
if not self.faqs or self.faq_embeddings is None:
|
| 176 |
return []
|
| 177 |
|
| 178 |
+
# Compute and normalize embedding for the input question
|
| 179 |
+
question_embedding = self.model.encode([question], normalize_embeddings=True)[0]
|
| 180 |
|
| 181 |
# Calculate cosine similarity
|
| 182 |
+
similarities = np.dot(self.faq_embeddings, question_embedding)
|
| 183 |
|
| 184 |
+
# Get top N similar questions
|
| 185 |
top_indices = np.argsort(similarities)[-num_suggestions:][::-1]
|
| 186 |
+
return [self.faqs[idx]['question'] for idx in top_indices if similarities[idx] > 0.3]
|
| 187 |
|
| 188 |
+
def add_faq(self, question: str, answer: str) -> bool:
|
| 189 |
+
"""Add a new FAQ to the database"""
|
| 190 |
+
connection = self.get_db_connection()
|
| 191 |
+
if connection:
|
| 192 |
+
try:
|
| 193 |
+
cursor = connection.cursor()
|
| 194 |
+
query = "INSERT INTO faqs (question, answer) VALUES (%s, %s)"
|
| 195 |
+
cursor.execute(query, (question, answer))
|
| 196 |
+
connection.commit()
|
| 197 |
+
cursor.close()
|
| 198 |
+
|
| 199 |
+
# Reload FAQs to update embeddings
|
| 200 |
+
self.load_faqs()
|
| 201 |
+
return True
|
| 202 |
+
except Error as e:
|
| 203 |
+
print(f"Error adding FAQ: {e}")
|
| 204 |
+
return False
|
| 205 |
+
finally:
|
| 206 |
+
connection.close()
|
| 207 |
+
return False
|