Spaces:
Sleeping
Sleeping
| import os | |
| import logging | |
| import sys | |
| from transformers import pipeline | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
| stream=sys.stdout | |
| ) | |
| logger = logging.getLogger(__name__) | |
| class QueryRewriter: | |
| def __init__(self): | |
| self.model = pipeline( | |
| "text-generation", | |
| model="google/flan-t5-base", # Using a smaller model suitable for Spaces | |
| device=-1 # Use CPU | |
| ) | |
| logger.info("Initialized QueryRewriter with flan-t5-base model") | |
| def generate(self, prompt): | |
| try: | |
| response = self.model( | |
| prompt, | |
| max_length=256, | |
| min_length=32, | |
| num_return_sequences=1 | |
| )[0]['generated_text'] | |
| return response | |
| except Exception as e: | |
| logger.error(f"Error generating response: {e}") | |
| return None | |
| def rewrite_cot(self, query): | |
| prompt = f""" | |
| Rewrite the following query using step-by-step reasoning: | |
| Original query: {query} | |
| Steps: | |
| 1. What is the main question being asked? | |
| 2. What are the key components? | |
| 3. How can we make it clearer? | |
| Rewritten query: | |
| """ | |
| rewritten_query = self.generate(prompt) | |
| if rewritten_query is None: | |
| logger.error(f"Error in CoT rewriting for query: {query}") | |
| return query, prompt # Return original query if rewriting fails | |
| # Extract the rewritten query (everything after "Rewritten query:") | |
| try: | |
| final_query = rewritten_query.split("Rewritten query:")[-1].strip() | |
| return final_query, prompt | |
| except Exception as e: | |
| logger.error(f"Error extracting rewritten query: {e}") | |
| return query, prompt | |
| def rewrite_react(self, query): | |
| prompt = f""" | |
| Rewrite the following query using a systematic approach: | |
| Original query: {query} | |
| Thought: What information are we looking for? | |
| Action: Break down the query into key components | |
| Observation: Identify the main focus | |
| Rewritten query: | |
| """ | |
| rewritten_query = self.generate(prompt) | |
| if rewritten_query is None: | |
| logger.error(f"Error in ReAct rewriting for query: {query}") | |
| return query, prompt | |
| # Extract the rewritten query | |
| try: | |
| final_query = rewritten_query.split("Rewritten query:")[-1].strip() | |
| return final_query, prompt | |
| except Exception as e: | |
| logger.error(f"Error extracting rewritten query: {e}") | |
| return query, prompt |