Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import os | |
| import json | |
| import logging | |
| import numpy as np | |
| from utils import (PromptTemplate, api_configs, setup_logging) | |
| from data_loader import load_data | |
| from evaluate import evaluate | |
| from main import SwiftSage, run_test, run_benchmark | |
| import multiprocessing | |
| def solve_problem(problem, max_iterations, reward_threshold, swift_model_id, sage_model_id, reward_model_id, use_retrieval, start_with_sage): | |
| # Configuration for each LLM | |
| max_iterations = int(max_iterations) | |
| reward_threshold = int(reward_threshold) | |
| swift_config = { | |
| "model_id": swift_model_id, | |
| "api_config": api_configs['Together'] | |
| } | |
| reward_config = { | |
| "model_id": reward_model_id, | |
| "api_config": api_configs['Together'] | |
| } | |
| sage_config = { | |
| "model_id": sage_model_id, | |
| "api_config": api_configs['Together'] | |
| } | |
| # specify the path to the prompt templates | |
| prompt_template_dir = './prompt_templates' | |
| dataset = [] | |
| embeddings = [] # TODO: for retrieval augmentation (not implemented yet now) | |
| s2 = SwiftSage( | |
| dataset, | |
| embeddings, | |
| prompt_template_dir, | |
| swift_config, | |
| sage_config, | |
| reward_config, | |
| use_retrieval=use_retrieval, | |
| start_with_sage=start_with_sage, | |
| ) | |
| reasoning, solution = s2.solve(problem, max_iterations, reward_threshold) | |
| solution = solution.replace("Answer (from running the code):\n ", " ") | |
| return reasoning, solution | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| # gr.Markdown("## SwiftSage: A Multi-Agent Framework for Reasoning") | |
| # use the html and center the title | |
| gr.HTML("<h1 style='text-align: center;'>SwiftSage: A Multi-Agent Framework for Reasoning</h1>") | |
| with gr.Row(): | |
| swift_model_id = gr.Textbox(label="π Swift Model ID", value="meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo") | |
| reward_model_id = gr.Textbox(label="π€ Feedback Model ID", value="meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo") | |
| sage_model_id = gr.Textbox(label="π Sage Model ID", value="meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo") | |
| # the following two should have a smaller width | |
| with gr.Accordion(label="βοΈ Advanced Options", open=False): | |
| with gr.Row(): | |
| with gr.Column(): | |
| max_iterations = gr.Textbox(label="Max Iterations", value="5") | |
| reward_threshold = gr.Textbox(label="Reward Threshold", value="8") | |
| # TODO: add top-p and temperature for each module for controlling | |
| with gr.Column(): | |
| top_p_swift = gr.Textbox(label="Top-p for Swift", value="0.9") | |
| temperature_swift = gr.Textbox(label="Temperature for Swift", value="0.7") | |
| with gr.Column(): | |
| top_p_sage = gr.Textbox(label="Top-p for Sage", value="0.9") | |
| temperature_sage = gr.Textbox(label="Temperature for Sage", value="0.7") | |
| with gr.Column(): | |
| top_p_reward = gr.Textbox(label="Top-p for Feedback", value="0.9") | |
| temperature_reward = gr.Textbox(label="Temperature for Feedback", value="0.7") | |
| use_retrieval = gr.Checkbox(label="Use Retrieval Augmentation", value=False, visible=False) | |
| start_with_sage = gr.Checkbox(label="Start with Sage", value=False, visible=False) | |
| problem = gr.Textbox(label="Input your problem", value="How many letter r are there in the sentence 'My strawberry is so ridiculously red.'?", lines=2) | |
| solve_button = gr.Button("π Solve Problem") | |
| reasoning_output = gr.Textbox(label="Reasoning steps with Code", interactive=False) | |
| solution_output = gr.Textbox(label="Final answer", interactive=False) | |
| solve_button.click( | |
| solve_problem, | |
| inputs=[problem, max_iterations, reward_threshold, swift_model_id, sage_model_id, reward_model_id, use_retrieval, start_with_sage], | |
| outputs=[reasoning_output, solution_output] | |
| ) | |
| if __name__ == '__main__': | |
| multiprocessing.set_start_method('spawn') | |
| demo.launch(share=False, show_api=False) | |