Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Final test for YouTube question classification and tool selection | |
| """ | |
| from question_classifier import QuestionClassifier | |
| def test_classification(): | |
| """Test that our classification improvements for YouTube questions are working""" | |
| # Initialize classifier | |
| classifier = QuestionClassifier() | |
| # Test cases | |
| test_cases = [ | |
| { | |
| 'question': 'In the video https://www.youtube.com/watch?v=L1vXCYZAYYM, what is the highest number of bird species?', | |
| 'expected_agent': 'multimedia', | |
| 'expected_tool': 'analyze_youtube_video' | |
| }, | |
| { | |
| 'question': 'Tell me about the video at youtu.be/dQw4w9WgXcQ', | |
| 'expected_agent': 'multimedia', | |
| 'expected_tool': 'analyze_youtube_video' | |
| }, | |
| { | |
| 'question': 'What does Teal\'c say in the YouTube video youtube.com/watch?v=XYZ123?', | |
| 'expected_agent': 'multimedia', | |
| 'expected_tool': 'analyze_youtube_video' | |
| }, | |
| { | |
| 'question': 'How many birds appear in this image?', | |
| 'expected_agent': 'multimedia', | |
| 'expected_tool': 'analyze_image_with_gemini' | |
| }, | |
| { | |
| 'question': 'When was the first Star Wars movie released?', | |
| 'expected_agent': 'research', | |
| 'expected_tool': None | |
| } | |
| ] | |
| print("π§ͺ Testing Question Classification for YouTube Questions") | |
| print("=" * 70) | |
| passed = 0 | |
| for i, case in enumerate(test_cases): | |
| print(f"\nTest {i+1}: {case['question'][:80]}...") | |
| # Classify the question | |
| classification = classifier.classify_question(case['question']) | |
| # Check primary agent type | |
| agent_correct = classification['primary_agent'] == case['expected_agent'] | |
| # Check if expected tool is in tools list | |
| expected_tool = case['expected_tool'] | |
| if expected_tool: | |
| tool_correct = expected_tool in classification.get('tools_needed', []) | |
| else: | |
| # If no specific tool expected, just make sure analyze_youtube_video isn't | |
| # incorrectly selected for non-YouTube questions | |
| tool_correct = 'analyze_youtube_video' not in classification.get('tools_needed', []) or 'youtube' in case['question'].lower() | |
| # Print results | |
| print(f"Expected agent: {case['expected_agent']}") | |
| print(f"Actual agent: {classification['primary_agent']}") | |
| print(f"Agent match: {'β ' if agent_correct else 'β'}") | |
| print(f"Expected tool: {case['expected_tool']}") | |
| print(f"Selected tools: {classification.get('tools_needed', [])}") | |
| print(f"Tool match: {'β ' if tool_correct else 'β'}") | |
| # Check which tools were selected first | |
| tools = classification.get('tools_needed', []) | |
| if tools and 'youtube' in case['question'].lower(): | |
| if tools[0] == 'analyze_youtube_video': | |
| print("β analyze_youtube_video correctly prioritized for YouTube question") | |
| else: | |
| print("β analyze_youtube_video not prioritized for YouTube question") | |
| # Print overall result | |
| if agent_correct and tool_correct: | |
| passed += 1 | |
| print("β TEST PASSED") | |
| else: | |
| print("β TEST FAILED") | |
| # Print summary | |
| print("\n" + "=" * 70) | |
| print(f"Final result: {passed}/{len(test_cases)} tests passed") | |
| if passed == len(test_cases): | |
| print("π All tests passed! The classifier is working correctly.") | |
| else: | |
| print("β οΈ Some tests failed. Further improvements needed.") | |
| if __name__ == "__main__": | |
| test_classification() | |