Skip to content

Instantly share code, notes, and snippets.

@DiTo97
Last active October 18, 2025 15:18
Show Gist options
  • Save DiTo97/bac619c9d7d133fcbfd69ac6b708b8ef to your computer and use it in GitHub Desktop.
Save DiTo97/bac619c9d7d133fcbfd69ac6b708b8ef to your computer and use it in GitHub Desktop.
A2A event queue and queue manager using Redis
import asyncio
import json
import logging
import uuid
import redis.asyncio as redis
from a2a.server.events.event_queue import EventQueue
from a2a.server.events.queue_manager import (
NoTaskQueue,
QueueManager,
TaskQueueExists,
)
from a2a.types import (
Message,
Task,
TaskArtifactUpdateEvent,
TaskStatusUpdateEvent,
)
from a2a.utils.telemetry import SpanKind, trace_class
logger = logging.getLogger(__name__)
Event = Message | Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent
@trace_class(kind=SpanKind.SERVER)
class DistributedEventQueue(EventQueue):
"""DistributedEventQueue uses Redis for storing and distributing events.
This extends EventQueue to use Redis Lists for queuing events and Redis Pub/Sub
for distributing events to child queues in a distributed setup.
"""
def __init__(
self,
task_id: str,
redis_client: redis.Redis,
max_queue_size: int = 1024,
channel: str | None = None,
):
"""Initializes the DistributedEventQueue."""
super().__init__(max_queue_size)
self.task_id = task_id
self.redis = redis_client
self.queue_key = f"queue:{task_id}:{uuid.uuid4().hex}"
self.channel = channel or f"task:{task_id}"
self._pubsub = self.redis.pubsub()
self._subscriber_task: asyncio.Task | None = None
self._children: list[DistributedEventQueue] = []
# Start subscribing to the channel for child events
if self.channel:
self._subscriber_task = asyncio.create_task(self._subscribe_to_channel())
async def _subscribe_to_channel(self):
"""Subscribe to the channel and handle incoming events."""
await self._pubsub.subscribe(self.channel)
async for message in self._pubsub.listen():
if message["type"] == "message":
try:
event_data = json.loads(message["data"])
event = self._deserialize_event(event_data)
await self._enqueue_locally(event)
except Exception as e:
logger.error(f"Error processing event from channel: {e}")
async def _enqueue_locally(self, event: Event):
"""Enqueue an event locally to the Redis list."""
event_data = self._serialize_event(event)
await self.redis.lpush(self.queue_key, event_data)
def _serialize_event(self, event: Event) -> str:
"""Serialize an event to JSON."""
return json.dumps({
"type": type(event).__name__,
"data": event.model_dump() if hasattr(event, "model_dump") else str(event),
})
def _deserialize_event(self, data: dict) -> Event:
"""Deserialize an event from JSON."""
event_type = data["type"]
event_data = data["data"]
# Assuming events are Pydantic models or simple types
if event_type == "Message":
return Message(**event_data)
elif event_type == "Task":
return Task(**event_data)
elif event_type == "TaskStatusUpdateEvent":
return TaskStatusUpdateEvent(**event_data)
elif event_type == "TaskArtifactUpdateEvent":
return TaskArtifactUpdateEvent(**event_data)
else:
raise ValueError(f"Unknown event type: {event_type}")
async def enqueue_event(self, event: Event) -> None:
"""Enqueues an event to this queue and all its children via Redis."""
async with self._lock:
if self._is_closed:
logger.warning('Queue is closed. Event will not be enqueued.')
return
logger.debug('Enqueuing event of type: %s', type(event))
# Enqueue locally
await self._enqueue_locally(event)
# Publish to channel for children
event_data = self._serialize_event(event)
await self.redis.publish(self.channel, event_data)
async def dequeue_event(self, no_wait: bool = False) -> Event:
"""Dequeues an event from the Redis list."""
async with self._lock:
if self._is_closed:
raise asyncio.QueueEmpty('Queue is closed.')
if no_wait:
result = await self.redis.rpop(self.queue_key)
if result is None:
raise asyncio.QueueEmpty
event_data = json.loads(result)
event = self._deserialize_event(event_data)
logger.debug('Dequeued event (no_wait=True) of type: %s', type(event))
return event
# For waiting, we need to poll since Redis doesn't have async blocking pop in asyncio
while True:
result = await self.redis.brpop(self.queue_key, timeout=1)
if result:
_, event_str = result
event_data = json.loads(event_str)
event = self._deserialize_event(event_data)
logger.debug('Dequeued event (waited) of type: %s', type(event))
return event
if self._is_closed:
raise asyncio.QueueEmpty('Queue is closed.')
def tap(self) -> 'DistributedEventQueue':
"""Taps the event queue to create a new child queue that receives all future events."""
logger.debug('Tapping DistributedEventQueue to create a child queue.')
child_queue = DistributedEventQueue(
self.task_id, self.redis, self.queue.maxsize, self.channel
)
self._children.append(child_queue)
return child_queue
async def close(self, immediate: bool = False) -> None:
"""Closes the queue and all child queues."""
logger.debug('Closing DistributedEventQueue.')
async with self._lock:
if self._is_closed and not immediate:
return
if not self._is_closed:
self._is_closed = True
if immediate:
await self.redis.delete(self.queue_key)
for child in self._children:
await child.close(True)
if self._subscriber_task:
self._subscriber_task.cancel()
else:
# Wait for queue to drain
while await self.redis.llen(self.queue_key) > 0:
await asyncio.sleep(0.1)
await self.redis.delete(self.queue_key)
await asyncio.gather(*(child.close() for child in self._children))
if self._subscriber_task:
self._subscriber_task.cancel()
@trace_class(kind=SpanKind.SERVER)
class DistributedQueueManager(QueueManager):
"""DistributedQueueManager uses Redis to manage distributed event queues.
This implements the `QueueManager` interface using Redis for distributed storage.
"""
def __init__(self, redis_url: str = "redis://localhost:6379"):
"""Initializes the DistributedQueueManager."""
self.redis = redis.from_url(redis_url)
self.tasks_key = "task_queues"
async def add(self, task_id: str, queue: EventQueue) -> None:
"""Adds a new event queue for a task ID."""
exists = await self.redis.sismember(self.tasks_key, task_id)
if exists:
raise TaskQueueExists
await self.redis.sadd(self.tasks_key, task_id)
# Assuming queue is DistributedEventQueue, but for compatibility, we can create one
# For now, we'll assume it's passed, but to respect interface, perhaps create internally
# But the interface takes EventQueue, so we'll adapt
async def get(self, task_id: str) -> EventQueue | None:
"""Retrieves the event queue for a task ID."""
exists = await self.redis.sismember(self.tasks_key, task_id)
if not exists:
return None
return DistributedEventQueue(task_id, self.redis)
async def tap(self, task_id: str) -> EventQueue | None:
"""Taps the event queue for a task ID to create a child queue."""
queue = await self.get(task_id)
if queue is None:
return None
return queue.tap()
async def close(self, task_id: str) -> None:
"""Closes and removes the event queue for a task ID."""
exists = await self.redis.sismember(self.tasks_key, task_id)
if not exists:
raise NoTaskQueue
await self.redis.srem(self.tasks_key, task_id)
# Close the queue if we can get it, but since it's distributed, perhaps notify via pubsub or something
# For simplicity, assume clients handle closing their queues
async def create_or_tap(self, task_id: str) -> EventQueue:
"""Creates a new event queue for a task ID if one doesn't exist, otherwise taps the existing one."""
exists = await self.redis.sismember(self.tasks_key, task_id)
if not exists:
await self.redis.sadd(self.tasks_key, task_id)
return DistributedEventQueue(task_id, self.redis)
else:
queue = await self.get(task_id)
return queue.tap()
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import redis.asyncio as redis
from a2a_distributed_queue import (
DistributedEventQueue,
DistributedQueueManager,
)
from a2a.server.events.event_queue import EventQueue
from a2a.server.events.queue_manager import (
NoTaskQueue,
TaskQueueExists,
)
from a2a.types import Message, Task, TaskStatusUpdateEvent
MINIMAL_TASK = {
'id': '123',
'context_id': 'session-xyz',
'status': {'state': 'submitted'},
'kind': 'task',
}
MESSAGE_PAYLOAD = {
'role': 'agent',
'parts': [{'text': 'test message'}],
'message_id': '111',
}
@pytest.fixture
async def redis_client():
"""Fixture to create a Redis client for testing."""
# Use a test Redis instance or fakeredis
client = redis.from_url("redis://localhost:6379", decode_responses=True)
yield client
# Clean up
await client.flushall()
@pytest.fixture
def distributed_queue_manager(redis_client):
"""Fixture to create a DistributedQueueManager."""
return DistributedQueueManager(redis_url="redis://localhost:6379")
@pytest.fixture
def distributed_event_queue(redis_client):
"""Fixture to create a DistributedEventQueue."""
return DistributedEventQueue("test_task", redis_client)
class TestDistributedEventQueue:
@pytest.mark.asyncio
async def test_init(self, distributed_event_queue):
"""Test that DistributedEventQueue initializes correctly."""
assert distributed_event_queue.task_id == "test_task"
assert distributed_event_queue.redis is not None
assert distributed_event_queue.channel == "task:test_task"
assert not distributed_event_queue.is_closed()
@pytest.mark.asyncio
async def test_enqueue_and_dequeue_event(self, distributed_event_queue):
"""Test enqueue and dequeue of an event."""
event = Message(**MESSAGE_PAYLOAD)
await distributed_event_queue.enqueue_event(event)
dequeued = await distributed_event_queue.dequeue_event(no_wait=True)
assert dequeued.role == event.role
assert dequeued.parts == event.parts
@pytest.mark.asyncio
async def test_dequeue_event_no_wait_empty(self, distributed_event_queue):
"""Test dequeue with no_wait on empty queue raises QueueEmpty."""
with pytest.raises(asyncio.QueueEmpty):
await distributed_event_queue.dequeue_event(no_wait=True)
@pytest.mark.asyncio
async def test_tap_creates_child(self, distributed_event_queue):
"""Test that tap creates a child queue."""
child = distributed_event_queue.tap()
assert isinstance(child, DistributedEventQueue)
assert child in distributed_event_queue._children
@pytest.mark.asyncio
async def test_close(self, distributed_event_queue):
"""Test closing the queue."""
await distributed_event_queue.close()
assert distributed_event_queue.is_closed()
@pytest.mark.asyncio
async def test_close_immediate(self, distributed_event_queue):
"""Test closing with immediate=True."""
event = Message(**MESSAGE_PAYLOAD)
await distributed_event_queue.enqueue_event(event)
await distributed_event_queue.close(immediate=True)
assert distributed_event_queue.is_closed()
class TestDistributedQueueManager:
@pytest.mark.asyncio
async def test_init(self, distributed_queue_manager):
"""Test DistributedQueueManager initializes correctly."""
assert distributed_queue_manager.redis is not None
assert distributed_queue_manager.tasks_key == "task_queues"
@pytest.mark.asyncio
async def test_add_new_queue(self, distributed_queue_manager):
"""Test adding a new queue."""
task_id = "test_task"
queue = DistributedEventQueue(task_id, distributed_queue_manager.redis)
await distributed_queue_manager.add(task_id, queue)
# Check if task_id is in Redis set
exists = await distributed_queue_manager.redis.sismember("task_queues", task_id)
assert exists
@pytest.mark.asyncio
async def test_add_existing_queue_raises(self, distributed_queue_manager):
"""Test adding existing queue raises TaskQueueExists."""
task_id = "test_task"
queue = DistributedEventQueue(task_id, distributed_queue_manager.redis)
await distributed_queue_manager.add(task_id, queue)
with pytest.raises(TaskQueueExists):
await distributed_queue_manager.add(task_id, queue)
@pytest.mark.asyncio
async def test_get_existing_queue(self, distributed_queue_manager):
"""Test getting an existing queue."""
task_id = "test_task"
await distributed_queue_manager.redis.sadd("task_queues", task_id)
queue = await distributed_queue_manager.get(task_id)
assert isinstance(queue, DistributedEventQueue)
assert queue.task_id == task_id
@pytest.mark.asyncio
async def test_get_nonexistent_queue(self, distributed_queue_manager):
"""Test getting nonexistent queue returns None."""
queue = await distributed_queue_manager.get("nonexistent")
assert queue is None
@pytest.mark.asyncio
async def test_tap_existing_queue(self, distributed_queue_manager):
"""Test tapping an existing queue."""
task_id = "test_task"
await distributed_queue_manager.redis.sadd("task_queues", task_id)
tapped = await distributed_queue_manager.tap(task_id)
assert isinstance(tapped, DistributedEventQueue)
@pytest.mark.asyncio
async def test_tap_nonexistent_queue(self, distributed_queue_manager):
"""Test tapping nonexistent queue returns None."""
tapped = await distributed_queue_manager.tap("nonexistent")
assert tapped is None
@pytest.mark.asyncio
async def test_close_existing_queue(self, distributed_queue_manager):
"""Test closing an existing queue."""
task_id = "test_task"
await distributed_queue_manager.redis.sadd("task_queues", task_id)
await distributed_queue_manager.close(task_id)
exists = await distributed_queue_manager.redis.sismember("task_queues", task_id)
assert not exists
@pytest.mark.asyncio
async def test_close_nonexistent_queue_raises(self, distributed_queue_manager):
"""Test closing nonexistent queue raises NoTaskQueue."""
with pytest.raises(NoTaskQueue):
await distributed_queue_manager.close("nonexistent")
@pytest.mark.asyncio
async def test_create_or_tap_new(self, distributed_queue_manager):
"""Test create_or_tap for new task."""
task_id = "test_task"
queue = await distributed_queue_manager.create_or_tap(task_id)
assert isinstance(queue, DistributedEventQueue)
exists = await distributed_queue_manager.redis.sismember("task_queues", task_id)
assert exists
@pytest.mark.asyncio
async def test_create_or_tap_existing(self, distributed_queue_manager):
"""Test create_or_tap for existing task taps."""
task_id = "test_task"
await distributed_queue_manager.redis.sadd("task_queues", task_id)
queue = await distributed_queue_manager.create_or_tap(task_id)
assert isinstance(queue, DistributedEventQueue)
@pytest.mark.asyncio
async def test_concurrency(self, distributed_queue_manager):
"""Test concurrent access."""
async def add_task(task_id):
queue = DistributedEventQueue(task_id, distributed_queue_manager.redis)
await distributed_queue_manager.add(task_id, queue)
return task_id
async def get_task(task_id):
return await distributed_queue_manager.get(task_id)
task_ids = [f"task_{i}" for i in range(10)]
# Add concurrently
add_tasks = [add_task(tid) for tid in task_ids]
added = await asyncio.gather(*add_tasks)
assert set(added) == set(task_ids)
# Get concurrently
get_tasks = [get_task(tid) for tid in task_ids]
queues = await asyncio.gather(*get_tasks)
assert all(isinstance(q, DistributedEventQueue) for q in queues)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment