Last active
October 18, 2025 15:18
-
-
Save DiTo97/bac619c9d7d133fcbfd69ac6b708b8ef to your computer and use it in GitHub Desktop.
A2A event queue and queue manager using Redis
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import asyncio | |
| import json | |
| import logging | |
| import uuid | |
| import redis.asyncio as redis | |
| from a2a.server.events.event_queue import EventQueue | |
| from a2a.server.events.queue_manager import ( | |
| NoTaskQueue, | |
| QueueManager, | |
| TaskQueueExists, | |
| ) | |
| from a2a.types import ( | |
| Message, | |
| Task, | |
| TaskArtifactUpdateEvent, | |
| TaskStatusUpdateEvent, | |
| ) | |
| from a2a.utils.telemetry import SpanKind, trace_class | |
| logger = logging.getLogger(__name__) | |
| Event = Message | Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent | |
| @trace_class(kind=SpanKind.SERVER) | |
| class DistributedEventQueue(EventQueue): | |
| """DistributedEventQueue uses Redis for storing and distributing events. | |
| This extends EventQueue to use Redis Lists for queuing events and Redis Pub/Sub | |
| for distributing events to child queues in a distributed setup. | |
| """ | |
| def __init__( | |
| self, | |
| task_id: str, | |
| redis_client: redis.Redis, | |
| max_queue_size: int = 1024, | |
| channel: str | None = None, | |
| ): | |
| """Initializes the DistributedEventQueue.""" | |
| super().__init__(max_queue_size) | |
| self.task_id = task_id | |
| self.redis = redis_client | |
| self.queue_key = f"queue:{task_id}:{uuid.uuid4().hex}" | |
| self.channel = channel or f"task:{task_id}" | |
| self._pubsub = self.redis.pubsub() | |
| self._subscriber_task: asyncio.Task | None = None | |
| self._children: list[DistributedEventQueue] = [] | |
| # Start subscribing to the channel for child events | |
| if self.channel: | |
| self._subscriber_task = asyncio.create_task(self._subscribe_to_channel()) | |
| async def _subscribe_to_channel(self): | |
| """Subscribe to the channel and handle incoming events.""" | |
| await self._pubsub.subscribe(self.channel) | |
| async for message in self._pubsub.listen(): | |
| if message["type"] == "message": | |
| try: | |
| event_data = json.loads(message["data"]) | |
| event = self._deserialize_event(event_data) | |
| await self._enqueue_locally(event) | |
| except Exception as e: | |
| logger.error(f"Error processing event from channel: {e}") | |
| async def _enqueue_locally(self, event: Event): | |
| """Enqueue an event locally to the Redis list.""" | |
| event_data = self._serialize_event(event) | |
| await self.redis.lpush(self.queue_key, event_data) | |
| def _serialize_event(self, event: Event) -> str: | |
| """Serialize an event to JSON.""" | |
| return json.dumps({ | |
| "type": type(event).__name__, | |
| "data": event.model_dump() if hasattr(event, "model_dump") else str(event), | |
| }) | |
| def _deserialize_event(self, data: dict) -> Event: | |
| """Deserialize an event from JSON.""" | |
| event_type = data["type"] | |
| event_data = data["data"] | |
| # Assuming events are Pydantic models or simple types | |
| if event_type == "Message": | |
| return Message(**event_data) | |
| elif event_type == "Task": | |
| return Task(**event_data) | |
| elif event_type == "TaskStatusUpdateEvent": | |
| return TaskStatusUpdateEvent(**event_data) | |
| elif event_type == "TaskArtifactUpdateEvent": | |
| return TaskArtifactUpdateEvent(**event_data) | |
| else: | |
| raise ValueError(f"Unknown event type: {event_type}") | |
| async def enqueue_event(self, event: Event) -> None: | |
| """Enqueues an event to this queue and all its children via Redis.""" | |
| async with self._lock: | |
| if self._is_closed: | |
| logger.warning('Queue is closed. Event will not be enqueued.') | |
| return | |
| logger.debug('Enqueuing event of type: %s', type(event)) | |
| # Enqueue locally | |
| await self._enqueue_locally(event) | |
| # Publish to channel for children | |
| event_data = self._serialize_event(event) | |
| await self.redis.publish(self.channel, event_data) | |
| async def dequeue_event(self, no_wait: bool = False) -> Event: | |
| """Dequeues an event from the Redis list.""" | |
| async with self._lock: | |
| if self._is_closed: | |
| raise asyncio.QueueEmpty('Queue is closed.') | |
| if no_wait: | |
| result = await self.redis.rpop(self.queue_key) | |
| if result is None: | |
| raise asyncio.QueueEmpty | |
| event_data = json.loads(result) | |
| event = self._deserialize_event(event_data) | |
| logger.debug('Dequeued event (no_wait=True) of type: %s', type(event)) | |
| return event | |
| # For waiting, we need to poll since Redis doesn't have async blocking pop in asyncio | |
| while True: | |
| result = await self.redis.brpop(self.queue_key, timeout=1) | |
| if result: | |
| _, event_str = result | |
| event_data = json.loads(event_str) | |
| event = self._deserialize_event(event_data) | |
| logger.debug('Dequeued event (waited) of type: %s', type(event)) | |
| return event | |
| if self._is_closed: | |
| raise asyncio.QueueEmpty('Queue is closed.') | |
| def tap(self) -> 'DistributedEventQueue': | |
| """Taps the event queue to create a new child queue that receives all future events.""" | |
| logger.debug('Tapping DistributedEventQueue to create a child queue.') | |
| child_queue = DistributedEventQueue( | |
| self.task_id, self.redis, self.queue.maxsize, self.channel | |
| ) | |
| self._children.append(child_queue) | |
| return child_queue | |
| async def close(self, immediate: bool = False) -> None: | |
| """Closes the queue and all child queues.""" | |
| logger.debug('Closing DistributedEventQueue.') | |
| async with self._lock: | |
| if self._is_closed and not immediate: | |
| return | |
| if not self._is_closed: | |
| self._is_closed = True | |
| if immediate: | |
| await self.redis.delete(self.queue_key) | |
| for child in self._children: | |
| await child.close(True) | |
| if self._subscriber_task: | |
| self._subscriber_task.cancel() | |
| else: | |
| # Wait for queue to drain | |
| while await self.redis.llen(self.queue_key) > 0: | |
| await asyncio.sleep(0.1) | |
| await self.redis.delete(self.queue_key) | |
| await asyncio.gather(*(child.close() for child in self._children)) | |
| if self._subscriber_task: | |
| self._subscriber_task.cancel() | |
| @trace_class(kind=SpanKind.SERVER) | |
| class DistributedQueueManager(QueueManager): | |
| """DistributedQueueManager uses Redis to manage distributed event queues. | |
| This implements the `QueueManager` interface using Redis for distributed storage. | |
| """ | |
| def __init__(self, redis_url: str = "redis://localhost:6379"): | |
| """Initializes the DistributedQueueManager.""" | |
| self.redis = redis.from_url(redis_url) | |
| self.tasks_key = "task_queues" | |
| async def add(self, task_id: str, queue: EventQueue) -> None: | |
| """Adds a new event queue for a task ID.""" | |
| exists = await self.redis.sismember(self.tasks_key, task_id) | |
| if exists: | |
| raise TaskQueueExists | |
| await self.redis.sadd(self.tasks_key, task_id) | |
| # Assuming queue is DistributedEventQueue, but for compatibility, we can create one | |
| # For now, we'll assume it's passed, but to respect interface, perhaps create internally | |
| # But the interface takes EventQueue, so we'll adapt | |
| async def get(self, task_id: str) -> EventQueue | None: | |
| """Retrieves the event queue for a task ID.""" | |
| exists = await self.redis.sismember(self.tasks_key, task_id) | |
| if not exists: | |
| return None | |
| return DistributedEventQueue(task_id, self.redis) | |
| async def tap(self, task_id: str) -> EventQueue | None: | |
| """Taps the event queue for a task ID to create a child queue.""" | |
| queue = await self.get(task_id) | |
| if queue is None: | |
| return None | |
| return queue.tap() | |
| async def close(self, task_id: str) -> None: | |
| """Closes and removes the event queue for a task ID.""" | |
| exists = await self.redis.sismember(self.tasks_key, task_id) | |
| if not exists: | |
| raise NoTaskQueue | |
| await self.redis.srem(self.tasks_key, task_id) | |
| # Close the queue if we can get it, but since it's distributed, perhaps notify via pubsub or something | |
| # For simplicity, assume clients handle closing their queues | |
| async def create_or_tap(self, task_id: str) -> EventQueue: | |
| """Creates a new event queue for a task ID if one doesn't exist, otherwise taps the existing one.""" | |
| exists = await self.redis.sismember(self.tasks_key, task_id) | |
| if not exists: | |
| await self.redis.sadd(self.tasks_key, task_id) | |
| return DistributedEventQueue(task_id, self.redis) | |
| else: | |
| queue = await self.get(task_id) | |
| return queue.tap() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import asyncio | |
| from unittest.mock import AsyncMock, MagicMock, patch | |
| import pytest | |
| import redis.asyncio as redis | |
| from a2a_distributed_queue import ( | |
| DistributedEventQueue, | |
| DistributedQueueManager, | |
| ) | |
| from a2a.server.events.event_queue import EventQueue | |
| from a2a.server.events.queue_manager import ( | |
| NoTaskQueue, | |
| TaskQueueExists, | |
| ) | |
| from a2a.types import Message, Task, TaskStatusUpdateEvent | |
| MINIMAL_TASK = { | |
| 'id': '123', | |
| 'context_id': 'session-xyz', | |
| 'status': {'state': 'submitted'}, | |
| 'kind': 'task', | |
| } | |
| MESSAGE_PAYLOAD = { | |
| 'role': 'agent', | |
| 'parts': [{'text': 'test message'}], | |
| 'message_id': '111', | |
| } | |
| @pytest.fixture | |
| async def redis_client(): | |
| """Fixture to create a Redis client for testing.""" | |
| # Use a test Redis instance or fakeredis | |
| client = redis.from_url("redis://localhost:6379", decode_responses=True) | |
| yield client | |
| # Clean up | |
| await client.flushall() | |
| @pytest.fixture | |
| def distributed_queue_manager(redis_client): | |
| """Fixture to create a DistributedQueueManager.""" | |
| return DistributedQueueManager(redis_url="redis://localhost:6379") | |
| @pytest.fixture | |
| def distributed_event_queue(redis_client): | |
| """Fixture to create a DistributedEventQueue.""" | |
| return DistributedEventQueue("test_task", redis_client) | |
| class TestDistributedEventQueue: | |
| @pytest.mark.asyncio | |
| async def test_init(self, distributed_event_queue): | |
| """Test that DistributedEventQueue initializes correctly.""" | |
| assert distributed_event_queue.task_id == "test_task" | |
| assert distributed_event_queue.redis is not None | |
| assert distributed_event_queue.channel == "task:test_task" | |
| assert not distributed_event_queue.is_closed() | |
| @pytest.mark.asyncio | |
| async def test_enqueue_and_dequeue_event(self, distributed_event_queue): | |
| """Test enqueue and dequeue of an event.""" | |
| event = Message(**MESSAGE_PAYLOAD) | |
| await distributed_event_queue.enqueue_event(event) | |
| dequeued = await distributed_event_queue.dequeue_event(no_wait=True) | |
| assert dequeued.role == event.role | |
| assert dequeued.parts == event.parts | |
| @pytest.mark.asyncio | |
| async def test_dequeue_event_no_wait_empty(self, distributed_event_queue): | |
| """Test dequeue with no_wait on empty queue raises QueueEmpty.""" | |
| with pytest.raises(asyncio.QueueEmpty): | |
| await distributed_event_queue.dequeue_event(no_wait=True) | |
| @pytest.mark.asyncio | |
| async def test_tap_creates_child(self, distributed_event_queue): | |
| """Test that tap creates a child queue.""" | |
| child = distributed_event_queue.tap() | |
| assert isinstance(child, DistributedEventQueue) | |
| assert child in distributed_event_queue._children | |
| @pytest.mark.asyncio | |
| async def test_close(self, distributed_event_queue): | |
| """Test closing the queue.""" | |
| await distributed_event_queue.close() | |
| assert distributed_event_queue.is_closed() | |
| @pytest.mark.asyncio | |
| async def test_close_immediate(self, distributed_event_queue): | |
| """Test closing with immediate=True.""" | |
| event = Message(**MESSAGE_PAYLOAD) | |
| await distributed_event_queue.enqueue_event(event) | |
| await distributed_event_queue.close(immediate=True) | |
| assert distributed_event_queue.is_closed() | |
| class TestDistributedQueueManager: | |
| @pytest.mark.asyncio | |
| async def test_init(self, distributed_queue_manager): | |
| """Test DistributedQueueManager initializes correctly.""" | |
| assert distributed_queue_manager.redis is not None | |
| assert distributed_queue_manager.tasks_key == "task_queues" | |
| @pytest.mark.asyncio | |
| async def test_add_new_queue(self, distributed_queue_manager): | |
| """Test adding a new queue.""" | |
| task_id = "test_task" | |
| queue = DistributedEventQueue(task_id, distributed_queue_manager.redis) | |
| await distributed_queue_manager.add(task_id, queue) | |
| # Check if task_id is in Redis set | |
| exists = await distributed_queue_manager.redis.sismember("task_queues", task_id) | |
| assert exists | |
| @pytest.mark.asyncio | |
| async def test_add_existing_queue_raises(self, distributed_queue_manager): | |
| """Test adding existing queue raises TaskQueueExists.""" | |
| task_id = "test_task" | |
| queue = DistributedEventQueue(task_id, distributed_queue_manager.redis) | |
| await distributed_queue_manager.add(task_id, queue) | |
| with pytest.raises(TaskQueueExists): | |
| await distributed_queue_manager.add(task_id, queue) | |
| @pytest.mark.asyncio | |
| async def test_get_existing_queue(self, distributed_queue_manager): | |
| """Test getting an existing queue.""" | |
| task_id = "test_task" | |
| await distributed_queue_manager.redis.sadd("task_queues", task_id) | |
| queue = await distributed_queue_manager.get(task_id) | |
| assert isinstance(queue, DistributedEventQueue) | |
| assert queue.task_id == task_id | |
| @pytest.mark.asyncio | |
| async def test_get_nonexistent_queue(self, distributed_queue_manager): | |
| """Test getting nonexistent queue returns None.""" | |
| queue = await distributed_queue_manager.get("nonexistent") | |
| assert queue is None | |
| @pytest.mark.asyncio | |
| async def test_tap_existing_queue(self, distributed_queue_manager): | |
| """Test tapping an existing queue.""" | |
| task_id = "test_task" | |
| await distributed_queue_manager.redis.sadd("task_queues", task_id) | |
| tapped = await distributed_queue_manager.tap(task_id) | |
| assert isinstance(tapped, DistributedEventQueue) | |
| @pytest.mark.asyncio | |
| async def test_tap_nonexistent_queue(self, distributed_queue_manager): | |
| """Test tapping nonexistent queue returns None.""" | |
| tapped = await distributed_queue_manager.tap("nonexistent") | |
| assert tapped is None | |
| @pytest.mark.asyncio | |
| async def test_close_existing_queue(self, distributed_queue_manager): | |
| """Test closing an existing queue.""" | |
| task_id = "test_task" | |
| await distributed_queue_manager.redis.sadd("task_queues", task_id) | |
| await distributed_queue_manager.close(task_id) | |
| exists = await distributed_queue_manager.redis.sismember("task_queues", task_id) | |
| assert not exists | |
| @pytest.mark.asyncio | |
| async def test_close_nonexistent_queue_raises(self, distributed_queue_manager): | |
| """Test closing nonexistent queue raises NoTaskQueue.""" | |
| with pytest.raises(NoTaskQueue): | |
| await distributed_queue_manager.close("nonexistent") | |
| @pytest.mark.asyncio | |
| async def test_create_or_tap_new(self, distributed_queue_manager): | |
| """Test create_or_tap for new task.""" | |
| task_id = "test_task" | |
| queue = await distributed_queue_manager.create_or_tap(task_id) | |
| assert isinstance(queue, DistributedEventQueue) | |
| exists = await distributed_queue_manager.redis.sismember("task_queues", task_id) | |
| assert exists | |
| @pytest.mark.asyncio | |
| async def test_create_or_tap_existing(self, distributed_queue_manager): | |
| """Test create_or_tap for existing task taps.""" | |
| task_id = "test_task" | |
| await distributed_queue_manager.redis.sadd("task_queues", task_id) | |
| queue = await distributed_queue_manager.create_or_tap(task_id) | |
| assert isinstance(queue, DistributedEventQueue) | |
| @pytest.mark.asyncio | |
| async def test_concurrency(self, distributed_queue_manager): | |
| """Test concurrent access.""" | |
| async def add_task(task_id): | |
| queue = DistributedEventQueue(task_id, distributed_queue_manager.redis) | |
| await distributed_queue_manager.add(task_id, queue) | |
| return task_id | |
| async def get_task(task_id): | |
| return await distributed_queue_manager.get(task_id) | |
| task_ids = [f"task_{i}" for i in range(10)] | |
| # Add concurrently | |
| add_tasks = [add_task(tid) for tid in task_ids] | |
| added = await asyncio.gather(*add_tasks) | |
| assert set(added) == set(task_ids) | |
| # Get concurrently | |
| get_tasks = [get_task(tid) for tid in task_ids] | |
| queues = await asyncio.gather(*get_tasks) | |
| assert all(isinstance(q, DistributedEventQueue) for q in queues) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment