Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| """ | |
| Test client for the ChatGPT Oasis Model Inference API | |
| """ | |
| import requests | |
| import base64 | |
| import json | |
| from PIL import Image | |
| import io | |
| import os | |
| # API base URL | |
| BASE_URL = "http://localhost:8000" | |
| def test_health_check(): | |
| """Test the health check endpoint""" | |
| print("Testing health check...") | |
| try: | |
| response = requests.get(f"{BASE_URL}/health") | |
| print(f"Status: {response.status_code}") | |
| print(f"Response: {json.dumps(response.json(), indent=2)}") | |
| return response.status_code == 200 | |
| except Exception as e: | |
| print(f"Error: {e}") | |
| return False | |
| def test_list_models(): | |
| """Test the models list endpoint""" | |
| print("\nTesting models list...") | |
| try: | |
| response = requests.get(f"{BASE_URL}/models") | |
| print(f"Status: {response.status_code}") | |
| print(f"Response: {json.dumps(response.json(), indent=2)}") | |
| return response.status_code == 200 | |
| except Exception as e: | |
| print(f"Error: {e}") | |
| return False | |
| def create_test_image(): | |
| """Create a simple test image""" | |
| # Create a simple colored rectangle | |
| img = Image.new('RGB', (224, 224), color='red') | |
| # Save to bytes | |
| buffer = io.BytesIO() | |
| img.save(buffer, format='JPEG') | |
| buffer.seek(0) | |
| return buffer.getvalue() | |
| def test_base64_inference(): | |
| """Test inference with base64 encoded image""" | |
| print("\nTesting base64 inference...") | |
| # Create test image | |
| image_data = create_test_image() | |
| image_base64 = base64.b64encode(image_data).decode() | |
| # Test both models | |
| for model_name in ["oasis500m", "vit-l-20"]: | |
| print(f"\nTesting {model_name}...") | |
| try: | |
| response = requests.post( | |
| f"{BASE_URL}/inference", | |
| json={ | |
| "image": image_base64, | |
| "model_name": model_name | |
| }, | |
| headers={"Content-Type": "application/json"} | |
| ) | |
| print(f"Status: {response.status_code}") | |
| if response.status_code == 200: | |
| result = response.json() | |
| print(f"Model used: {result['model_used']}") | |
| print(f"Top prediction: {result['predictions'][0]}") | |
| else: | |
| print(f"Error: {response.text}") | |
| except Exception as e: | |
| print(f"Error: {e}") | |
| def test_file_upload_inference(): | |
| """Test inference with file upload""" | |
| print("\nTesting file upload inference...") | |
| # Create test image | |
| image_data = create_test_image() | |
| # Test both models | |
| for model_name in ["oasis500m", "vit-l-20"]: | |
| print(f"\nTesting {model_name} with file upload...") | |
| try: | |
| files = {'file': ('test_image.jpg', image_data, 'image/jpeg')} | |
| data = {'model_name': model_name} | |
| response = requests.post( | |
| f"{BASE_URL}/upload_inference", | |
| files=files, | |
| data=data | |
| ) | |
| print(f"Status: {response.status_code}") | |
| if response.status_code == 200: | |
| result = response.json() | |
| print(f"Model used: {result['model_used']}") | |
| print(f"Top prediction: {result['predictions'][0]}") | |
| else: | |
| print(f"Error: {response.text}") | |
| except Exception as e: | |
| print(f"Error: {e}") | |
| def test_with_real_image(image_path): | |
| """Test with a real image file""" | |
| if not os.path.exists(image_path): | |
| print(f"Image file not found: {image_path}") | |
| return | |
| print(f"\nTesting with real image: {image_path}") | |
| # Test file upload | |
| try: | |
| with open(image_path, 'rb') as f: | |
| files = {'file': (os.path.basename(image_path), f, 'image/jpeg')} | |
| data = {'model_name': 'oasis500m'} | |
| response = requests.post( | |
| f"{BASE_URL}/upload_inference", | |
| files=files, | |
| data=data | |
| ) | |
| print(f"Status: {response.status_code}") | |
| if response.status_code == 200: | |
| result = response.json() | |
| print(f"Model used: {result['model_used']}") | |
| print("Top 3 predictions:") | |
| for i, pred in enumerate(result['predictions'][:3]): | |
| print(f" {i+1}. {pred['label']} ({pred['confidence']:.3f})") | |
| else: | |
| print(f"Error: {response.text}") | |
| except Exception as e: | |
| print(f"Error: {e}") | |
| def main(): | |
| """Run all tests""" | |
| print("ChatGPT Oasis Model Inference API - Test Client") | |
| print("=" * 50) | |
| # Test basic endpoints | |
| health_ok = test_health_check() | |
| models_ok = test_list_models() | |
| if not health_ok: | |
| print("Health check failed. Make sure the server is running!") | |
| return | |
| # Test inference endpoints | |
| test_base64_inference() | |
| test_file_upload_inference() | |
| # Test with real image if available | |
| test_images = ["test.jpg", "sample.jpg", "image.jpg"] | |
| for img in test_images: | |
| if os.path.exists(img): | |
| test_with_real_image(img) | |
| break | |
| print("\n" + "=" * 50) | |
| print("Test completed!") | |
| if __name__ == "__main__": | |
| main() | |