Spaces:
Running
Running
| from flask import Flask, render_template, request, jsonify | |
| from api_base import create_api # New import for API factory | |
| from plain_text_reasoning import ( | |
| create_mermaid_diagram as create_plain_diagram, | |
| parse_plain_text_response | |
| ) | |
| from cot_reasoning import ( | |
| VisualizationConfig, | |
| create_mermaid_diagram as create_cot_diagram, | |
| parse_cot_response | |
| ) | |
| from tot_reasoning import ( | |
| create_mermaid_diagram as create_tot_diagram, | |
| parse_tot_response | |
| ) | |
| from l2m_reasoning import ( | |
| create_mermaid_diagram as create_l2m_diagram, | |
| parse_l2m_response | |
| ) | |
| from selfconsistency_reasoning import ( | |
| create_mermaid_diagram as create_scr_diagram, | |
| parse_scr_response | |
| ) | |
| from selfrefine_reasoning import ( | |
| create_mermaid_diagram as create_srf_diagram, | |
| parse_selfrefine_response | |
| ) | |
| from bs_reasoning import ( | |
| create_mermaid_diagram as create_bs_diagram, | |
| parse_bs_response | |
| ) | |
| from configs import config | |
| import logging | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # Initialize Flask app | |
| app = Flask(__name__) | |
| def index(): | |
| """Render the main page""" | |
| return render_template('index.html') | |
| def index_direct(): | |
| """Directly render the main page when accessed via index.html""" | |
| return render_template('index.html') | |
| def index_cn(): | |
| """Render the Chinese version of the main page""" | |
| return render_template('index_cn.html') | |
| def get_config(): | |
| """Get initial configuration""" | |
| return jsonify(config.get_initial_values()) | |
| def get_method_config(method_id): | |
| """Get configuration for specific method""" | |
| method_config = config.get_method_config(method_id) | |
| if method_config: | |
| return jsonify(method_config) | |
| return jsonify({"error": "Method not found"}), 404 | |
| def get_provider_api_key(provider): | |
| """Get default API key for specific provider""" | |
| try: | |
| api_key = config.general.get_default_api_key(provider) | |
| return jsonify({ | |
| 'success': True, | |
| 'api_key': api_key | |
| }) | |
| except Exception as e: | |
| logger.error(f"Error getting API key for provider {provider}: {str(e)}") | |
| return jsonify({ | |
| 'success': False, | |
| 'error': str(e) | |
| }), 500 | |
| def save_api_key(): | |
| """Save API key for a provider in memory only (no file storage)""" | |
| try: | |
| data = request.json | |
| if not data: | |
| return jsonify({ | |
| 'success': False, | |
| 'error': 'No data provided' | |
| }), 400 | |
| provider = data.get('provider') | |
| api_key = data.get('api_key') | |
| if not provider or not api_key: | |
| return jsonify({ | |
| 'success': False, | |
| 'error': 'Provider and API key are required' | |
| }), 400 | |
| # Update API key in config (this updates the in-memory API keys only) | |
| config.general.provider_api_keys[provider] = api_key | |
| logger.info(f"Saved API key for provider: {provider} (in memory only)") | |
| return jsonify({ | |
| 'success': True | |
| }) | |
| except Exception as e: | |
| logger.error(f"Error saving API key: {str(e)}") | |
| return jsonify({ | |
| 'success': False, | |
| 'error': str(e) | |
| }), 500 | |
| def select_method(): | |
| """Let the model select the most appropriate reasoning method""" | |
| try: | |
| data = request.json | |
| if not data: | |
| return jsonify({'success': False, 'error': 'No data provided'}), 400 | |
| # Extract parameters | |
| api_key = data.get('api_key') | |
| provider = data.get('provider', 'anthropic') | |
| model = data.get('model') | |
| question = data.get('question') | |
| if not all([api_key, model, question]): | |
| return jsonify({'success': False, 'error': 'Missing required parameters'}), 400 | |
| # Create the selection prompt | |
| methods = config.methods | |
| prompt = f"""Given this question: "{question}" | |
| Please select the most appropriate reasoning method from the following options to solve it: | |
| {chr(10).join(f'- {method_id}: {config.name}' for method_id, config in methods.items())} | |
| Consider the characteristics of each method and the nature of the question. | |
| Output your selection in exactly this format: | |
| <selected_method>method_id</selected_method> | |
| where method_id is strictly one of: {', '.join(methods.keys())}. | |
| Do not use the method or words that are not in {', '.join(methods.keys())}.""" | |
| # Get model's selection | |
| try: | |
| api = create_api(provider, api_key, model) | |
| response = api.generate_response(prompt, max_tokens=100) | |
| # Extract method ID using basic string parsing | |
| import re | |
| match = re.search(r'<selected_method>(\w+)</selected_method>', response) | |
| if match and match.group(1) in methods: | |
| selected_method = match.group(1) | |
| return jsonify({ | |
| 'success': True, | |
| 'selected_method': selected_method, | |
| 'raw_response': response | |
| }) | |
| else: | |
| return jsonify({ | |
| 'success': False, | |
| 'error': 'Invalid method selection in response' | |
| }), 400 | |
| except Exception as e: | |
| return jsonify({ | |
| 'success': False, | |
| 'error': f'API call failed: {str(e)}' | |
| }), 500 | |
| except Exception as e: | |
| logger.error(f"Error in method selection: {str(e)}") | |
| return jsonify({ | |
| 'success': False, | |
| 'error': str(e) | |
| }), 500 | |
| def process(): | |
| """Process the reasoning request""" | |
| try: | |
| # Get request data | |
| data = request.json | |
| if not data: | |
| return jsonify({ | |
| 'success': False, | |
| 'error': 'No data provided' | |
| }), 400 | |
| # Extract parameters | |
| api_key = data.get('api_key') | |
| if not api_key: | |
| return jsonify({ | |
| 'success': False, | |
| 'error': 'API key is required' | |
| }), 400 | |
| question = data.get('question') | |
| if not question: | |
| return jsonify({ | |
| 'success': False, | |
| 'error': 'Question is required' | |
| }), 400 | |
| # Get optional parameters with defaults | |
| provider = data.get('provider', 'anthropic') # New parameter for provider | |
| model = data.get('model', config.general.available_models[0]) | |
| max_tokens = int(data.get('max_tokens', config.general.max_tokens)) | |
| prompt_format = data.get('prompt_format') | |
| chars_per_line = int(data.get('chars_per_line', config.general.chars_per_line)) | |
| max_lines = int(data.get('max_lines', config.general.max_lines)) | |
| reasoning_method = data.get('reasoning_method', 'cot') | |
| # Initialize API with factory function | |
| try: | |
| api = create_api(provider, api_key, model) | |
| except Exception as e: | |
| return jsonify({ | |
| 'success': False, | |
| 'error': f'Failed to initialize API: {str(e)}' | |
| }), 400 | |
| # Get model response | |
| logger.info(f"Generating response for question using {provider} {model}") | |
| try: | |
| raw_response = api.generate_response( | |
| question, | |
| max_tokens=max_tokens, | |
| prompt_format=prompt_format | |
| ) | |
| except Exception as e: | |
| return jsonify({ | |
| 'success': False, | |
| 'error': f'API call failed: {str(e)}' | |
| }), 500 | |
| # Create visualization config | |
| viz_config = VisualizationConfig( | |
| max_chars_per_line=chars_per_line, | |
| max_lines=max_lines | |
| ) | |
| # Generate visualization based on reasoning method | |
| visualization = None | |
| try: | |
| if reasoning_method == 'cot': | |
| result = parse_cot_response(raw_response, question) | |
| visualization = create_cot_diagram(result, viz_config) | |
| elif reasoning_method == 'tot': | |
| result = parse_tot_response(raw_response, question) | |
| visualization = create_tot_diagram(result, viz_config) | |
| elif reasoning_method == 'l2m': | |
| result = parse_l2m_response(raw_response, question) | |
| visualization = create_l2m_diagram(result, viz_config) | |
| elif reasoning_method == 'scr': | |
| result = parse_scr_response(raw_response, question) | |
| visualization = create_scr_diagram(result, viz_config) | |
| elif reasoning_method == 'srf': | |
| result = parse_selfrefine_response(raw_response, question) | |
| visualization = create_srf_diagram(result, viz_config) | |
| elif reasoning_method == 'bs': | |
| result = parse_bs_response(raw_response, question) | |
| visualization = create_bs_diagram(result, viz_config) | |
| elif reasoning_method == 'plain': | |
| parse_plain_text_response(raw_response, question) | |
| visualization = None | |
| logger.info("Successfully generated visualization") | |
| except Exception as viz_error: | |
| logger.error(f"Visualization generation failed: {str(viz_error)}") | |
| # Continue without visualization | |
| # Return successful response | |
| return jsonify({ | |
| 'success': True, | |
| 'raw_output': raw_response, | |
| 'visualization': visualization | |
| }) | |
| except Exception as e: | |
| # Log the error and return error response | |
| logger.error(f"Error processing request: {str(e)}") | |
| return jsonify({ | |
| 'success': False, | |
| 'error': str(e) | |
| }), 500 | |
| def not_found_error(error): | |
| """Handle 404 errors""" | |
| return jsonify({ | |
| 'success': False, | |
| 'error': 'Resource not found' | |
| }), 404 | |
| def internal_error(error): | |
| """Handle 500 errors""" | |
| return jsonify({ | |
| 'success': False, | |
| 'error': 'Internal server error' | |
| }), 500 | |
| if __name__ == '__main__': | |
| try: | |
| # Run the application | |
| app.run( | |
| host='0.0.0.0', | |
| port=5001, | |
| debug=False # Disable debug mode in production | |
| ) | |
| except Exception as e: | |
| logger.error(f"Failed to start application: {str(e)}") |