Fix: Update chat_template in tokenizer_config for proper system prompts and prefill rendering
Browse files## Problem
There are two problems with the current version of chat_template.
### 1. System prompt gets dropped when you add assistant's prefill.
To understand the problem fully, let's first see what happens normally when system prompt is specified.
when input messages are
```
messages=[{"role": "system", "content": "Be short. Answer question within 10 words"},
{"role": "user", "content": "talk excessively, do you understand? Answer question using above 10 words"},
{"role": "assistant", "content": "OKay I understand"},
{"role": "user", "content": "Where is San Jose? Answer using more than 10 words"}]
```
serialized instruction is
```
<s>[INST] talk excessively, do you understand? Answer question using above 10 words[/INST] OKay I understand</s>[INST] Be short. Answer question within 10 words\n\nWhere is San Jose? Answer using more than 10 words[/INST]
```
system message is prepended to last turn's user question
However,
when input messages are
```
messages=[{"role": "system", "content": "Be short. Answer question within 10 words"},
{"role": "user", "content": "talk excessively, do you understand? Answer question using above 10 words"},
{"role": "assistant", "content": "OKay I understand"},
{"role": "user", "content": "Where is San Jose?"},
{"role": "assistant", "content": "San Jose"}] # prefill asisstant's response
```
serialized instruction is (use apply_chat_template with tokenize=False)
`<s>[INST] talk excessively, do you understand? Answer question using above 10 words[/INST] OKay I understand</s>[INST] Where is San Jose?[/INST] San Jose</s>`
Current output drops system prompt.
Expected instruction should be
`<s>[INST] talk excessively, do you understand? Answer question using above 10 words[/INST] OKay I understand</s>[INST] Be short. Answer question within 10 words\n\nWhere is San Jose?[/INST] San Jose</s>`
### 2. Prefill is not triggered correctly
Currently, when you try to prefill the assistant, for example, to complete the rest of the output "I'm sorry. I can't help because", the serialized instruction incorrectly contains extra EOS token.
For example,
when input messages are
```
messages=[{"role": "user", "content": "Where is San Jose?"},
{"role": "assistant", "content": "I'm sorry. I can't help because"}]
```
serialized instruction is
`<s>[INST] Where is San Jose?[/INST] I'm sorry. I can't help because</s>`
Current output has </s> attached to end of assistant's response.
Expected instruction should be to have </s> only attached to end of assistant response when it's not the last turn, i,e.
`<s>[INST] Where is San Jose?[/INST] I'm sorry. I can't help because`
This allows for proper prefill.
## Solution
Update chat_template in tokenizer_config.
## Testing
Besides the above 2 test cases, also this one:
```
messages=[{"role": "system", "content": "Be short. Answer question within 10 words"},
{"role": "user", "content": "talk excessively, do you understand? Answer question using above 10 words"},
{"role": "assistant", "content": "OKay I understand"},
{"role": "user", "content": "Where is San Jose?"},
{"role": "assistant", "content": "San Jose is in California"},
{"role": "user", "content": "Where is San Jose?"},
{"role": "assistant", "content": "San Jose "}]
```
expected output of apply_chat_template
`<s>[INST] talk excessively, do you understand? Answer question using above 10 words[/INST] OKay I understand</s>[INST] Where is San Jose?[/INST] San Jose is in California</s>[INST] Be short. Answer question within 10 words\n\nWhere is San Jose?[/INST] San Jose`
- tokenizer_config.json +1 -1
|
@@ -6173,7 +6173,7 @@
|
|
| 6173 |
}
|
| 6174 |
},
|
| 6175 |
"bos_token": "<s>",
|
| 6176 |
-
"chat_template": "{%- if messages[0][\"role\"] == \"system\" %}\n {%- set system_message = messages[0][\"content\"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n{%- set user_messages = loop_messages | selectattr(\"role\", \"equalto\", \"user\") | list %}\n\n{#- This block checks for alternating user/assistant messages, skipping tool calling messages #}\n{%- set ns = namespace() %}\n{%- set ns.index = 0 %}\n{%- for message in loop_messages %}\n {%- if not (message.role == \"tool\" or message.role == \"tool_results\" or (message.tool_calls is defined and message.tool_calls is not none)) %}\n {%- if (message[\"role\"] == \"user\") != (ns.index % 2 == 0) %}\n {{- raise_exception(\"After the optional system message, conversation roles must alternate user/assistant/user/assistant/...\") }}\n {%- endif %}\n {%- set ns.index = ns.index + 1 %}\n {%- endif %}\n{%- endfor %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if message[\"role\"] == \"user\" %}\n {%- if tools is not none and (message == user_messages[-1]) %}\n {{- \"[AVAILABLE_TOOLS] [\" }}\n {%- for tool in tools %}\n {%- set tool = tool.function %}\n {{- '{\"type\": \"function\", \"function\": {' }}\n {%- for key, val in tool.items() if key != \"return\" %}\n {%- if val is string %}\n {{- '\"' + key + '\": \"' + val + '\"' }}\n {%- else %}\n {{- '\"' + key + '\": ' + val|tojson }}\n {%- endif %}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- endif %}\n {%- endfor %}\n {{- \"}}\" }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" }}\n {%- endif %}\n {%- endfor %}\n {{- \"[/AVAILABLE_TOOLS]\" }}\n {%- endif %}\n {%- if loop.last and system_message is defined %}\n {{- \"[INST] \" + system_message + \"
|
| 6177 |
"clean_up_tokenization_spaces": false,
|
| 6178 |
"eos_token": "</s>",
|
| 6179 |
"legacy": false,
|
|
|
|
| 6173 |
}
|
| 6174 |
},
|
| 6175 |
"bos_token": "<s>",
|
| 6176 |
+
"chat_template": "{%- if messages[0][\"role\"] == \"system\" %}\n {%- set system_message = messages[0][\"content\"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n{%- set user_messages = loop_messages | selectattr(\"role\", \"equalto\", \"user\") | list %}\n\n{#- This block checks for alternating user/assistant messages, skipping tool calling messages #}\n{%- set ns = namespace() %}\n{%- set ns.index = 0 %}\n{%- for message in loop_messages %}\n {%- if not (message.role == \"tool\" or message.role == \"tool_results\" or (message.tool_calls is defined and message.tool_calls is not none)) %}\n {%- if (message[\"role\"] == \"user\") != (ns.index % 2 == 0) %}\n {{- raise_exception(\"After the optional system message, conversation roles must alternate user/assistant/user/assistant/...\") }}\n {%- endif %}\n {%- set ns.index = ns.index + 1 %}\n {%- endif %}\n{%- endfor %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if message[\"role\"] == \"user\" %}\n {%- if tools is not none and (message == user_messages[-1]) %}\n {{- \"[AVAILABLE_TOOLS] [\" }}\n {%- for tool in tools %}\n {%- set tool = tool.function %}\n {{- '{\"type\": \"function\", \"function\": {' }}\n {%- for key, val in tool.items() if key != \"return\" %}\n {%- if val is string %}\n {{- '\"' + key + '\": \"' + val + '\"' }}\n {%- else %}\n {{- '\"' + key + '\": ' + val|tojson }}\n {%- endif %}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- endif %}\n {%- endfor %}\n {{- \"}}\" }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" }}\n {%- endif %}\n {%- endfor %}\n {{- \"[/AVAILABLE_TOOLS]\" }}\n {%- endif %}\n {%- if (loop.last or loop.index == loop.length - 1) and system_message is defined %}\n {{- \"[INST] \" + system_message + \"\n\n\" + message[\"content\"] + \"[/INST]\" }}\n {%- else %}\n {{- \"[INST] \" + message[\"content\"] + \"[/INST]\" }}\n {%- endif %}\n {%- elif message.tool_calls is defined and message.tool_calls is not none %}\n {{- \"[TOOL_CALLS] [\" }}\n {%- for tool_call in message.tool_calls %}\n {%- set out = tool_call.function|tojson %}\n {{- out[:-1] }}\n {%- if not tool_call.id is defined or tool_call.id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- ', \"id\": \"' + tool_call.id + '\"}' }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" + eos_token }}\n {%- endif %}\n {%- endfor %}\n {%- elif message[\"role\"] == \"assistant\" %}\n {{- \" \" + message[\"content\"]|trim + (\"\" if loop.last else eos_token)}}\n {%- elif message[\"role\"] == \"tool_results\" or message[\"role\"] == \"tool\" %}\n {%- if message.content is defined and message.content.content is defined %}\n {%- set content = message.content.content %}\n {%- else %}\n {%- set content = message.content %}\n {%- endif %}\n {{- '[TOOL_RESULTS] {\"content\": ' + content|string + \", \" }}\n {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- '\"call_id\": \"' + message.tool_call_id + '\"}[/TOOL_RESULTS]' }}\n {%- else %}\n {{- raise_exception(\"Only user and assistant roles are supported, with the exception of an initial optional system message!\") }}\n {%- endif %}\n{%- endfor %}",
|
| 6177 |
"clean_up_tokenization_spaces": false,
|
| 6178 |
"eos_token": "</s>",
|
| 6179 |
"legacy": false,
|