stop processing history
Browse files- utils/swe_bench.py +0 -55
utils/swe_bench.py
CHANGED
|
@@ -9,49 +9,6 @@ def clean_git_patch(git_patch):
|
|
| 9 |
git_patch = git_patch[git_patch.index('diff'):]
|
| 10 |
return git_patch
|
| 11 |
|
| 12 |
-
def reformat_history(history):
|
| 13 |
-
new_history = []
|
| 14 |
-
cur_turn = []
|
| 15 |
-
for i, (action, observation) in enumerate(history):
|
| 16 |
-
|
| 17 |
-
# Compatibility mode: old format before refractor
|
| 18 |
-
if 'source' not in action:
|
| 19 |
-
return history
|
| 20 |
-
|
| 21 |
-
if i == 0:
|
| 22 |
-
assert action['action'] == 'message'
|
| 23 |
-
assert action['source'] == 'user'
|
| 24 |
-
# skip the initial instruction
|
| 25 |
-
continue
|
| 26 |
-
|
| 27 |
-
if action['source'] == 'agent':
|
| 28 |
-
# cleanup all previous turns
|
| 29 |
-
if len(cur_turn) == 1:
|
| 30 |
-
new_history.append(cur_turn[0])
|
| 31 |
-
elif len(cur_turn) == 2:
|
| 32 |
-
# one action from user, one action from agent
|
| 33 |
-
agent_msg_action, agent_msg_obs = cur_turn[0]
|
| 34 |
-
assert agent_msg_obs['observation'] == 'null'
|
| 35 |
-
user_msg_action, user_msg_obs = cur_turn[1]
|
| 36 |
-
assert user_msg_obs['observation'] == 'null'
|
| 37 |
-
# re-write user message to be a observation message
|
| 38 |
-
user_msg_action_as_obs = {
|
| 39 |
-
'observation': 'message',
|
| 40 |
-
'source': 'user',
|
| 41 |
-
'content': user_msg_action['args']['content'],
|
| 42 |
-
}
|
| 43 |
-
new_history.append((agent_msg_action, user_msg_action_as_obs))
|
| 44 |
-
elif len(cur_turn) == 0:
|
| 45 |
-
pass
|
| 46 |
-
else:
|
| 47 |
-
st.write(f'Unsupported #interactions per iteration: {len(cur_turn)}')
|
| 48 |
-
st.json(cur_turn)
|
| 49 |
-
raise ValueError(f'Unsupported #interactions per iteration: {len(cur_turn)}')
|
| 50 |
-
|
| 51 |
-
# reset new turn
|
| 52 |
-
cur_turn = []
|
| 53 |
-
cur_turn.append((action, observation))
|
| 54 |
-
return new_history
|
| 55 |
|
| 56 |
def _load_report_legacy(instance_id_to_status, report):
|
| 57 |
# instance_id to status
|
|
@@ -103,7 +60,6 @@ def load_df_from_selected_filepaths(select_filepaths):
|
|
| 103 |
# clear out git patch
|
| 104 |
if 'git_patch' in d:
|
| 105 |
d['git_patch'] = clean_git_patch(d['git_patch'])
|
| 106 |
-
d['history'] = reformat_history(d['history'])
|
| 107 |
if d['instance_id'] in instance_id_to_status:
|
| 108 |
d['fine_grained_report'] = dict(instance_id_to_status[d['instance_id']])
|
| 109 |
data.append(d)
|
|
@@ -139,13 +95,6 @@ def agg_stats(df):
|
|
| 139 |
test_result['test_errored'] = bool(entry['report'].get('test_errored', False))
|
| 140 |
test_result['patch_applied'] = bool(entry['report'].get('apply_test_patch_success', False))
|
| 141 |
|
| 142 |
-
# avg,std obs length
|
| 143 |
-
obs_lengths = []
|
| 144 |
-
for _, (_, obs) in enumerate(history):
|
| 145 |
-
if 'content' in obs:
|
| 146 |
-
obs_lengths.append(len(obs['content']))
|
| 147 |
-
obs_lengths = pd.Series(obs_lengths)
|
| 148 |
-
|
| 149 |
metrics = entry.get('metrics', {})
|
| 150 |
cost = metrics.get('accumulated_cost', None)
|
| 151 |
|
|
@@ -154,14 +103,10 @@ def agg_stats(df):
|
|
| 154 |
'instance_id': entry['instance_id'],
|
| 155 |
'agent_class': entry['metadata']['agent_class'],
|
| 156 |
'model_name': entry['metadata']['llm_config']['model'] if 'llm_config' in entry['metadata'] else entry['metadata']['model_name'],
|
| 157 |
-
'n_turns': len(history),
|
| 158 |
**test_result,
|
| 159 |
'agent_stuck_in_loop': agent_stuck_in_loop,
|
| 160 |
'contains_error': contains_error,
|
| 161 |
'cost': cost,
|
| 162 |
-
'obs_len_avg': round(obs_lengths.mean(), 0),
|
| 163 |
-
'obs_len_std': round(obs_lengths.std(), 0),
|
| 164 |
-
'obs_len_max': round(obs_lengths.max(), 0),
|
| 165 |
}
|
| 166 |
if 'swe_instance' in entry:
|
| 167 |
d.update(
|
|
|
|
| 9 |
git_patch = git_patch[git_patch.index('diff'):]
|
| 10 |
return git_patch
|
| 11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
def _load_report_legacy(instance_id_to_status, report):
|
| 14 |
# instance_id to status
|
|
|
|
| 60 |
# clear out git patch
|
| 61 |
if 'git_patch' in d:
|
| 62 |
d['git_patch'] = clean_git_patch(d['git_patch'])
|
|
|
|
| 63 |
if d['instance_id'] in instance_id_to_status:
|
| 64 |
d['fine_grained_report'] = dict(instance_id_to_status[d['instance_id']])
|
| 65 |
data.append(d)
|
|
|
|
| 95 |
test_result['test_errored'] = bool(entry['report'].get('test_errored', False))
|
| 96 |
test_result['patch_applied'] = bool(entry['report'].get('apply_test_patch_success', False))
|
| 97 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
metrics = entry.get('metrics', {})
|
| 99 |
cost = metrics.get('accumulated_cost', None)
|
| 100 |
|
|
|
|
| 103 |
'instance_id': entry['instance_id'],
|
| 104 |
'agent_class': entry['metadata']['agent_class'],
|
| 105 |
'model_name': entry['metadata']['llm_config']['model'] if 'llm_config' in entry['metadata'] else entry['metadata']['model_name'],
|
|
|
|
| 106 |
**test_result,
|
| 107 |
'agent_stuck_in_loop': agent_stuck_in_loop,
|
| 108 |
'contains_error': contains_error,
|
| 109 |
'cost': cost,
|
|
|
|
|
|
|
|
|
|
| 110 |
}
|
| 111 |
if 'swe_instance' in entry:
|
| 112 |
d.update(
|