markobinario commited on
Commit
ad4defa
·
verified ·
1 Parent(s): dc61c2e

Update ai_chatbot.py

Browse files
Files changed (1) hide show
  1. ai_chatbot.py +12 -119
ai_chatbot.py CHANGED
@@ -1,13 +1,10 @@
1
  from sentence_transformers import SentenceTransformer
2
  import numpy as np
3
  from typing import List, Dict, Tuple
4
- import mysql.connector
5
- from mysql.connector import Error
6
  import re
7
 
8
  class AIChatbot:
9
- def __init__(self, db_config: Dict[str, str]):
10
- self.db_config = db_config
11
  # Load the pre-trained model (can use a smaller model for more speed)
12
  self.model = SentenceTransformer('all-MiniLM-L6-v2')
13
  # Warm up the model to avoid first-request slowness
@@ -16,47 +13,15 @@ class AIChatbot:
16
  self.faqs = None
17
  self.load_faqs()
18
 
19
- def get_db_connection(self):
20
- try:
21
- connection = mysql.connector.connect(**self.db_config)
22
- return connection
23
- except Error as e:
24
- print(f"Error connecting to database: {e}")
25
- return None
26
-
27
  def load_faqs(self):
28
- """Load active FAQs from database and compute their normalized embeddings"""
29
- connection = self.get_db_connection()
30
- if connection:
31
- try:
32
- cursor = connection.cursor(dictionary=True)
33
- cursor.execute("SELECT id, question, answer FROM faqs WHERE is_active = 1 ORDER BY sort_order, id")
34
- self.faqs = cursor.fetchall()
35
- cursor.close()
36
-
37
- if self.faqs:
38
- # Compute and normalize embeddings for all questions
39
- questions = [faq['question'] for faq in self.faqs]
40
- embeddings = self.model.encode(questions, normalize_embeddings=True)
41
- self.faq_embeddings = np.array(embeddings)
42
- except Error as e:
43
- print(f"Error loading FAQs: {e}")
44
- finally:
45
- connection.close()
46
 
47
  def save_unanswered_question(self, question):
48
- print(f"Saving unanswered question: {question}") # Debug print
49
- try:
50
- connection = self.get_db_connection()
51
- if connection:
52
- cursor = connection.cursor()
53
- query = "INSERT INTO unanswered_questions (question) VALUES (%s)"
54
- cursor.execute(query, (question,))
55
- connection.commit()
56
- cursor.close()
57
- connection.close()
58
- except Error as e:
59
- print(f"Error saving unanswered question: {e}")
60
 
61
  def _tokenize(self, text: str):
62
  if not text:
@@ -88,51 +53,7 @@ class AIChatbot:
88
 
89
  def find_best_match(self, question: str, threshold: float = 0.7) -> Tuple[str, float]:
90
  print(f"find_best_match called with: {question}") # Debug print
91
-
92
- # First try to match with FAQs
93
- if self.faqs and self.faq_embeddings is not None:
94
- # Compute and normalize embedding for the input question
95
- question_embedding = self.model.encode([question], normalize_embeddings=True)[0]
96
- similarities = np.dot(self.faq_embeddings, question_embedding)
97
-
98
- # Compute keyword overlap with each FAQ question
99
- q_tokens = self._tokenize(question)
100
- overlap_scores = []
101
- for faq in self.faqs:
102
- overlap_scores.append(self._overlap_ratio(q_tokens, self._tokenize(faq['question'])))
103
-
104
- similarities = np.array(similarities)
105
- overlap_scores = np.array(overlap_scores)
106
-
107
- # Combined score to reduce false positives
108
- combined = 0.7 * similarities + 0.3 * overlap_scores
109
-
110
- # Apply WH-word intent consistency penalty
111
- q_wh = self._wh_class(question)
112
- if q_wh:
113
- for i, faq in enumerate(self.faqs):
114
- f_wh = self._wh_class(faq['question'])
115
- if f_wh and f_wh != q_wh:
116
- combined[i] *= 0.6 # penalize mismatched intent significantly
117
- best_idx = int(np.argmax(combined))
118
- best_semantic = float(similarities[best_idx])
119
- best_overlap = float(overlap_scores[best_idx])
120
- best_combined = float(combined[best_idx])
121
- best_wh = self._wh_class(self.faqs[best_idx]['question'])
122
-
123
- # Acceptance criteria: require good semantic OR strong combined with overlap
124
- accept = (
125
- best_semantic >= max(0.7, threshold)
126
- or (best_combined >= threshold and best_overlap >= 0.3)
127
- )
128
- # Enforce WH intent match when present
129
- if accept and q_wh and best_wh and q_wh != best_wh:
130
- accept = False
131
-
132
- if accept:
133
- return self.faqs[best_idx]['answer'], best_combined
134
-
135
- # If no FAQ match, provide general conversation response
136
  return self._generate_general_response(question)
137
 
138
  def _generate_general_response(self, question: str) -> Tuple[str, float]:
@@ -171,37 +92,9 @@ class AIChatbot:
171
  return "I understand you're asking about something, but I'm specifically designed to help with PSAU-related questions like admissions, courses, and university information. Could you rephrase your question to be more specific about what you'd like to know about Pangasinan State University? I'm here to help with academic guidance and university-related inquiries!", 0.6
172
 
173
  def get_suggested_questions(self, question: str, num_suggestions: int = 3) -> List[str]:
174
- """Get suggested questions based on the input question"""
175
- if not self.faqs or self.faq_embeddings is None:
176
- return []
177
-
178
- # Compute and normalize embedding for the input question
179
- question_embedding = self.model.encode([question], normalize_embeddings=True)[0]
180
-
181
- # Calculate cosine similarity
182
- similarities = np.dot(self.faq_embeddings, question_embedding)
183
-
184
- # Get top N similar questions
185
- top_indices = np.argsort(similarities)[-num_suggestions:][::-1]
186
- return [self.faqs[idx]['question'] for idx in top_indices if similarities[idx] > 0.3]
187
 
188
  def add_faq(self, question: str, answer: str) -> bool:
189
- """Add a new FAQ to the database"""
190
- connection = self.get_db_connection()
191
- if connection:
192
- try:
193
- cursor = connection.cursor()
194
- query = "INSERT INTO faqs (question, answer) VALUES (%s, %s)"
195
- cursor.execute(query, (question, answer))
196
- connection.commit()
197
- cursor.close()
198
-
199
- # Reload FAQs to update embeddings
200
- self.load_faqs()
201
- return True
202
- except Error as e:
203
- print(f"Error adding FAQ: {e}")
204
- return False
205
- finally:
206
- connection.close()
207
- return False
 
1
  from sentence_transformers import SentenceTransformer
2
  import numpy as np
3
  from typing import List, Dict, Tuple
 
 
4
  import re
5
 
6
  class AIChatbot:
7
+ def __init__(self):
 
8
  # Load the pre-trained model (can use a smaller model for more speed)
9
  self.model = SentenceTransformer('all-MiniLM-L6-v2')
10
  # Warm up the model to avoid first-request slowness
 
13
  self.faqs = None
14
  self.load_faqs()
15
 
 
 
 
 
 
 
 
 
16
  def load_faqs(self):
17
+ """Disable FAQs entirely; operate as a general-conversation bot."""
18
+ self.faqs = []
19
+ self.faq_embeddings = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  def save_unanswered_question(self, question):
22
+ """Log unanswered questions to console (can be extended to save to file)"""
23
+ print(f"Unanswered question logged: {question}")
24
+ # In a real implementation, you could save this to a file or send to an admin
 
 
 
 
 
 
 
 
 
25
 
26
  def _tokenize(self, text: str):
27
  if not text:
 
53
 
54
  def find_best_match(self, question: str, threshold: float = 0.7) -> Tuple[str, float]:
55
  print(f"find_best_match called with: {question}") # Debug print
56
+ # Always act as a general-conversation bot
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  return self._generate_general_response(question)
58
 
59
  def _generate_general_response(self, question: str) -> Tuple[str, float]:
 
92
  return "I understand you're asking about something, but I'm specifically designed to help with PSAU-related questions like admissions, courses, and university information. Could you rephrase your question to be more specific about what you'd like to know about Pangasinan State University? I'm here to help with academic guidance and university-related inquiries!", 0.6
93
 
94
  def get_suggested_questions(self, question: str, num_suggestions: int = 3) -> List[str]:
95
+ """No suggestions when FAQs are disabled."""
96
+ return []
 
 
 
 
 
 
 
 
 
 
 
97
 
98
  def add_faq(self, question: str, answer: str) -> bool:
99
+ """No-op when FAQs are disabled."""
100
+ return False