Spaces:
Sleeping
Sleeping
| import asyncio | |
| import pytest | |
| from lerobot_arena_client import RoboticsConsumer | |
| class TestRoboticsConsumer: | |
| """Test RoboticsConsumer functionality.""" | |
| async def test_consumer_connection(self, consumer, test_room): | |
| """Test basic consumer connection.""" | |
| assert not consumer.is_connected() | |
| success = await consumer.connect(test_room) | |
| assert success is True | |
| assert consumer.is_connected() | |
| assert consumer.room_id == test_room | |
| assert consumer.role == "consumer" | |
| await consumer.disconnect() | |
| assert not consumer.is_connected() | |
| async def test_consumer_connection_info(self, connected_consumer): | |
| """Test getting connection information.""" | |
| consumer, room_id = connected_consumer | |
| info = consumer.get_connection_info() | |
| assert info["connected"] is True | |
| assert info["room_id"] == room_id | |
| assert info["role"] == "consumer" | |
| assert info["participant_id"] is not None | |
| assert info["base_url"] == "http://localhost:8000" | |
| async def test_get_state_sync(self, connected_consumer): | |
| """Test getting current state synchronously.""" | |
| consumer, room_id = connected_consumer | |
| state = await consumer.get_state_sync() | |
| assert isinstance(state, dict) | |
| # Initial state should be empty | |
| assert len(state) == 0 | |
| async def test_consumer_callbacks_setup(self, consumer, test_room): | |
| """Test setting up consumer callbacks.""" | |
| state_sync_called = False | |
| joint_update_called = False | |
| error_called = False | |
| connected_called = False | |
| disconnected_called = False | |
| def on_state_sync(state): | |
| nonlocal state_sync_called | |
| state_sync_called = True | |
| def on_joint_update(joints): | |
| nonlocal joint_update_called | |
| joint_update_called = True | |
| def on_error(error): | |
| nonlocal error_called | |
| error_called = True | |
| def on_connected(): | |
| nonlocal connected_called | |
| connected_called = True | |
| def on_disconnected(): | |
| nonlocal disconnected_called | |
| disconnected_called = True | |
| # Set all callbacks | |
| consumer.on_state_sync(on_state_sync) | |
| consumer.on_joint_update(on_joint_update) | |
| consumer.on_error(on_error) | |
| consumer.on_connected(on_connected) | |
| consumer.on_disconnected(on_disconnected) | |
| # Connect and test connection callbacks | |
| await consumer.connect(test_room) | |
| await asyncio.sleep(0.1) | |
| assert connected_called is True | |
| await consumer.disconnect() | |
| await asyncio.sleep(0.1) | |
| assert disconnected_called is True | |
| async def test_multiple_consumers(self, test_room): | |
| """Test multiple consumers connecting to same room.""" | |
| consumer1 = RoboticsConsumer("http://localhost:8000") | |
| consumer2 = RoboticsConsumer("http://localhost:8000") | |
| try: | |
| # Both consumers should be able to connect | |
| success1 = await consumer1.connect(test_room) | |
| success2 = await consumer2.connect(test_room) | |
| assert success1 is True | |
| assert success2 is True | |
| assert consumer1.is_connected() | |
| assert consumer2.is_connected() | |
| finally: | |
| if consumer1.is_connected(): | |
| await consumer1.disconnect() | |
| if consumer2.is_connected(): | |
| await consumer2.disconnect() | |
| async def test_consumer_receive_state_sync(self, producer_consumer_pair): | |
| """Test consumer receiving state sync from producer.""" | |
| producer, consumer, room_id = producer_consumer_pair | |
| received_states = [] | |
| received_updates = [] | |
| def on_state_sync(state): | |
| received_states.append(state) | |
| def on_joint_update(joints): | |
| received_updates.append(joints) | |
| consumer.on_state_sync(on_state_sync) | |
| consumer.on_joint_update(on_joint_update) | |
| # Give some time for connection to stabilize | |
| await asyncio.sleep(0.1) | |
| # Producer sends state sync (which gets converted to joint updates) | |
| await producer.send_state_sync({"shoulder": 45.0, "elbow": -20.0}) | |
| # Wait for message to be received | |
| await asyncio.sleep(0.2) | |
| # Consumer should have received the joint updates from the state sync | |
| # The initial state sync during connection might be empty, so we check for joint updates | |
| assert len(received_updates) >= 1 | |
| async def test_consumer_receive_joint_updates(self, producer_consumer_pair): | |
| """Test consumer receiving joint updates from producer.""" | |
| producer, consumer, room_id = producer_consumer_pair | |
| received_updates = [] | |
| def on_joint_update(joints): | |
| received_updates.append(joints) | |
| consumer.on_joint_update(on_joint_update) | |
| # Give some time for connection to stabilize | |
| await asyncio.sleep(0.1) | |
| # Producer sends joint updates | |
| test_joints = [ | |
| {"name": "shoulder", "value": 45.0}, | |
| {"name": "elbow", "value": -20.0}, | |
| ] | |
| await producer.send_joint_update(test_joints) | |
| # Wait for message to be received | |
| await asyncio.sleep(0.2) | |
| # Consumer should have received the joint update | |
| assert len(received_updates) >= 1 | |
| if received_updates: | |
| received_joints = received_updates[-1] | |
| assert isinstance(received_joints, list) | |
| assert len(received_joints) == 2 | |
| async def test_consumer_multiple_updates(self, producer_consumer_pair): | |
| """Test consumer receiving multiple updates.""" | |
| producer, consumer, room_id = producer_consumer_pair | |
| received_updates = [] | |
| def on_joint_update(joints): | |
| received_updates.append(joints) | |
| consumer.on_joint_update(on_joint_update) | |
| # Give some time for connection to stabilize | |
| await asyncio.sleep(0.1) | |
| # Send multiple updates | |
| for i in range(5): | |
| await producer.send_state_sync({ | |
| "joint1": float(i * 10), | |
| "joint2": float(i * -5), | |
| }) | |
| await asyncio.sleep(0.05) | |
| # Wait for all messages to be received | |
| await asyncio.sleep(0.3) | |
| # Should have received multiple updates | |
| assert len(received_updates) >= 3 | |
| async def test_consumer_emergency_stop(self, producer_consumer_pair): | |
| """Test consumer receiving emergency stop.""" | |
| producer, consumer, room_id = producer_consumer_pair | |
| received_errors = [] | |
| def on_error(error): | |
| received_errors.append(error) | |
| consumer.on_error(on_error) | |
| # Give some time for connection to stabilize | |
| await asyncio.sleep(0.1) | |
| # Producer sends emergency stop | |
| await producer.send_emergency_stop("Test emergency stop") | |
| # Wait for message to be received | |
| await asyncio.sleep(0.2) | |
| # Consumer should have received emergency stop as error | |
| assert len(received_errors) >= 1 | |
| if received_errors: | |
| assert "emergency stop" in received_errors[-1].lower() | |
| async def test_custom_participant_id(self, consumer, test_room): | |
| """Test connecting with custom participant ID.""" | |
| custom_id = "custom-consumer-456" | |
| await consumer.connect(test_room, participant_id=custom_id) | |
| info = consumer.get_connection_info() | |
| assert info["participant_id"] == custom_id | |
| async def test_context_manager(self, test_room): | |
| """Test using consumer as context manager.""" | |
| async with RoboticsConsumer("http://localhost:8000") as consumer: | |
| await consumer.connect(test_room) | |
| assert consumer.is_connected() | |
| state = await consumer.get_state_sync() | |
| assert isinstance(state, dict) | |
| # Should be disconnected after context exit | |
| assert not consumer.is_connected() | |
| async def test_get_state_without_connection(self, consumer): | |
| """Test getting state without being connected.""" | |
| assert not consumer.is_connected() | |
| with pytest.raises(ValueError, match="Must be connected to a room"): | |
| await consumer.get_state_sync() | |
| async def test_consumer_reconnection(self, consumer, test_room): | |
| """Test consumer reconnecting to same room.""" | |
| # First connection | |
| await consumer.connect(test_room) | |
| assert consumer.is_connected() | |
| await consumer.disconnect() | |
| assert not consumer.is_connected() | |
| # Reconnect to same room | |
| await consumer.connect(test_room) | |
| assert consumer.is_connected() | |
| assert consumer.room_id == test_room | |
| async def test_consumer_state_after_producer_updates(self, producer_consumer_pair): | |
| """Test that consumer can get updated state after producer sends updates.""" | |
| producer, consumer, room_id = producer_consumer_pair | |
| # Give some time for connection to stabilize | |
| await asyncio.sleep(0.1) | |
| # Producer sends some state updates | |
| await producer.send_state_sync({ | |
| "shoulder": 45.0, | |
| "elbow": -20.0, | |
| "wrist": 10.0, | |
| }) | |
| # Wait for state to propagate | |
| await asyncio.sleep(0.2) | |
| # Consumer should be able to get updated state | |
| state = await consumer.get_state_sync() | |
| assert isinstance(state, dict) | |
| # State should contain the joints we sent | |
| expected_joints = {"shoulder", "elbow", "wrist"} | |
| if state: # Only check if state is not empty | |
| assert set(state.keys()) == expected_joints | |