Spaces:
Runtime error
Runtime error
Wonderplex
commited on
Commit
·
c3a4051
1
Parent(s):
c423c55
Feature/select agent env (#45)
Browse files* changed ui to include scenario and agent info
* ui layout correct; need to fix logics
* half-way through; need to fix record reading and agent pair filtering logic
* fixed deletion of app.py
* debugging gradio change
* before debug
* finished UI features
* added 5 times retry
* finished merging
- app.py +48 -27
- requirements.txt +1 -0
- sotopia_pi_generate.py +3 -3
- utils.py +1 -1
app.py
CHANGED
|
@@ -12,7 +12,7 @@ with open("openai_api.key", "r") as f:
|
|
| 12 |
os.environ["OPENAI_API_KEY"] = f.read().strip()
|
| 13 |
|
| 14 |
DEPLOYED = os.getenv("DEPLOYED", "true").lower() == "true"
|
| 15 |
-
DEFAULT_MODEL_SELECTION = "
|
| 16 |
TEMPERATURE = 0.7
|
| 17 |
TOP_P = 1
|
| 18 |
MAX_TOKENS = 1024
|
|
@@ -100,6 +100,7 @@ def create_bot_agent_dropdown(environment_id, user_agent_id):
|
|
| 100 |
environment, user_agent = environment_dict[environment_id], agent_dict[user_agent_id]
|
| 101 |
|
| 102 |
bot_agent_list = []
|
|
|
|
| 103 |
for neighbor_id in relationship_dict[environment.relationship][user_agent.agent_id]:
|
| 104 |
bot_agent_list.append((agent_dict[neighbor_id].name, neighbor_id))
|
| 105 |
|
|
@@ -109,46 +110,62 @@ def create_environment_info(environment_dropdown):
|
|
| 109 |
_, environment_dict, _, _ = get_sotopia_profiles()
|
| 110 |
environment = environment_dict[environment_dropdown]
|
| 111 |
text = environment.scenario
|
| 112 |
-
return gr.Textbox(label="Scenario
|
| 113 |
|
| 114 |
-
def create_user_info(
|
| 115 |
-
_,
|
| 116 |
-
|
| 117 |
-
text = f"{user_agent.background} {user_agent.personality}
|
| 118 |
return gr.Textbox(label="User Agent Profile", lines=4, value=text)
|
| 119 |
|
| 120 |
-
def create_bot_info(
|
| 121 |
-
_,
|
| 122 |
-
|
| 123 |
-
|
|
|
|
| 124 |
return gr.Textbox(label="Bot Agent Profile", lines=4, value=text)
|
| 125 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
def sotopia_info_accordion(accordion_visible=True):
|
|
|
|
| 127 |
|
| 128 |
-
with gr.Accordion("
|
| 129 |
-
with gr.Column():
|
| 130 |
-
model_name_dropdown = gr.Dropdown(
|
| 131 |
-
choices=["cmu-lti/sotopia-pi-mistral-7b-BC_SR", "mistralai/Mistral-7B-Instruct-v0.1", "gpt-3.5-turbo"],
|
| 132 |
-
value="cmu-lti/sotopia-pi-mistral-7b-BC_SR",
|
| 133 |
-
interactive=True,
|
| 134 |
-
label="Model Selection"
|
| 135 |
-
)
|
| 136 |
with gr.Row():
|
| 137 |
-
environments, _, _, _ = get_sotopia_profiles()
|
| 138 |
environment_dropdown = gr.Dropdown(
|
| 139 |
choices=environments,
|
| 140 |
label="Scenario Selection",
|
| 141 |
value=environments[0][1] if environments else None,
|
| 142 |
interactive=True,
|
| 143 |
)
|
| 144 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 145 |
user_agent_dropdown = create_user_agent_dropdown(environment_dropdown.value)
|
| 146 |
bot_agent_dropdown = create_bot_agent_dropdown(environment_dropdown.value, user_agent_dropdown.value)
|
| 147 |
|
| 148 |
with gr.Row():
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
bot_agent_info_display = create_bot_info(environment_dropdown.value, bot_agent_dropdown.value)
|
| 152 |
|
| 153 |
# Update user dropdown when scenario changes
|
| 154 |
environment_dropdown.change(fn=create_user_agent_dropdown, inputs=[environment_dropdown], outputs=[user_agent_dropdown])
|
|
@@ -157,9 +174,13 @@ def sotopia_info_accordion(accordion_visible=True):
|
|
| 157 |
# Update scenario information when scenario changes
|
| 158 |
environment_dropdown.change(fn=create_environment_info, inputs=[environment_dropdown], outputs=[scenario_info_display])
|
| 159 |
# Update user agent profile when user changes
|
| 160 |
-
user_agent_dropdown.change(fn=create_user_info, inputs=[
|
| 161 |
# Update bot agent profile when bot changes
|
| 162 |
-
bot_agent_dropdown.change(fn=create_bot_info, inputs=[
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
|
| 164 |
return model_name_dropdown, environment_dropdown, user_agent_dropdown, bot_agent_dropdown
|
| 165 |
|
|
@@ -192,12 +213,12 @@ def chat_tab():
|
|
| 192 |
user_agent = agent_dict[user_agent_dropdown]
|
| 193 |
bot_agent = agent_dict[bot_agent_dropdown]
|
| 194 |
|
| 195 |
-
import pdb; pdb.set_trace()
|
| 196 |
context = get_context_prompt(bot_agent, user_agent, environment)
|
| 197 |
dialogue_history, next_turn_idx = dialogue_history_prompt(message, history, user_agent, bot_agent)
|
| 198 |
prompt_history = f"{context}\n\n{dialogue_history}"
|
| 199 |
agent_action = generate_action(model_selection, prompt_history, next_turn_idx, ACTION_TYPES, bot_agent.name, TEMPERATURE)
|
| 200 |
-
import pdb; pdb.set_trace()
|
| 201 |
return agent_action.to_natural_language()
|
| 202 |
|
| 203 |
with gr.Column():
|
|
|
|
| 12 |
os.environ["OPENAI_API_KEY"] = f.read().strip()
|
| 13 |
|
| 14 |
DEPLOYED = os.getenv("DEPLOYED", "true").lower() == "true"
|
| 15 |
+
DEFAULT_MODEL_SELECTION = "gpt-3.5-turbo" # "mistralai/Mistral-7B-Instruct-v0.1"
|
| 16 |
TEMPERATURE = 0.7
|
| 17 |
TOP_P = 1
|
| 18 |
MAX_TOKENS = 1024
|
|
|
|
| 100 |
environment, user_agent = environment_dict[environment_id], agent_dict[user_agent_id]
|
| 101 |
|
| 102 |
bot_agent_list = []
|
| 103 |
+
# import pdb; pdb.set_trace()
|
| 104 |
for neighbor_id in relationship_dict[environment.relationship][user_agent.agent_id]:
|
| 105 |
bot_agent_list.append((agent_dict[neighbor_id].name, neighbor_id))
|
| 106 |
|
|
|
|
| 110 |
_, environment_dict, _, _ = get_sotopia_profiles()
|
| 111 |
environment = environment_dict[environment_dropdown]
|
| 112 |
text = environment.scenario
|
| 113 |
+
return gr.Textbox(label="Scenario", lines=1, value=text)
|
| 114 |
|
| 115 |
+
def create_user_info(user_agent_dropdown):
|
| 116 |
+
_, _, agent_dict, _ = get_sotopia_profiles()
|
| 117 |
+
user_agent = agent_dict[user_agent_dropdown]
|
| 118 |
+
text = f"{user_agent.background} {user_agent.personality}"
|
| 119 |
return gr.Textbox(label="User Agent Profile", lines=4, value=text)
|
| 120 |
|
| 121 |
+
def create_bot_info(bot_agent_dropdown):
|
| 122 |
+
_, _, agent_dict, _ = get_sotopia_profiles()
|
| 123 |
+
# import pdb; pdb.set_trace()
|
| 124 |
+
bot_agent = agent_dict[bot_agent_dropdown]
|
| 125 |
+
text = f"{bot_agent.background} {bot_agent.personality}"
|
| 126 |
return gr.Textbox(label="Bot Agent Profile", lines=4, value=text)
|
| 127 |
|
| 128 |
+
def create_user_goal(environment_dropdown):
|
| 129 |
+
_, environment_dict, _, _ = get_sotopia_profiles()
|
| 130 |
+
text = environment_dict[environment_dropdown].agent_goals[0]
|
| 131 |
+
return gr.Textbox(label="User Agent Goal", lines=4, value=text)
|
| 132 |
+
|
| 133 |
+
def create_bot_goal(environment_dropdown):
|
| 134 |
+
_, environment_dict, _, _ = get_sotopia_profiles()
|
| 135 |
+
text = environment_dict[environment_dropdown].agent_goals[1]
|
| 136 |
+
return gr.Textbox(label="Bot Agent Goal", lines=4, value=text)
|
| 137 |
+
|
| 138 |
def sotopia_info_accordion(accordion_visible=True):
|
| 139 |
+
environments, _, _, _ = get_sotopia_profiles()
|
| 140 |
|
| 141 |
+
with gr.Accordion("Environment Configuration", open=accordion_visible):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
with gr.Row():
|
|
|
|
| 143 |
environment_dropdown = gr.Dropdown(
|
| 144 |
choices=environments,
|
| 145 |
label="Scenario Selection",
|
| 146 |
value=environments[0][1] if environments else None,
|
| 147 |
interactive=True,
|
| 148 |
)
|
| 149 |
+
model_name_dropdown = gr.Dropdown(
|
| 150 |
+
choices=["cmu-lti/sotopia-pi-mistral-7b-BC_SR", "mistralai/Mistral-7B-Instruct-v0.1", "gpt-3.5-turbo", "gpt-4-turbo"],
|
| 151 |
+
value=DEFAULT_MODEL_SELECTION,
|
| 152 |
+
interactive=True,
|
| 153 |
+
label="Model Selection"
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
scenario_info_display = create_environment_info(environment_dropdown.value)
|
| 157 |
+
|
| 158 |
+
with gr.Row():
|
| 159 |
+
bot_goal_display = create_bot_goal(environment_dropdown.value)
|
| 160 |
+
user_goal_display = create_user_goal(environment_dropdown.value)
|
| 161 |
+
|
| 162 |
+
with gr.Row():
|
| 163 |
user_agent_dropdown = create_user_agent_dropdown(environment_dropdown.value)
|
| 164 |
bot_agent_dropdown = create_bot_agent_dropdown(environment_dropdown.value, user_agent_dropdown.value)
|
| 165 |
|
| 166 |
with gr.Row():
|
| 167 |
+
user_agent_info_display = create_user_info(user_agent_dropdown.value)
|
| 168 |
+
bot_agent_info_display = create_bot_info(bot_agent_dropdown.value)
|
|
|
|
| 169 |
|
| 170 |
# Update user dropdown when scenario changes
|
| 171 |
environment_dropdown.change(fn=create_user_agent_dropdown, inputs=[environment_dropdown], outputs=[user_agent_dropdown])
|
|
|
|
| 174 |
# Update scenario information when scenario changes
|
| 175 |
environment_dropdown.change(fn=create_environment_info, inputs=[environment_dropdown], outputs=[scenario_info_display])
|
| 176 |
# Update user agent profile when user changes
|
| 177 |
+
user_agent_dropdown.change(fn=create_user_info, inputs=[user_agent_dropdown], outputs=[user_agent_info_display])
|
| 178 |
# Update bot agent profile when bot changes
|
| 179 |
+
bot_agent_dropdown.change(fn=create_bot_info, inputs=[bot_agent_dropdown], outputs=[bot_agent_info_display])
|
| 180 |
+
# Update user goal when scenario changes
|
| 181 |
+
environment_dropdown.change(fn=create_user_goal, inputs=[environment_dropdown], outputs=[user_goal_display])
|
| 182 |
+
# Update bot goal when scenario changes
|
| 183 |
+
environment_dropdown.change(fn=create_bot_goal, inputs=[environment_dropdown], outputs=[bot_goal_display])
|
| 184 |
|
| 185 |
return model_name_dropdown, environment_dropdown, user_agent_dropdown, bot_agent_dropdown
|
| 186 |
|
|
|
|
| 213 |
user_agent = agent_dict[user_agent_dropdown]
|
| 214 |
bot_agent = agent_dict[bot_agent_dropdown]
|
| 215 |
|
| 216 |
+
# import pdb; pdb.set_trace()
|
| 217 |
context = get_context_prompt(bot_agent, user_agent, environment)
|
| 218 |
dialogue_history, next_turn_idx = dialogue_history_prompt(message, history, user_agent, bot_agent)
|
| 219 |
prompt_history = f"{context}\n\n{dialogue_history}"
|
| 220 |
agent_action = generate_action(model_selection, prompt_history, next_turn_idx, ACTION_TYPES, bot_agent.name, TEMPERATURE)
|
| 221 |
+
# import pdb; pdb.set_trace()
|
| 222 |
return agent_action.to_natural_language()
|
| 223 |
|
| 224 |
with gr.Column():
|
requirements.txt
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
gradio
|
| 2 |
transformers
|
| 3 |
torch
|
|
|
|
| 1 |
+
sotopia
|
| 2 |
gradio
|
| 3 |
transformers
|
| 4 |
torch
|
sotopia_pi_generate.py
CHANGED
|
@@ -113,7 +113,7 @@ def obtain_chain_hf(
|
|
| 113 |
model, tokenizer = prepare_model(model_name)
|
| 114 |
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=max_tokens, temperature=temperature)
|
| 115 |
hf = HuggingFacePipeline(pipeline=pipe)
|
| 116 |
-
import pdb; pdb.set_trace()
|
| 117 |
chain = LLMChain(llm=hf, prompt=chat_prompt_template)
|
| 118 |
return chain
|
| 119 |
|
|
@@ -124,7 +124,7 @@ def generate(
|
|
| 124 |
output_parser: BaseOutputParser[OutputType],
|
| 125 |
temperature: float = 0.7,
|
| 126 |
) -> tuple[OutputType, str]:
|
| 127 |
-
import pdb; pdb.set_trace()
|
| 128 |
input_variables = re.findall(r"{(.*?)}", template)
|
| 129 |
assert (
|
| 130 |
set(input_variables) == set(list(input_values.keys()) + ["format_instructions"])
|
|
@@ -136,7 +136,7 @@ def generate(
|
|
| 136 |
if "format_instructions" not in input_values:
|
| 137 |
input_values["format_instructions"] = output_parser.get_format_instructions()
|
| 138 |
result = chain.predict([], **input_values)
|
| 139 |
-
import pdb; pdb.set_trace()
|
| 140 |
try:
|
| 141 |
parsed_result = output_parser.parse(result)
|
| 142 |
except KeyboardInterrupt:
|
|
|
|
| 113 |
model, tokenizer = prepare_model(model_name)
|
| 114 |
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=max_tokens, temperature=temperature)
|
| 115 |
hf = HuggingFacePipeline(pipeline=pipe)
|
| 116 |
+
# import pdb; pdb.set_trace()
|
| 117 |
chain = LLMChain(llm=hf, prompt=chat_prompt_template)
|
| 118 |
return chain
|
| 119 |
|
|
|
|
| 124 |
output_parser: BaseOutputParser[OutputType],
|
| 125 |
temperature: float = 0.7,
|
| 126 |
) -> tuple[OutputType, str]:
|
| 127 |
+
# import pdb; pdb.set_trace()
|
| 128 |
input_variables = re.findall(r"{(.*?)}", template)
|
| 129 |
assert (
|
| 130 |
set(input_variables) == set(list(input_values.keys()) + ["format_instructions"])
|
|
|
|
| 136 |
if "format_instructions" not in input_values:
|
| 137 |
input_values["format_instructions"] = output_parser.get_format_instructions()
|
| 138 |
result = chain.predict([], **input_values)
|
| 139 |
+
# import pdb; pdb.set_trace()
|
| 140 |
try:
|
| 141 |
parsed_result = output_parser.parse(result)
|
| 142 |
except KeyboardInterrupt:
|
utils.py
CHANGED
|
@@ -74,7 +74,7 @@ def truncate_dialogue_history_to_length(dia_his, surpass_num, tokenizer):
|
|
| 74 |
|
| 75 |
|
| 76 |
def format_bot_message(bot_message) -> str:
|
| 77 |
-
# import pdb; pdb.set_trace()
|
| 78 |
start_idx, end_idx = bot_message.index("{"), bot_message.index("}")
|
| 79 |
if end_idx == -1:
|
| 80 |
bot_message += "'}"
|
|
|
|
| 74 |
|
| 75 |
|
| 76 |
def format_bot_message(bot_message) -> str:
|
| 77 |
+
# # import pdb; pdb.set_trace()
|
| 78 |
start_idx, end_idx = bot_message.index("{"), bot_message.index("}")
|
| 79 |
if end_idx == -1:
|
| 80 |
bot_message += "'}"
|