import asyncio import json import logging import uuid import redis.asyncio as redis from a2a.server.events.event_queue import EventQueue from a2a.server.events.queue_manager import ( NoTaskQueue, QueueManager, TaskQueueExists, ) from a2a.types import ( Message, Task, TaskArtifactUpdateEvent, TaskStatusUpdateEvent, ) from a2a.utils.telemetry import SpanKind, trace_class logger = logging.getLogger(__name__) Event = Message | Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent @trace_class(kind=SpanKind.SERVER) class DistributedEventQueue(EventQueue): """DistributedEventQueue uses Redis for storing and distributing events. This extends EventQueue to use Redis Lists for queuing events and Redis Pub/Sub for distributing events to child queues in a distributed setup. """ def __init__( self, task_id: str, redis_client: redis.Redis, max_queue_size: int = 1024, channel: str | None = None, ): """Initializes the DistributedEventQueue.""" super().__init__(max_queue_size) self.task_id = task_id self.redis = redis_client self.queue_key = f"queue:{task_id}:{uuid.uuid4().hex}" self.channel = channel or f"task:{task_id}" self._pubsub = self.redis.pubsub() self._subscriber_task: asyncio.Task | None = None self._children: list[DistributedEventQueue] = [] # Start subscribing to the channel for child events if self.channel: self._subscriber_task = asyncio.create_task(self._subscribe_to_channel()) async def _subscribe_to_channel(self): """Subscribe to the channel and handle incoming events.""" await self._pubsub.subscribe(self.channel) async for message in self._pubsub.listen(): if message["type"] == "message": try: event_data = json.loads(message["data"]) event = self._deserialize_event(event_data) await self._enqueue_locally(event) except Exception as e: logger.error(f"Error processing event from channel: {e}") async def _enqueue_locally(self, event: Event): """Enqueue an event locally to the Redis list.""" event_data = self._serialize_event(event) await self.redis.lpush(self.queue_key, event_data) def _serialize_event(self, event: Event) -> str: """Serialize an event to JSON.""" return json.dumps({ "type": type(event).__name__, "data": event.model_dump() if hasattr(event, "model_dump") else str(event), }) def _deserialize_event(self, data: dict) -> Event: """Deserialize an event from JSON.""" event_type = data["type"] event_data = data["data"] # Assuming events are Pydantic models or simple types if event_type == "Message": return Message(**event_data) elif event_type == "Task": return Task(**event_data) elif event_type == "TaskStatusUpdateEvent": return TaskStatusUpdateEvent(**event_data) elif event_type == "TaskArtifactUpdateEvent": return TaskArtifactUpdateEvent(**event_data) else: raise ValueError(f"Unknown event type: {event_type}") async def enqueue_event(self, event: Event) -> None: """Enqueues an event to this queue and all its children via Redis.""" async with self._lock: if self._is_closed: logger.warning('Queue is closed. Event will not be enqueued.') return logger.debug('Enqueuing event of type: %s', type(event)) # Enqueue locally await self._enqueue_locally(event) # Publish to channel for children event_data = self._serialize_event(event) await self.redis.publish(self.channel, event_data) async def dequeue_event(self, no_wait: bool = False) -> Event: """Dequeues an event from the Redis list.""" async with self._lock: if self._is_closed: raise asyncio.QueueEmpty('Queue is closed.') if no_wait: result = await self.redis.rpop(self.queue_key) if result is None: raise asyncio.QueueEmpty event_data = json.loads(result) event = self._deserialize_event(event_data) logger.debug('Dequeued event (no_wait=True) of type: %s', type(event)) return event # For waiting, we need to poll since Redis doesn't have async blocking pop in asyncio while True: result = await self.redis.brpop(self.queue_key, timeout=1) if result: _, event_str = result event_data = json.loads(event_str) event = self._deserialize_event(event_data) logger.debug('Dequeued event (waited) of type: %s', type(event)) return event if self._is_closed: raise asyncio.QueueEmpty('Queue is closed.') def tap(self) -> 'DistributedEventQueue': """Taps the event queue to create a new child queue that receives all future events.""" logger.debug('Tapping DistributedEventQueue to create a child queue.') child_queue = DistributedEventQueue( self.task_id, self.redis, self.queue.maxsize, self.channel ) self._children.append(child_queue) return child_queue async def close(self, immediate: bool = False) -> None: """Closes the queue and all child queues.""" logger.debug('Closing DistributedEventQueue.') async with self._lock: if self._is_closed and not immediate: return if not self._is_closed: self._is_closed = True if immediate: await self.redis.delete(self.queue_key) for child in self._children: await child.close(True) if self._subscriber_task: self._subscriber_task.cancel() else: # Wait for queue to drain while await self.redis.llen(self.queue_key) > 0: await asyncio.sleep(0.1) await self.redis.delete(self.queue_key) await asyncio.gather(*(child.close() for child in self._children)) if self._subscriber_task: self._subscriber_task.cancel() @trace_class(kind=SpanKind.SERVER) class DistributedQueueManager(QueueManager): """DistributedQueueManager uses Redis to manage distributed event queues. This implements the `QueueManager` interface using Redis for distributed storage. """ def __init__(self, redis_url: str = "redis://localhost:6379"): """Initializes the DistributedQueueManager.""" self.redis = redis.from_url(redis_url) self.tasks_key = "task_queues" async def add(self, task_id: str, queue: EventQueue) -> None: """Adds a new event queue for a task ID.""" exists = await self.redis.sismember(self.tasks_key, task_id) if exists: raise TaskQueueExists await self.redis.sadd(self.tasks_key, task_id) # Assuming queue is DistributedEventQueue, but for compatibility, we can create one # For now, we'll assume it's passed, but to respect interface, perhaps create internally # But the interface takes EventQueue, so we'll adapt async def get(self, task_id: str) -> EventQueue | None: """Retrieves the event queue for a task ID.""" exists = await self.redis.sismember(self.tasks_key, task_id) if not exists: return None return DistributedEventQueue(task_id, self.redis) async def tap(self, task_id: str) -> EventQueue | None: """Taps the event queue for a task ID to create a child queue.""" queue = await self.get(task_id) if queue is None: return None return queue.tap() async def close(self, task_id: str) -> None: """Closes and removes the event queue for a task ID.""" exists = await self.redis.sismember(self.tasks_key, task_id) if not exists: raise NoTaskQueue await self.redis.srem(self.tasks_key, task_id) # Close the queue if we can get it, but since it's distributed, perhaps notify via pubsub or something # For simplicity, assume clients handle closing their queues async def create_or_tap(self, task_id: str) -> EventQueue: """Creates a new event queue for a task ID if one doesn't exist, otherwise taps the existing one.""" exists = await self.redis.sismember(self.tasks_key, task_id) if not exists: await self.redis.sadd(self.tasks_key, task_id) return DistributedEventQueue(task_id, self.redis) else: queue = await self.get(task_id) return queue.tap()