Skip to content

Instantly share code, notes, and snippets.

@DiTo97
Last active October 18, 2025 15:18
Show Gist options
  • Save DiTo97/bac619c9d7d133fcbfd69ac6b708b8ef to your computer and use it in GitHub Desktop.
Save DiTo97/bac619c9d7d133fcbfd69ac6b708b8ef to your computer and use it in GitHub Desktop.

Revisions

  1. DiTo97 revised this gist Oct 18, 2025. 1 changed file with 1 addition and 1 deletion.
    2 changes: 1 addition & 1 deletion test_a2a_distributed_queue.py
    Original file line number Diff line number Diff line change
    @@ -4,7 +4,7 @@
    import pytest
    import redis.asyncio as redis

    from distributed_queue import (
    from a2a_distributed_queue import (
    DistributedEventQueue,
    DistributedQueueManager,
    )
  2. DiTo97 revised this gist Oct 18, 2025. 2 changed files with 0 additions and 0 deletions.
    File renamed without changes.
    File renamed without changes.
  3. DiTo97 created this gist Oct 18, 2025.
    226 changes: 226 additions & 0 deletions distributed_queue.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,226 @@
    import asyncio
    import json
    import logging
    import uuid

    import redis.asyncio as redis

    from a2a.server.events.event_queue import EventQueue
    from a2a.server.events.queue_manager import (
    NoTaskQueue,
    QueueManager,
    TaskQueueExists,
    )
    from a2a.types import (
    Message,
    Task,
    TaskArtifactUpdateEvent,
    TaskStatusUpdateEvent,
    )
    from a2a.utils.telemetry import SpanKind, trace_class

    logger = logging.getLogger(__name__)

    Event = Message | Task | TaskStatusUpdateEvent | TaskArtifactUpdateEvent


    @trace_class(kind=SpanKind.SERVER)
    class DistributedEventQueue(EventQueue):
    """DistributedEventQueue uses Redis for storing and distributing events.
    This extends EventQueue to use Redis Lists for queuing events and Redis Pub/Sub
    for distributing events to child queues in a distributed setup.
    """

    def __init__(
    self,
    task_id: str,
    redis_client: redis.Redis,
    max_queue_size: int = 1024,
    channel: str | None = None,
    ):
    """Initializes the DistributedEventQueue."""
    super().__init__(max_queue_size)
    self.task_id = task_id
    self.redis = redis_client
    self.queue_key = f"queue:{task_id}:{uuid.uuid4().hex}"
    self.channel = channel or f"task:{task_id}"
    self._pubsub = self.redis.pubsub()
    self._subscriber_task: asyncio.Task | None = None
    self._children: list[DistributedEventQueue] = []

    # Start subscribing to the channel for child events
    if self.channel:
    self._subscriber_task = asyncio.create_task(self._subscribe_to_channel())

    async def _subscribe_to_channel(self):
    """Subscribe to the channel and handle incoming events."""
    await self._pubsub.subscribe(self.channel)
    async for message in self._pubsub.listen():
    if message["type"] == "message":
    try:
    event_data = json.loads(message["data"])
    event = self._deserialize_event(event_data)
    await self._enqueue_locally(event)
    except Exception as e:
    logger.error(f"Error processing event from channel: {e}")

    async def _enqueue_locally(self, event: Event):
    """Enqueue an event locally to the Redis list."""
    event_data = self._serialize_event(event)
    await self.redis.lpush(self.queue_key, event_data)

    def _serialize_event(self, event: Event) -> str:
    """Serialize an event to JSON."""
    return json.dumps({
    "type": type(event).__name__,
    "data": event.model_dump() if hasattr(event, "model_dump") else str(event),
    })

    def _deserialize_event(self, data: dict) -> Event:
    """Deserialize an event from JSON."""
    event_type = data["type"]
    event_data = data["data"]
    # Assuming events are Pydantic models or simple types
    if event_type == "Message":
    return Message(**event_data)
    elif event_type == "Task":
    return Task(**event_data)
    elif event_type == "TaskStatusUpdateEvent":
    return TaskStatusUpdateEvent(**event_data)
    elif event_type == "TaskArtifactUpdateEvent":
    return TaskArtifactUpdateEvent(**event_data)
    else:
    raise ValueError(f"Unknown event type: {event_type}")

    async def enqueue_event(self, event: Event) -> None:
    """Enqueues an event to this queue and all its children via Redis."""
    async with self._lock:
    if self._is_closed:
    logger.warning('Queue is closed. Event will not be enqueued.')
    return

    logger.debug('Enqueuing event of type: %s', type(event))

    # Enqueue locally
    await self._enqueue_locally(event)

    # Publish to channel for children
    event_data = self._serialize_event(event)
    await self.redis.publish(self.channel, event_data)

    async def dequeue_event(self, no_wait: bool = False) -> Event:
    """Dequeues an event from the Redis list."""
    async with self._lock:
    if self._is_closed:
    raise asyncio.QueueEmpty('Queue is closed.')

    if no_wait:
    result = await self.redis.rpop(self.queue_key)
    if result is None:
    raise asyncio.QueueEmpty
    event_data = json.loads(result)
    event = self._deserialize_event(event_data)
    logger.debug('Dequeued event (no_wait=True) of type: %s', type(event))
    return event

    # For waiting, we need to poll since Redis doesn't have async blocking pop in asyncio
    while True:
    result = await self.redis.brpop(self.queue_key, timeout=1)
    if result:
    _, event_str = result
    event_data = json.loads(event_str)
    event = self._deserialize_event(event_data)
    logger.debug('Dequeued event (waited) of type: %s', type(event))
    return event
    if self._is_closed:
    raise asyncio.QueueEmpty('Queue is closed.')

    def tap(self) -> 'DistributedEventQueue':
    """Taps the event queue to create a new child queue that receives all future events."""
    logger.debug('Tapping DistributedEventQueue to create a child queue.')
    child_queue = DistributedEventQueue(
    self.task_id, self.redis, self.queue.maxsize, self.channel
    )
    self._children.append(child_queue)
    return child_queue

    async def close(self, immediate: bool = False) -> None:
    """Closes the queue and all child queues."""
    logger.debug('Closing DistributedEventQueue.')
    async with self._lock:
    if self._is_closed and not immediate:
    return
    if not self._is_closed:
    self._is_closed = True

    if immediate:
    await self.redis.delete(self.queue_key)
    for child in self._children:
    await child.close(True)
    if self._subscriber_task:
    self._subscriber_task.cancel()
    else:
    # Wait for queue to drain
    while await self.redis.llen(self.queue_key) > 0:
    await asyncio.sleep(0.1)
    await self.redis.delete(self.queue_key)
    await asyncio.gather(*(child.close() for child in self._children))
    if self._subscriber_task:
    self._subscriber_task.cancel()


    @trace_class(kind=SpanKind.SERVER)
    class DistributedQueueManager(QueueManager):
    """DistributedQueueManager uses Redis to manage distributed event queues.
    This implements the `QueueManager` interface using Redis for distributed storage.
    """

    def __init__(self, redis_url: str = "redis://localhost:6379"):
    """Initializes the DistributedQueueManager."""
    self.redis = redis.from_url(redis_url)
    self.tasks_key = "task_queues"

    async def add(self, task_id: str, queue: EventQueue) -> None:
    """Adds a new event queue for a task ID."""
    exists = await self.redis.sismember(self.tasks_key, task_id)
    if exists:
    raise TaskQueueExists
    await self.redis.sadd(self.tasks_key, task_id)
    # Assuming queue is DistributedEventQueue, but for compatibility, we can create one
    # For now, we'll assume it's passed, but to respect interface, perhaps create internally
    # But the interface takes EventQueue, so we'll adapt

    async def get(self, task_id: str) -> EventQueue | None:
    """Retrieves the event queue for a task ID."""
    exists = await self.redis.sismember(self.tasks_key, task_id)
    if not exists:
    return None
    return DistributedEventQueue(task_id, self.redis)

    async def tap(self, task_id: str) -> EventQueue | None:
    """Taps the event queue for a task ID to create a child queue."""
    queue = await self.get(task_id)
    if queue is None:
    return None
    return queue.tap()

    async def close(self, task_id: str) -> None:
    """Closes and removes the event queue for a task ID."""
    exists = await self.redis.sismember(self.tasks_key, task_id)
    if not exists:
    raise NoTaskQueue
    await self.redis.srem(self.tasks_key, task_id)
    # Close the queue if we can get it, but since it's distributed, perhaps notify via pubsub or something
    # For simplicity, assume clients handle closing their queues

    async def create_or_tap(self, task_id: str) -> EventQueue:
    """Creates a new event queue for a task ID if one doesn't exist, otherwise taps the existing one."""
    exists = await self.redis.sismember(self.tasks_key, task_id)
    if not exists:
    await self.redis.sadd(self.tasks_key, task_id)
    return DistributedEventQueue(task_id, self.redis)
    else:
    queue = await self.get(task_id)
    return queue.tap()
    218 changes: 218 additions & 0 deletions test_distributed_queue.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,218 @@
    import asyncio
    from unittest.mock import AsyncMock, MagicMock, patch

    import pytest
    import redis.asyncio as redis

    from distributed_queue import (
    DistributedEventQueue,
    DistributedQueueManager,
    )
    from a2a.server.events.event_queue import EventQueue
    from a2a.server.events.queue_manager import (
    NoTaskQueue,
    TaskQueueExists,
    )
    from a2a.types import Message, Task, TaskStatusUpdateEvent


    MINIMAL_TASK = {
    'id': '123',
    'context_id': 'session-xyz',
    'status': {'state': 'submitted'},
    'kind': 'task',
    }
    MESSAGE_PAYLOAD = {
    'role': 'agent',
    'parts': [{'text': 'test message'}],
    'message_id': '111',
    }


    @pytest.fixture
    async def redis_client():
    """Fixture to create a Redis client for testing."""
    # Use a test Redis instance or fakeredis
    client = redis.from_url("redis://localhost:6379", decode_responses=True)
    yield client
    # Clean up
    await client.flushall()


    @pytest.fixture
    def distributed_queue_manager(redis_client):
    """Fixture to create a DistributedQueueManager."""
    return DistributedQueueManager(redis_url="redis://localhost:6379")


    @pytest.fixture
    def distributed_event_queue(redis_client):
    """Fixture to create a DistributedEventQueue."""
    return DistributedEventQueue("test_task", redis_client)


    class TestDistributedEventQueue:
    @pytest.mark.asyncio
    async def test_init(self, distributed_event_queue):
    """Test that DistributedEventQueue initializes correctly."""
    assert distributed_event_queue.task_id == "test_task"
    assert distributed_event_queue.redis is not None
    assert distributed_event_queue.channel == "task:test_task"
    assert not distributed_event_queue.is_closed()

    @pytest.mark.asyncio
    async def test_enqueue_and_dequeue_event(self, distributed_event_queue):
    """Test enqueue and dequeue of an event."""
    event = Message(**MESSAGE_PAYLOAD)
    await distributed_event_queue.enqueue_event(event)

    dequeued = await distributed_event_queue.dequeue_event(no_wait=True)
    assert dequeued.role == event.role
    assert dequeued.parts == event.parts

    @pytest.mark.asyncio
    async def test_dequeue_event_no_wait_empty(self, distributed_event_queue):
    """Test dequeue with no_wait on empty queue raises QueueEmpty."""
    with pytest.raises(asyncio.QueueEmpty):
    await distributed_event_queue.dequeue_event(no_wait=True)

    @pytest.mark.asyncio
    async def test_tap_creates_child(self, distributed_event_queue):
    """Test that tap creates a child queue."""
    child = distributed_event_queue.tap()
    assert isinstance(child, DistributedEventQueue)
    assert child in distributed_event_queue._children

    @pytest.mark.asyncio
    async def test_close(self, distributed_event_queue):
    """Test closing the queue."""
    await distributed_event_queue.close()
    assert distributed_event_queue.is_closed()

    @pytest.mark.asyncio
    async def test_close_immediate(self, distributed_event_queue):
    """Test closing with immediate=True."""
    event = Message(**MESSAGE_PAYLOAD)
    await distributed_event_queue.enqueue_event(event)

    await distributed_event_queue.close(immediate=True)
    assert distributed_event_queue.is_closed()


    class TestDistributedQueueManager:
    @pytest.mark.asyncio
    async def test_init(self, distributed_queue_manager):
    """Test DistributedQueueManager initializes correctly."""
    assert distributed_queue_manager.redis is not None
    assert distributed_queue_manager.tasks_key == "task_queues"

    @pytest.mark.asyncio
    async def test_add_new_queue(self, distributed_queue_manager):
    """Test adding a new queue."""
    task_id = "test_task"
    queue = DistributedEventQueue(task_id, distributed_queue_manager.redis)
    await distributed_queue_manager.add(task_id, queue)

    # Check if task_id is in Redis set
    exists = await distributed_queue_manager.redis.sismember("task_queues", task_id)
    assert exists

    @pytest.mark.asyncio
    async def test_add_existing_queue_raises(self, distributed_queue_manager):
    """Test adding existing queue raises TaskQueueExists."""
    task_id = "test_task"
    queue = DistributedEventQueue(task_id, distributed_queue_manager.redis)
    await distributed_queue_manager.add(task_id, queue)

    with pytest.raises(TaskQueueExists):
    await distributed_queue_manager.add(task_id, queue)

    @pytest.mark.asyncio
    async def test_get_existing_queue(self, distributed_queue_manager):
    """Test getting an existing queue."""
    task_id = "test_task"
    await distributed_queue_manager.redis.sadd("task_queues", task_id)

    queue = await distributed_queue_manager.get(task_id)
    assert isinstance(queue, DistributedEventQueue)
    assert queue.task_id == task_id

    @pytest.mark.asyncio
    async def test_get_nonexistent_queue(self, distributed_queue_manager):
    """Test getting nonexistent queue returns None."""
    queue = await distributed_queue_manager.get("nonexistent")
    assert queue is None

    @pytest.mark.asyncio
    async def test_tap_existing_queue(self, distributed_queue_manager):
    """Test tapping an existing queue."""
    task_id = "test_task"
    await distributed_queue_manager.redis.sadd("task_queues", task_id)

    tapped = await distributed_queue_manager.tap(task_id)
    assert isinstance(tapped, DistributedEventQueue)

    @pytest.mark.asyncio
    async def test_tap_nonexistent_queue(self, distributed_queue_manager):
    """Test tapping nonexistent queue returns None."""
    tapped = await distributed_queue_manager.tap("nonexistent")
    assert tapped is None

    @pytest.mark.asyncio
    async def test_close_existing_queue(self, distributed_queue_manager):
    """Test closing an existing queue."""
    task_id = "test_task"
    await distributed_queue_manager.redis.sadd("task_queues", task_id)

    await distributed_queue_manager.close(task_id)

    exists = await distributed_queue_manager.redis.sismember("task_queues", task_id)
    assert not exists

    @pytest.mark.asyncio
    async def test_close_nonexistent_queue_raises(self, distributed_queue_manager):
    """Test closing nonexistent queue raises NoTaskQueue."""
    with pytest.raises(NoTaskQueue):
    await distributed_queue_manager.close("nonexistent")

    @pytest.mark.asyncio
    async def test_create_or_tap_new(self, distributed_queue_manager):
    """Test create_or_tap for new task."""
    task_id = "test_task"
    queue = await distributed_queue_manager.create_or_tap(task_id)
    assert isinstance(queue, DistributedEventQueue)

    exists = await distributed_queue_manager.redis.sismember("task_queues", task_id)
    assert exists

    @pytest.mark.asyncio
    async def test_create_or_tap_existing(self, distributed_queue_manager):
    """Test create_or_tap for existing task taps."""
    task_id = "test_task"
    await distributed_queue_manager.redis.sadd("task_queues", task_id)

    queue = await distributed_queue_manager.create_or_tap(task_id)
    assert isinstance(queue, DistributedEventQueue)

    @pytest.mark.asyncio
    async def test_concurrency(self, distributed_queue_manager):
    """Test concurrent access."""
    async def add_task(task_id):
    queue = DistributedEventQueue(task_id, distributed_queue_manager.redis)
    await distributed_queue_manager.add(task_id, queue)
    return task_id

    async def get_task(task_id):
    return await distributed_queue_manager.get(task_id)

    task_ids = [f"task_{i}" for i in range(10)]

    # Add concurrently
    add_tasks = [add_task(tid) for tid in task_ids]
    added = await asyncio.gather(*add_tasks)
    assert set(added) == set(task_ids)

    # Get concurrently
    get_tasks = [get_task(tid) for tid in task_ids]
    queues = await asyncio.gather(*get_tasks)
    assert all(isinstance(q, DistributedEventQueue) for q in queues)