| import unittest | |
| import tests.context | |
| from autogpt.token_counter import count_message_tokens, count_string_tokens | |
| class TestTokenCounter(unittest.TestCase): | |
| def test_count_message_tokens(self): | |
| messages = [ | |
| {"role": "user", "content": "Hello"}, | |
| {"role": "assistant", "content": "Hi there!"}, | |
| ] | |
| self.assertEqual(count_message_tokens(messages), 17) | |
| def test_count_message_tokens_with_name(self): | |
| messages = [ | |
| {"role": "user", "content": "Hello", "name": "John"}, | |
| {"role": "assistant", "content": "Hi there!"}, | |
| ] | |
| self.assertEqual(count_message_tokens(messages), 17) | |
| def test_count_message_tokens_empty_input(self): | |
| self.assertEqual(count_message_tokens([]), 3) | |
| def test_count_message_tokens_invalid_model(self): | |
| messages = [ | |
| {"role": "user", "content": "Hello"}, | |
| {"role": "assistant", "content": "Hi there!"}, | |
| ] | |
| with self.assertRaises(KeyError): | |
| count_message_tokens(messages, model="invalid_model") | |
| def test_count_message_tokens_gpt_4(self): | |
| messages = [ | |
| {"role": "user", "content": "Hello"}, | |
| {"role": "assistant", "content": "Hi there!"}, | |
| ] | |
| self.assertEqual(count_message_tokens(messages, model="gpt-4-0314"), 15) | |
| def test_count_string_tokens(self): | |
| string = "Hello, world!" | |
| self.assertEqual( | |
| count_string_tokens(string, model_name="gpt-3.5-turbo-0301"), 4 | |
| ) | |
| def test_count_string_tokens_empty_input(self): | |
| self.assertEqual(count_string_tokens("", model_name="gpt-3.5-turbo-0301"), 0) | |
| def test_count_message_tokens_invalid_model(self): | |
| messages = [ | |
| {"role": "user", "content": "Hello"}, | |
| {"role": "assistant", "content": "Hi there!"}, | |
| ] | |
| with self.assertRaises(NotImplementedError): | |
| count_message_tokens(messages, model="invalid_model") | |
| def test_count_string_tokens_gpt_4(self): | |
| string = "Hello, world!" | |
| self.assertEqual(count_string_tokens(string, model_name="gpt-4-0314"), 4) | |
| if __name__ == "__main__": | |
| unittest.main() | |