import asyncio from unittest.mock import AsyncMock, MagicMock, patch import pytest import redis.asyncio as redis from a2a_distributed_queue import ( DistributedEventQueue, DistributedQueueManager, ) from a2a.server.events.event_queue import EventQueue from a2a.server.events.queue_manager import ( NoTaskQueue, TaskQueueExists, ) from a2a.types import Message, Task, TaskStatusUpdateEvent MINIMAL_TASK = { 'id': '123', 'context_id': 'session-xyz', 'status': {'state': 'submitted'}, 'kind': 'task', } MESSAGE_PAYLOAD = { 'role': 'agent', 'parts': [{'text': 'test message'}], 'message_id': '111', } @pytest.fixture async def redis_client(): """Fixture to create a Redis client for testing.""" # Use a test Redis instance or fakeredis client = redis.from_url("redis://localhost:6379", decode_responses=True) yield client # Clean up await client.flushall() @pytest.fixture def distributed_queue_manager(redis_client): """Fixture to create a DistributedQueueManager.""" return DistributedQueueManager(redis_url="redis://localhost:6379") @pytest.fixture def distributed_event_queue(redis_client): """Fixture to create a DistributedEventQueue.""" return DistributedEventQueue("test_task", redis_client) class TestDistributedEventQueue: @pytest.mark.asyncio async def test_init(self, distributed_event_queue): """Test that DistributedEventQueue initializes correctly.""" assert distributed_event_queue.task_id == "test_task" assert distributed_event_queue.redis is not None assert distributed_event_queue.channel == "task:test_task" assert not distributed_event_queue.is_closed() @pytest.mark.asyncio async def test_enqueue_and_dequeue_event(self, distributed_event_queue): """Test enqueue and dequeue of an event.""" event = Message(**MESSAGE_PAYLOAD) await distributed_event_queue.enqueue_event(event) dequeued = await distributed_event_queue.dequeue_event(no_wait=True) assert dequeued.role == event.role assert dequeued.parts == event.parts @pytest.mark.asyncio async def test_dequeue_event_no_wait_empty(self, distributed_event_queue): """Test dequeue with no_wait on empty queue raises QueueEmpty.""" with pytest.raises(asyncio.QueueEmpty): await distributed_event_queue.dequeue_event(no_wait=True) @pytest.mark.asyncio async def test_tap_creates_child(self, distributed_event_queue): """Test that tap creates a child queue.""" child = distributed_event_queue.tap() assert isinstance(child, DistributedEventQueue) assert child in distributed_event_queue._children @pytest.mark.asyncio async def test_close(self, distributed_event_queue): """Test closing the queue.""" await distributed_event_queue.close() assert distributed_event_queue.is_closed() @pytest.mark.asyncio async def test_close_immediate(self, distributed_event_queue): """Test closing with immediate=True.""" event = Message(**MESSAGE_PAYLOAD) await distributed_event_queue.enqueue_event(event) await distributed_event_queue.close(immediate=True) assert distributed_event_queue.is_closed() class TestDistributedQueueManager: @pytest.mark.asyncio async def test_init(self, distributed_queue_manager): """Test DistributedQueueManager initializes correctly.""" assert distributed_queue_manager.redis is not None assert distributed_queue_manager.tasks_key == "task_queues" @pytest.mark.asyncio async def test_add_new_queue(self, distributed_queue_manager): """Test adding a new queue.""" task_id = "test_task" queue = DistributedEventQueue(task_id, distributed_queue_manager.redis) await distributed_queue_manager.add(task_id, queue) # Check if task_id is in Redis set exists = await distributed_queue_manager.redis.sismember("task_queues", task_id) assert exists @pytest.mark.asyncio async def test_add_existing_queue_raises(self, distributed_queue_manager): """Test adding existing queue raises TaskQueueExists.""" task_id = "test_task" queue = DistributedEventQueue(task_id, distributed_queue_manager.redis) await distributed_queue_manager.add(task_id, queue) with pytest.raises(TaskQueueExists): await distributed_queue_manager.add(task_id, queue) @pytest.mark.asyncio async def test_get_existing_queue(self, distributed_queue_manager): """Test getting an existing queue.""" task_id = "test_task" await distributed_queue_manager.redis.sadd("task_queues", task_id) queue = await distributed_queue_manager.get(task_id) assert isinstance(queue, DistributedEventQueue) assert queue.task_id == task_id @pytest.mark.asyncio async def test_get_nonexistent_queue(self, distributed_queue_manager): """Test getting nonexistent queue returns None.""" queue = await distributed_queue_manager.get("nonexistent") assert queue is None @pytest.mark.asyncio async def test_tap_existing_queue(self, distributed_queue_manager): """Test tapping an existing queue.""" task_id = "test_task" await distributed_queue_manager.redis.sadd("task_queues", task_id) tapped = await distributed_queue_manager.tap(task_id) assert isinstance(tapped, DistributedEventQueue) @pytest.mark.asyncio async def test_tap_nonexistent_queue(self, distributed_queue_manager): """Test tapping nonexistent queue returns None.""" tapped = await distributed_queue_manager.tap("nonexistent") assert tapped is None @pytest.mark.asyncio async def test_close_existing_queue(self, distributed_queue_manager): """Test closing an existing queue.""" task_id = "test_task" await distributed_queue_manager.redis.sadd("task_queues", task_id) await distributed_queue_manager.close(task_id) exists = await distributed_queue_manager.redis.sismember("task_queues", task_id) assert not exists @pytest.mark.asyncio async def test_close_nonexistent_queue_raises(self, distributed_queue_manager): """Test closing nonexistent queue raises NoTaskQueue.""" with pytest.raises(NoTaskQueue): await distributed_queue_manager.close("nonexistent") @pytest.mark.asyncio async def test_create_or_tap_new(self, distributed_queue_manager): """Test create_or_tap for new task.""" task_id = "test_task" queue = await distributed_queue_manager.create_or_tap(task_id) assert isinstance(queue, DistributedEventQueue) exists = await distributed_queue_manager.redis.sismember("task_queues", task_id) assert exists @pytest.mark.asyncio async def test_create_or_tap_existing(self, distributed_queue_manager): """Test create_or_tap for existing task taps.""" task_id = "test_task" await distributed_queue_manager.redis.sadd("task_queues", task_id) queue = await distributed_queue_manager.create_or_tap(task_id) assert isinstance(queue, DistributedEventQueue) @pytest.mark.asyncio async def test_concurrency(self, distributed_queue_manager): """Test concurrent access.""" async def add_task(task_id): queue = DistributedEventQueue(task_id, distributed_queue_manager.redis) await distributed_queue_manager.add(task_id, queue) return task_id async def get_task(task_id): return await distributed_queue_manager.get(task_id) task_ids = [f"task_{i}" for i in range(10)] # Add concurrently add_tasks = [add_task(tid) for tid in task_ids] added = await asyncio.gather(*add_tasks) assert set(added) == set(task_ids) # Get concurrently get_tasks = [get_task(tid) for tid in task_ids] queues = await asyncio.gather(*get_tasks) assert all(isinstance(q, DistributedEventQueue) for q in queues)