Spaces:
Sleeping
Sleeping
| # This is my app.py | |
| import os | |
| import torch | |
| import re | |
| import warnings | |
| import time | |
| import json | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig | |
| from sentence_transformers import SentenceTransformer, util | |
| import gspread | |
| from google.auth import default | |
| from tqdm import tqdm | |
| from duckduckgo_search import DDGS | |
| # Removed spacy and pathlib imports | |
| import base64 | |
| # Suppress warnings | |
| warnings.filterwarnings("ignore", category=UserWarning) | |
| # --- Configuration --- | |
| SHEET_ID = "19ipxC2vHYhpXCefpxpIkpeYdI43a1Ku2kYwecgUULIw" # Your Google Sheet ID | |
| HF_TOKEN = os.getenv("HF_TOKEN") # Get Hugging Face token from Space Secrets | |
| GOOGLE_SERVICE_ACCOUNT_KEY_BASE64 = os.getenv("GOOGLE_SERVICE_ACCOUNT_KEY_BASE64") | |
| # Changed model_id to Gemma 2B for CPU | |
| # model_id = "google/gemma-2b" # Using Gemma 2B | |
| model_id ="unsloth/gemma-3-1b-it" | |
| # --- Constants for Prompting and Validation --- | |
| SEARCH_MARKER = "ACTION: SEARCH:" | |
| BUSINESS_LOOKUP_MARKER = "ACTION: LOOKUP_BUSINESS_INFO:" | |
| ANSWER_DIRECTLY_MARKER = "ACTION: ANSWER_DIRECTLY:" | |
| BUSINESS_LOOKUP_VALIDATION_THRESHOLD = 0.6 | |
| SEARCH_VALIDATION_THRESHOLD = 0.6 | |
| PRE_PASS1_BUSINESS_PART_LOOKUP_THRESHOLD = 0.5 | |
| # --- Global variables to load once --- | |
| tokenizer = None | |
| model = None | |
| # Removed nlp = None | |
| embedder = None # Sentence Transformer | |
| data = [] # Google Sheet data | |
| descriptions = [] | |
| embeddings = torch.tensor([]) # Google Sheet embeddings | |
| # --- Loading Functions (Run once on startup) --- | |
| # Removed load_spacy_model function | |
| def load_sentence_transformer(): | |
| """Loads the Sentence Transformer model.""" | |
| print("Loading Sentence Transformer...") | |
| try: | |
| embedder_model = SentenceTransformer("all-MiniLM-L6-v2") | |
| print("Sentence Transformer loaded.") | |
| return embedder_model | |
| except Exception as e: | |
| print(f"Error loading Sentence Transformer: {e}") | |
| return None | |
| # Inside app.py, locate this function | |
| def load_google_sheet_data(sheet_id, service_account_key_base64): | |
| """Authenticates and loads data from Google Sheet.""" | |
| print(f"Attempting to load Google Sheet data from ID: {sheet_id}") | |
| if not service_account_key_base64: | |
| print("Warning: GOOGLE_SERVICE_ACCOUNT_KEY_BASE64 secret is not set. Cannot access Google Sheets.") | |
| return [], [], torch.tensor([]) | |
| try: | |
| print("Decoding base64 key...") | |
| key_bytes = base64.b64decode(service_account_key_base64) | |
| key_dict = json.loads(key_bytes) | |
| print("Base64 key decoded and parsed.") | |
| print("Authenticating with service account...") | |
| from google.oauth2 import service_account | |
| # --- Suggested Change: Add the Google Sheets Scope --- | |
| # Define the scopes needed. This is the standard scope for Google Sheets. | |
| scopes = ['https://www.googleapis.com/auth/spreadsheets.readonly'] # Use read-only if only reading, 'https://www.googleapis.com/auth/spreadsheets' for read/write | |
| creds = service_account.Credentials.from_service_account_info(key_dict, scopes=scopes) | |
| # --- End Suggested Change --- | |
| client = gspread.authorize(creds) | |
| print("Authentication successful.") | |
| print(f"Opening sheet with key '{sheet_id}'...") | |
| # *** IMPORTANT: If your sheet is NOT the first sheet, change 'sheet1' | |
| # *** For example, if your sheet is named 'Data', use: | |
| # sheet = client.open_by_key(sheet_id).worksheet("Data") | |
| sheet = client.open_by_key(sheet_id).sheet1 | |
| print(f"Successfully opened Google Sheet with ID: {sheet_id}") | |
| print("Getting all records from the sheet...") | |
| sheet_data = sheet.get_all_records() | |
| print(f"Retrieved {len(sheet_data)} raw records from sheet.") | |
| if not sheet_data: | |
| print(f"Warning: No data records found in Google Sheet with ID: {sheet_id}") | |
| return [], [], torch.tensor([]) | |
| print("Filtering data for 'Service' and 'Description' columns...") | |
| filtered_data = [row for row in sheet_data if row.get('Service') and row.get('Description')] | |
| print(f"Filtered down to {len(filtered_data)} records.") | |
| if not filtered_data: | |
| print("Warning: Filtered data is empty after checking for 'Service' and 'Description'.") | |
| # Check if headers exist at all if filtered_data is empty but sheet_data isn't | |
| if sheet_data and ('Service' not in sheet_data[0] or 'Description' not in sheet_data[0]): | |
| print("Error: 'Service' or 'Description' headers are missing or misspelled in the sheet.") | |
| return [], [], torch.tensor([]) | |
| # Re-checking column existence on filtered_data (redundant after filter but safe) | |
| if 'Service' not in filtered_data[0] or 'Description' not in filtered_data[0]: | |
| print("Error: Filtered Google Sheet data must contain 'Service' and 'Description' columns. This should not happen if filtering worked.") | |
| return [], [], torch.tensor([]) | |
| services = [row["Service"] for row in filtered_data] | |
| descriptions = [row["Description"] for row in filtered_data] | |
| print(f"Loaded {len(descriptions)} entries from Google Sheet for embedding.") | |
| return filtered_data, descriptions, None # Return descriptions, embeddings encoded later | |
| except gspread.exceptions.SpreadsheetNotFound: | |
| print(f"Error: Google Sheet with ID '{sheet_id}' not found.") | |
| print("Please check the SHEET_ID and ensure the service account has access.") | |
| return [], [], torch.tensor([]) | |
| except Exception as e: | |
| print(f"An error occurred while accessing the Google Sheet: {e}") | |
| return [], [], torch.tensor([]) | |
| def load_llm_model(model_id, hf_token): | |
| """Loads the LLM in full precision (for CPU).""" | |
| print(f"Loading model {model_id} in full precision...") | |
| if not hf_token: | |
| print("Error: HF_TOKEN secret is not set. Cannot load Hugging Face model.") | |
| return None, None | |
| try: | |
| llm_tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_token) | |
| # Explicitly set the chat template for Gemma models | |
| # This template formats messages as <start_of_turn>user\n{message}<end_of_turn>\n<start_of_turn>model\n{response}<end_of_turn>\n | |
| # and adds <bos> at the beginning and <start_of_turn>model\n at the end for generation prompt. | |
| llm_tokenizer.chat_template = "{% for message in messages %}{% if message['role'] == 'user' %}{{ '<start_of_turn>user\n' + message['content'] + '<end_of_turn>\n' }}{% elif message['role'] == 'system' %}{{ '<start_of_turn>system\n' + message['content'] + '<end_of_turn>\n' }}{% elif message['role'] == 'tool' %}{{ '<start_of_turn>tool\n' + message['content'] + '<end_of_turn>\n' }}{% elif message['role'] == 'model' %}{{ '<start_of_turn>model\n' + message['content'] + '<end_of_turn>\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<start_of_turn>model\n' }}{% endif %}" | |
| if llm_tokenizer.pad_token is None: | |
| llm_tokenizer.pad_token = llm_tokenizer.eos_token | |
| llm_model = AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| token=hf_token, | |
| device_map="auto", # This will likely map to 'cpu' | |
| ) | |
| print(f"Model {model_id} loaded in full precision.") | |
| return llm_model, llm_tokenizer | |
| except Exception as e: | |
| print(f"Error loading model {model_id}: {e}") | |
| print("Please ensure transformers, trl, peft, and accelerate are installed.") | |
| print("Check your Hugging Face token.") | |
| return None, None | |
| try: | |
| llm_tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_token) | |
| if llm_tokenizer.pad_token is None: | |
| llm_tokenizer.pad_token = llm_tokenizer.eos_token | |
| llm_model = AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| token=hf_token, | |
| device_map="auto", # This will likely map to 'cpu' | |
| ) | |
| print(f"Model {model_id} loaded in full precision.") | |
| return llm_model, llm_tokenizer | |
| except Exception as e: | |
| print(f"Error loading model {model_id}: {e}") | |
| print("Please ensure transformers, trl, peft, and accelerate are installed.") | |
| print("Check your Hugging Face token.") | |
| return None, None | |
| # --- Load all assets on startup --- | |
| print("Loading assets...") | |
| # Removed nlp = load_spacy_model() # Keep this line commented out if you removed spaCy | |
| embedder = load_sentence_transformer() | |
| print(f"Embedder loaded: {embedder is not None}") # Add this print | |
| data, descriptions, _ = load_google_sheet_data(SHEET_ID, GOOGLE_SERVICE_ACCOUNT_KEY_BASE64) | |
| print(f"Google Sheet data loaded: {len(data)} rows") # Add this print | |
| print(f"Google Sheet descriptions loaded: {len(descriptions)} items") # Add this print | |
| if embedder and descriptions: | |
| print("Encoding Google Sheet descriptions...") | |
| try: | |
| embeddings = embedder.encode(descriptions, convert_to_tensor=True) | |
| print("Encoding complete.") | |
| print(f"Embeddings shape: {embeddings.shape}") # Add this print | |
| except Exception as e: | |
| print(f"Error during embedding: {e}") | |
| embeddings = torch.tensor([]) # Ensure embeddings is an empty tensor on error | |
| else: | |
| print("Skipping embedding due to missing embedder or descriptions.") | |
| embeddings = torch.tensor([]) # Ensure embeddings is an empty tensor when skipped | |
| print(f"Embeddings tensor after skip: {embeddings.shape}") # Should print torch.Size([]) | |
| model, tokenizer = load_llm_model(model_id, HF_TOKEN) | |
| print(f"LLM Model loaded: {model is not None}") # Add this print | |
| print(f"LLM Tokenizer loaded: {tokenizer is not None}") # Add this print | |
| # Check if essential components loaded | |
| # This block provides a summary if anything failed during loading | |
| if not model or not tokenizer or not embedder or embeddings is None or embeddings.numel() == 0 or not data: | |
| print("\nERROR: Essential components failed to load. The application may not function correctly.") | |
| if not model: print("- LLM Model failed to load.") | |
| if not tokenizer: print("- LLM Tokenizer failed to load.") | |
| if not embedder: print("- Sentence Embedder failed to load.") | |
| # Check if embeddings is not None before accessing numel() | |
| if embeddings is None or embeddings.numel() == 0: print("- Embeddings are empty or None.") | |
| if not data: print("- Google Sheet Data is empty.") | |
| # Descriptions being empty is implicitly covered by data being empty in this context | |
| # if not descriptions: print("- Google Sheet Descriptions are empty.") | |
| # Removed spaCy error message | |
| # Continue, but the main inference function will need checks (already handled by the check at start of respond) | |
| else: | |
| print("\nAll essential components loaded successfully.") # Add this print | |
| # Check if essential components loaded (Removed nlp from this check) | |
| if not model or not tokenizer or not embedder: | |
| print("\nERROR: Essential components failed to load. The application may not function correctly.") | |
| if not model: print("- LLM Model failed to load.") | |
| if not tokenizer: print("- LLM Tokenizer failed to load.") | |
| if not embedder: print("- Sentence Embedder failed to load.") | |
| # Removed spaCy error message | |
| # Continue, but the main inference function will need checks | |
| # --- Helper Functions --- | |
| def perform_duckduckgo_search(query, max_results=3): | |
| """ | |
| Performs a search using DuckDuckGo and returns a list of dictionaries. | |
| Includes a delay to avoid rate limits. | |
| """ | |
| search_results_list = [] | |
| try: | |
| time.sleep(1) | |
| with DDGS() as ddgs: | |
| for r in ddgs.text(query, max_results=max_results): | |
| search_results_list.append(r) | |
| except Exception as e: | |
| print(f"Error during Duckduckgo search for '{query}': {e}") | |
| return [] | |
| return search_results_list | |
| def retrieve_business_info(query, data, embeddings, embedder, threshold=0.50): | |
| """ | |
| Retrieves relevant business information based on query similarity. | |
| Returns a dictionary if a match above threshold is found, otherwise None. | |
| Also returns the similarity score. | |
| Uses the global embedder, data, and embeddings. | |
| """ | |
| if not data or (embeddings is None or embeddings.numel() == 0) or embedder is None: | |
| print("Skipping business info retrieval: Data, embeddings or embedder not available.") | |
| return None, 0.0 | |
| try: | |
| user_embedding = embedder.encode(query, convert_to_tensor=True) | |
| cos_scores = util.cos_sim(user_embedding, embeddings)[0] | |
| best_score = cos_scores.max().item() | |
| if best_score > threshold: | |
| best_match_idx = cos_scores.argmax().item() | |
| best_match = data[best_match_idx] | |
| return best_match, best_score | |
| else: | |
| return None, best_score | |
| except Exception as e: | |
| print(f"Error during business information retrieval: {e}") | |
| return None, 0.0 | |
| # Alternative split_query function without spaCy | |
| def split_query(query): | |
| """Splits a user query into potential sub-queries using regex.""" | |
| # This regex splits on common separators like comma, semicolon, and conjunctions followed by interrogative words | |
| parts = re.split(r',|;|\band\s+(?:who|what|where|when|why|how|is|are|can|tell me about)\b', query, flags=re.IGNORECASE) | |
| # Filter out empty strings and strip whitespace | |
| parts = [part.strip() for part in parts if part and part.strip()] | |
| # If splitting didn't produce multiple meaningful parts, return the original query | |
| if len(parts) <= 1: | |
| return [query] | |
| return parts | |
| # --- Pass 1 System Prompt --- | |
| pass1_instructions_action = """You are a helpful assistant for a business. Your primary goal in this first step is to analyze the user's query and decide which actions are needed to answer it. | |
| You have analyzed the user's query and potentially broken it down into parts. For each part, a preliminary check was done to see if it matches known business information. The results of this check are provided below. | |
| {business_check_summary} | |
| Based on the user's query and the results of the business info check for each part, identify if you need to perform actions. | |
| Output one or more actions, each on a new line, in the format: | |
| ACTION: [ACTION_TYPE]: [Argument/Query for the action] | |
| Possible actions: | |
| 1. **LOOKUP_BUSINESS_INFO**: If a part of the query asks about the business's services, prices, availability, or individuals mentioned in the business context, *and* the business info check for that part indicates a high relevance ({PRE_PASS1_BUSINESS_PART_LOOKUP_THRESHOLD:.2f} or higher). The argument should be the specific phrase or name to look up. | |
| 2. **SEARCH**: If a part of the query asks for current external information (e.g., current events, real-time data, general facts not in business info), *or* if a part that seems like it could be business info did *not* have a high relevance score in the preliminary check (below {PRE_PASS1_BUSINESS_PART_LOOKUP_THRESHOLD:.2f}). The argument should be the precise search query. | |
| 3. **ANSWER_DIRECTLY**: If the overall query is a simple greeting or can be answered from your general knowledge without lookup or search, *and* the business info check results indicate low relevance for all parts. The argument should be the direct answer here. | |
| **Crucially:** | |
| - **Prioritize LOOKUP_BUSINESS_INFO** for any part of the query where the preliminary business info check score was {PRE_PASS1_BUSINESS_PART_LOOKUP_THRESHOLD:.2f} or higher. | |
| - Use **SEARCH** for parts about external information or where the business info check score was below {PRE_PASS1_BUSINESS_PART_LOOKUP_THRESHOLD:.2f}. | |
| - If a part of the query is clearly external (like asking about current events or famous people) even if its business info score wasn't zero, you should likely use SEARCH for it. | |
| - Do NOT output any other text besides the ACTION lines. | |
| - If the results suggest a direct answer is sufficient, use ANSWER_DIRECTLY. | |
| Now, analyze the following user query, considering the business info check results provided above, and output the required actions: | |
| """ | |
| # --- Pass 2 System Prompt --- | |
| pass2_instructions_synthesize = """You are a helpful assistant for a business. You have been provided with the original user query, relevant Business Information (if found), and results from external searches (if performed). | |
| Your task is to synthesize ALL the provided information to answer the user's original question concisely and accurately. | |
| **Prioritize Business Information** for details about the business, its services, or individuals mentioned within that context. | |
| Use the Search Results for current external information that was requested. | |
| If information for a specific part of the question was not found in either Business Information or Search Results, use your general knowledge if possible, or state that the information could not be found. | |
| Synthesize the information into a natural language response. Do NOT copy and paste raw context or strings like 'Business Information:' or 'SEARCH RESULTS:' or 'ACTION:' or the raw user query. | |
| After your answer, generate a few concise follow-up questions that a user might ask based on the previous turn's conversation and your response. List these questions clearly at the end of your response. | |
| When search results were used to answer the question, list the URLs from the search results you used under a "Sources:" heading at the very end. | |
| """ | |
| # --- Main Inference Function for Gradio --- | |
| def respond(user_input, chat_history): | |
| """ | |
| Processes user input, performs actions (lookup/search), and generates a response. | |
| Manages chat history within Gradio state. | |
| """ | |
| # Check if models loaded successfully (Removed nlp from this check) | |
| if model is None or tokenizer is None or embedder is None: | |
| return "", chat_history + [(user_input, "Sorry, the application failed to load necessary components. Please try again later or contact the administrator.")] | |
| original_user_input = user_input | |
| # Initialize action results containers for this turn | |
| search_results_dicts = [] | |
| business_lookup_results_formatted = [] | |
| response_pass1_raw = "" | |
| # --- Pre-Pass 1: Programmatic Business Info Check for Query Parts --- | |
| query_parts = split_query(original_user_input) # This now uses the regex split | |
| business_check_results = [] | |
| overall_pre_pass1_score = 0.0 | |
| print("\n--- Processing new user query ---") | |
| print(f"User: {user_input}") | |
| print("Performing programmatic business info check on query parts...") | |
| if query_parts: | |
| for i, part in enumerate(query_parts): | |
| match, score = retrieve_business_info(part, data, embeddings, embedder, threshold=0.0) | |
| business_check_results.append({"part": part, "score": score, "match": match}) | |
| print(f"- Part '{part}': Score {score:.4f}") | |
| overall_pre_pass1_score = max(overall_pre_pass1_score, score) | |
| else: | |
| match, score = retrieve_business_info(original_user_input, data, embeddings, embedder, threshold=0.0) | |
| business_check_results.append({"part": original_user_input, "score": score, "match": match}) | |
| print(f"- Part '{original_user_input}': Score {score:.4f}") | |
| overall_pre_pass1_score = score | |
| is_likely_direct_answer = overall_pre_pass1_score < PRE_PASS1_BUSINESS_PART_LOOKUP_THRESHOLD and len(query_parts) <= 2 | |
| # Format business check summary for Pass 1 prompt | |
| business_check_summary = "Business Info Check Results for Query Parts:\n" | |
| if business_check_results: | |
| for result in business_check_results: | |
| status = "High Relevance" if result['score'] >= PRE_PASS1_BUSINESS_PART_LOOKUP_THRESHOLD else "Low Relevance" | |
| business_check_summary += f"- Part '{result['part']}': Score {result['score']:.4f} ({status})\n" | |
| else: | |
| business_check_summary += "- No parts identified or check skipped.\n" | |
| business_check_summary += "\n" | |
| # --- Pass 1: Action Identification (if not direct answer) --- | |
| requested_actions = [] | |
| answer_directly_provided = None | |
| if is_likely_direct_answer: | |
| print("Programmatically determined likely direct answer.") | |
| response_pass1_raw = f"ACTION: ANSWER_DIRECTLY: " | |
| else: | |
| pass1_user_message_content = pass1_instructions_action.format( | |
| business_check_summary=business_check_summary, | |
| PRE_PASS1_BUSINESS_PART_LOOKUP_THRESHOLD=PRE_PASS1_BUSINESS_PART_LOOKUP_THRESHOLD | |
| ) + "\n\nUser Query: " + user_input | |
| temp_chat_history_pass1 = [{"role": "user", "content": pass1_user_message_content}] | |
| try: | |
| prompt_pass1 = tokenizer.apply_chat_template( | |
| temp_chat_history_pass1, | |
| tokenize=False, | |
| add_generation_prompt=True | |
| ) | |
| generation_config_pass1 = GenerationConfig( | |
| max_new_tokens=200, | |
| do_sample=False, | |
| temperature=0.1, | |
| eos_token_id=tokenizer.eos_token_id, | |
| pad_token_id=tokenizer.pad_token_id, | |
| use_cache=True | |
| ) | |
| input_ids_pass1 = tokenizer(prompt_pass1, return_tensors="pt").input_ids | |
| if model and input_ids_pass1.numel() > 0: | |
| outputs_pass1 = model.generate( | |
| input_ids=input_ids_pass1, | |
| generation_config=generation_config_pass1, | |
| ) | |
| prompt_length_pass1 = input_ids_pass1.shape[1] | |
| if outputs_pass1.shape[1] > prompt_length_pass1: | |
| generated_tokens_pass1 = outputs_pass1[0, prompt_length_pass1:] | |
| response_pass1_raw = tokenizer.decode(generated_tokens_pass1, skip_special_tokens=True).strip() | |
| else: | |
| response_pass1_raw = "" | |
| else: | |
| response_pass1_raw = "" | |
| except Exception as e: | |
| print(f"Error during Pass 1 (Action Identification): {e}") | |
| response_pass1_raw = f"ACTION: ANSWER_DIRECTLY: Error in Pass 1 - {e}" | |
| # --- Parse Model's Requested Actions with Validation --- | |
| if response_pass1_raw: | |
| lines = response_pass1_raw.strip().split('\n') | |
| for line in lines: | |
| line = line.strip() | |
| if line.startswith(SEARCH_MARKER): | |
| query = line[len(SEARCH_MARKER):].strip() | |
| if query: | |
| _, score = retrieve_business_info(query, data, embeddings, embedder, threshold=0.0) | |
| if score < SEARCH_VALIDATION_THRESHOLD: | |
| requested_actions.append(("SEARCH", query)) | |
| print(f"Validated Search Action for '{query}' (Score: {score:.4f})") | |
| else: | |
| print(f"Rejected Search Action for '{query}' (Score: {score:.4f}) - Too similar to business data.") | |
| elif line.startswith(BUSINESS_LOOKUP_MARKER): | |
| query = line[len(BUSINESS_LOOKUP_MARKER):].strip() | |
| if query: | |
| match, score = retrieve_business_info(query, data, embeddings, embedder, threshold=0.0) | |
| if score > BUSINESS_LOOKUP_VALIDATION_THRESHOLD: | |
| requested_actions.append(("LOOKUP_BUSINESS_INFO", query)) | |
| print(f"Validated Business Lookup Action for '{query}' (Score: {score:.4f})") | |
| else: | |
| print(f"Rejected Business Lookup Action for '{query}' (Score: {score:.4f}) - Below validation threshold.") | |
| elif line.startswith(ANSWER_DIRECTLY_MARKER): | |
| answer = line[len(ANSWER_DIRECTLY_MARKER):].strip() | |
| answer_directly_provided = answer if answer else original_user_input | |
| requested_actions = [] | |
| break | |
| # --- Execute Actions (Search and Lookup) --- | |
| context_for_pass2 = "" | |
| if requested_actions: | |
| print("Executing requested actions...") | |
| for action_type, query in requested_actions: | |
| if action_type == "SEARCH": | |
| print(f"Performing search for: '{query}'") | |
| results = perform_duckduckgo_search(query) | |
| if results: | |
| search_results_dicts.extend(results) | |
| print(f"Found {len(results)} search results.") | |
| else: | |
| print(f"No search results found for '{query}'.") | |
| elif action_type == "LOOKUP_BUSINESS_INFO": | |
| print(f"Performing business info lookup for: '{query}'") | |
| match, score = retrieve_business_info(query, data, embeddings, embedder, threshold=retrieve_business_info.__defaults__[0]) | |
| print(f"Actual lookup score for '{query}': {score:.4f} (Threshold: {retrieve_business_info.__defaults__[0]})") | |
| if match: | |
| formatted_match = f"""Service: {match.get('Service', 'N/A')} | |
| Description: {match.get('Description', 'N/A')} | |
| Price: {match.get('Price', 'N/A')} | |
| Available: {match.get('Available', 'N/A')}""" | |
| business_lookup_results_formatted.append(formatted_match) | |
| print(f"Found business info match.") | |
| else: | |
| print(f"No business info match found for '{query}' at threshold {retrieve_business_info.__defaults__[0]}.") | |
| # --- Prepare Context for Pass 2 based on executed actions --- | |
| if business_lookup_results_formatted: | |
| context_for_pass2 += "Business Information (Use this for questions about the business):\n" | |
| context_for_pass2 += "\n---\n".join(business_lookup_results_formatted) | |
| context_for_pass2 += "\n\n" | |
| if search_results_dicts: | |
| context_for_pass2 += "SEARCH RESULTS (Use this for current external information):\n" | |
| aggregated_search_results_formatted = [] | |
| for result in search_results_dicts: | |
| aggregated_search_results_formatted.append(f"Title: {result.get('title', 'N/A')}\nSnippet: {result.get('body', 'N/A')}\nURL: {result.get('href', 'N/A')}") | |
| context_for_pass2 += "\n---\n".join(aggregated_search_results_formatted) + "\n\n" | |
| if requested_actions and not business_lookup_results_formatted and not search_results_dicts: | |
| context_for_pass2 = "Note: No relevant information was found in Business Information or via Search for your query." | |
| print("Note: No results were found for the requested actions.") | |
| # If ANSWER_DIRECTLY was determined | |
| if answer_directly_provided is not None: | |
| print(f"Handling as direct answer: {answer_directly_provided}") | |
| context_for_pass2 = "Note: This query is a simple request or greeting." | |
| if answer_directly_provided != original_user_input and answer_directly_provided != "": | |
| context_for_pass2 += f" Initial suggestion from action step: {answer_directly_provided}" | |
| search_results_dicts = [] | |
| business_lookup_results_formatted = [] | |
| # If no actions or direct answer, and no results | |
| if not requested_actions and answer_directly_provided is None: | |
| if response_pass1_raw.strip(): | |
| print("Warning: Pass 1 did not result in valid actions or a direct answer.") | |
| context_for_pass2 = f"Error: Could not determine actions from Pass 1 response: '{response_pass1_raw}'." | |
| else: | |
| print("Warning: Pass 1 generated an empty response.") | |
| context_for_pass2 = "Error: Pass 1 generated an empty response." | |
| # --- Pass 2: Synthesize and Respond --- | |
| final_response = "Sorry, I couldn't generate a response." | |
| if model is not None and tokenizer is not None: | |
| pass2_user_message_content = pass2_instructions_synthesize + "\n\nOriginal User Query: " + original_user_input + "\n\n" + context_for_pass2 | |
| model_chat_history = [] | |
| for user_msg, bot_msg in chat_history: | |
| model_chat_history.append({"role": "user", "content": user_msg}) | |
| model_chat_history.append({"role": "assistant", "content": bot_msg}) | |
| model_chat_history.append({"role": "user", "content": pass2_user_message_content}) | |
| try: | |
| prompt_pass2 = tokenizer.apply_chat_template( | |
| model_chat_history, | |
| tokenize=False, | |
| add_generation_prompt=True | |
| ) | |
| generation_config_pass2 = GenerationConfig( | |
| max_new_tokens=1500, | |
| do_sample=True, | |
| temperature=0.7, | |
| top_k=50, | |
| top_p=0.95, | |
| repetition_penalty=1.1, | |
| eos_token_id=tokenizer.eos_token_id, | |
| pad_token_id=tokenizer.pad_token_id, | |
| use_cache=True | |
| ) | |
| input_ids_pass2 = tokenizer(prompt_pass2, return_tensors="pt").input_ids | |
| if model and input_ids_pass2.numel() > 0: | |
| outputs_pass2 = model.generate( | |
| input_ids=input_ids_pass2, | |
| generation_config=generation_config_pass2, | |
| ) | |
| prompt_length_pass2 = input_ids_pass2.shape[1] | |
| if outputs_pass2.shape[1] > prompt_length_pass2: | |
| generated_tokens_pass2 = outputs_pass2[0, prompt_length_pass2:] | |
| final_response = tokenizer.decode(generated_tokens_pass2, skip_special_tokens=True).strip() | |
| else: | |
| final_response = "..." | |
| else: | |
| final_response = "Error: Model or empty input for Pass 2." | |
| except Exception as gen_error: | |
| print(f"Error during model generation in Pass 2: {gen_error}") | |
| final_response = "Error generating response in Pass 2." | |
| # --- Post-process Final Response from Pass 2 --- | |
| cleaned_response = final_response | |
| lines = cleaned_response.split('\n') | |
| cleaned_lines = [line for line in lines if not line.strip().lower().startswith("business information") | |
| and not line.strip().lower().startswith("search results") | |
| and not line.strip().startswith("---") | |
| and not line.strip().lower().startswith("original user query:") | |
| and not line.strip().lower().startswith("you are a helpful assistant for a business.")] | |
| cleaned_response = "\n".join(cleaned_lines).strip() | |
| urls_to_list = [result.get('href') for result in search_results_dicts if result.get('href')] | |
| urls_to_list = list(dict.fromkeys(urls_to_list)) | |
| if search_results_dicts and urls_to_list: | |
| cleaned_response += "\n\nSources:\n" + "\n".join(urls_to_list) | |
| final_response = cleaned_response | |
| if not final_response.strip(): | |
| final_response = "Sorry, I couldn't generate a meaningful response based on the information found." | |
| print("Warning: Final response was empty after cleaning.") | |
| else: | |
| final_response = "Sorry, the core language model is not available." | |
| print("Error: LLM model or tokenizer not loaded for Pass 2.") | |
| # --- Update Chat History for Gradio --- | |
| updated_chat_history = chat_history + [(original_user_input, final_response)] | |
| max_history_pairs = 10 | |
| if len(updated_chat_history) > max_history_pairs: | |
| updated_chat_history = updated_chat_history[-max_history_pairs:] | |
| return "", updated_chat_history |