Skip to content

Instantly share code, notes, and snippets.

@samuelcolvin
Created December 12, 2024 18:28
Show Gist options
  • Select an option

  • Save samuelcolvin/3d2b490d140defae08286bf1bdfab9a1 to your computer and use it in GitHub Desktop.

Select an option

Save samuelcolvin/3d2b490d140defae08286bf1bdfab9a1 to your computer and use it in GitHub Desktop.

Revisions

  1. samuelcolvin created this gist Dec 12, 2024.
    482 changes: 482 additions & 0 deletions ruff_format_broken.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,482 @@
    import json
    from dataclasses import dataclass
    from typing import Annotated, Any, Callable, Union

    import pydantic_core
    import pytest
    from inline_snapshot import snapshot
    from pydantic import BaseModel, Field
    from pydantic_core import PydanticSerializationError

    from pydantic_ai import Agent, RunContext, Tool, UserError
    from pydantic_ai.messages import Message, ModelAnyResponse, ModelTextResponse
    from pydantic_ai.models.function import AgentInfo, FunctionModel
    from pydantic_ai.models.test import TestModel
    from pydantic_ai.tools import ToolDefinition


    def test_tool_no_ctx():
    agent = Agent(TestModel())

    with pytest.raises(UserError) as exc_info:

    @agent.tool # pyright: ignore[reportArgumentType]
    def invalid_tool(x: int) -> str: # pragma: no cover
    return 'Hello'

    assert str(exc_info.value) == snapshot(
    'Error generating schema for test_tool_no_ctx.<locals>.invalid_tool:\n'
    ' First parameter of tools that take context must be annotated with RunContext[...]'
    )


    def test_tool_plain_with_ctx():
    agent = Agent(TestModel())

    with pytest.raises(UserError) as exc_info:

    @agent.tool_plain
    async def invalid_tool(ctx: RunContext[None]) -> str: # pragma: no cover
    return 'Hello'

    assert str(exc_info.value) == snapshot(
    'Error generating schema for test_tool_plain_with_ctx.<locals>.invalid_tool:\n'
    ' RunContext annotations can only be used with tools that take context'
    )


    def test_tool_ctx_second():
    agent = Agent(TestModel())

    with pytest.raises(UserError) as exc_info:

    @agent.tool # pyright: ignore[reportArgumentType]
    def invalid_tool(x: int, ctx: RunContext[None]) -> str: # pragma: no cover
    return 'Hello'

    assert str(exc_info.value) == snapshot(
    'Error generating schema for test_tool_ctx_second.<locals>.invalid_tool:\n'
    ' First parameter of tools that take context must be annotated with RunContext[...]\n'
    ' RunContext annotations can only be used as the first argument'
    )


    async def google_style_docstring(foo: int, bar: str) -> str: # pragma: no cover
    """Do foobar stuff, a lot.
    Args:
    foo: The foo thing.
    bar: The bar thing.
    """
    return f'{foo} {bar}'


    async def get_json_schema(_messages: list[Message], info: AgentInfo) -> ModelAnyResponse:
    assert len(info.function_tools) == 1
    r = info.function_tools[0]
    return ModelTextResponse(pydantic_core.to_json(r).decode())


    def test_docstring_google(set_event_loop: None):
    agent = Agent(FunctionModel(get_json_schema))
    agent.tool_plain(google_style_docstring)

    result = agent.run_sync('Hello')
    json_schema = json.loads(result.data)
    assert json_schema == snapshot(
    {
    'name': 'google_style_docstring',
    'description': 'Do foobar stuff, a lot.',
    'parameters_json_schema': {
    'properties': {
    'foo': {'description': 'The foo thing.', 'title': 'Foo', 'type': 'integer'},
    'bar': {'description': 'The bar thing.', 'title': 'Bar', 'type': 'string'},
    },
    'required': ['foo', 'bar'],
    'type': 'object',
    'additionalProperties': False,
    },
    'outer_typed_dict_key': None,
    }
    )
    # description should be the first key
    assert next(iter(json_schema)) == 'description'


    def sphinx_style_docstring(foo: int, /) -> str: # pragma: no cover
    """Sphinx style docstring.
    :param foo: The foo thing.
    :return: The result.
    """
    return str(foo)


    def test_docstring_sphinx(set_event_loop: None):
    agent = Agent(FunctionModel(get_json_schema))
    agent.tool_plain(sphinx_style_docstring)

    result = agent.run_sync('Hello')
    json_schema = json.loads(result.data)
    assert json_schema == snapshot(
    {
    'name': 'sphinx_style_docstring',
    'description': 'Sphinx style docstring.',
    'parameters_json_schema': {
    'properties': {'foo': {'description': 'The foo thing.', 'title': 'Foo', 'type': 'integer'}},
    'required': ['foo'],
    'type': 'object',
    'additionalProperties': False,
    },
    'outer_typed_dict_key': None,
    }
    )


    def numpy_style_docstring(*, foo: int, bar: str) -> str: # pragma: no cover
    """Numpy style docstring.
    Parameters
    ----------
    foo : int
    The foo thing.
    bar : str
    The bar thing.
    """
    return f'{foo} {bar}'


    def test_docstring_numpy(set_event_loop: None):
    agent = Agent(FunctionModel(get_json_schema))
    agent.tool_plain(numpy_style_docstring)

    result = agent.run_sync('Hello')
    json_schema = json.loads(result.data)
    assert json_schema == snapshot(
    {
    'name': 'numpy_style_docstring',
    'description': 'Numpy style docstring.',
    'parameters_json_schema': {
    'properties': {
    'foo': {'description': 'The foo thing.', 'title': 'Foo', 'type': 'integer'},
    'bar': {'description': 'The bar thing.', 'title': 'Bar', 'type': 'string'},
    },
    'required': ['foo', 'bar'],
    'type': 'object',
    'additionalProperties': False,
    },
    'outer_typed_dict_key': None,
    }
    )


    def unknown_docstring(**kwargs: int) -> str: # pragma: no cover
    """Unknown style docstring."""
    return str(kwargs)


    def test_docstring_unknown(set_event_loop: None):
    agent = Agent(FunctionModel(get_json_schema))
    agent.tool_plain(unknown_docstring)

    result = agent.run_sync('Hello')
    json_schema = json.loads(result.data)
    assert json_schema == snapshot(
    {
    'name': 'unknown_docstring',
    'description': 'Unknown style docstring.',
    'parameters_json_schema': {'properties': {}, 'type': 'object', 'additionalProperties': True},
    'outer_typed_dict_key': None,
    }
    )


    # fmt: off
    async def google_style_docstring_no_body(
    foo: int, bar: Annotated[str, Field(description='from fields')]
    ) -> str: # pragma: no cover
    """
    Args:
    foo: The foo thing.
    bar: The bar thing.
    """
    # fmt: on
    return f'{foo} {bar}'


    def test_docstring_google_no_body(set_event_loop: None):
    agent = Agent(FunctionModel(get_json_schema))
    agent.tool_plain(google_style_docstring_no_body)

    result = agent.run_sync('')
    json_schema = json.loads(result.data)
    assert json_schema == snapshot(
    {'name': 'google_style_docstring_no_body', 'description': '', 'parameters_json_schema': {
    'properties': {
    'foo': {'description': 'The foo thing.', 'title': 'Foo', 'type': 'integer'},
    'bar': {'description': 'from fields', 'title': 'Bar', 'type': 'string'},
    },
    'required': ['foo', 'bar'],
    'type': 'object',
    'additionalProperties': False,
    }, 'outer_typed_dict_key': None}
    )


    class Foo(BaseModel):
    x: int
    y: str


    def test_takes_just_model(set_event_loop: None):
    agent = Agent()

    @agent.tool_plain
    def takes_just_model(model: Foo) -> str:
    return f'{model.x} {model.y}'

    result = agent.run_sync('', model=FunctionModel(get_json_schema))
    json_schema = json.loads(result.data)
    assert json_schema == snapshot(
    {'name': 'takes_just_model', 'description': None, 'parameters_json_schema': {
    'properties': {
    'x': {'title': 'X', 'type': 'integer'},
    'y': {'title': 'Y', 'type': 'string'},
    },
    'required': ['x', 'y'],
    'title': 'Foo',
    'type': 'object',
    }, 'outer_typed_dict_key': None}
    )

    result = agent.run_sync('', model=TestModel())
    assert result.data == snapshot('{"takes_just_model":"0 a"}')


    def test_takes_model_and_int(set_event_loop: None):
    agent = Agent()

    @agent.tool_plain
    def takes_just_model(model: Foo, z: int) -> str:
    return f'{model.x} {model.y} {z}'

    result = agent.run_sync('', model=FunctionModel(get_json_schema))
    json_schema = json.loads(result.data)
    assert json_schema == snapshot(
    {'name': 'takes_just_model', 'description': '', 'parameters_json_schema': {
    '$defs': {
    'Foo': {
    'properties': {
    'x': {'title': 'X', 'type': 'integer'},
    'y': {'title': 'Y', 'type': 'string'},
    },
    'required': ['x', 'y'],
    'title': 'Foo',
    'type': 'object',
    }
    },
    'properties': {
    'model': {'$ref': '#/$defs/Foo'},
    'z': {'title': 'Z', 'type': 'integer'},
    },
    'required': ['model', 'z'],
    'type': 'object',
    'additionalProperties': False,
    }, 'outer_typed_dict_key': None}
    )

    result = agent.run_sync('', model=TestModel())
    assert result.data == snapshot('{"takes_just_model":"0 a 0"}')


    # pyright: reportPrivateUsage=false
    def test_init_tool_plain(set_event_loop: None):
    call_args: list[int] = []

    def plain_tool(x: int) -> int:
    call_args.append(x)
    return x + 1

    agent = Agent('test', tools=[Tool(plain_tool)], retries=7)
    result = agent.run_sync('foobar')
    assert result.data == snapshot('{"plain_tool":1}')
    assert call_args == snapshot([0])
    assert agent._function_tools['plain_tool'].takes_ctx is False
    assert agent._function_tools['plain_tool'].max_retries == 7

    agent_infer = Agent('test', tools=[plain_tool], retries=7)
    result = agent_infer.run_sync('foobar')
    assert result.data == snapshot('{"plain_tool":1}')
    assert call_args == snapshot([0, 0])
    assert agent_infer._function_tools['plain_tool'].takes_ctx is False
    assert agent_infer._function_tools['plain_tool'].max_retries == 7


    def ctx_tool(ctx: RunContext[int], x: int) -> int:
    return x + ctx.deps


    # pyright: reportPrivateUsage=false
    def test_init_tool_ctx(set_event_loop: None):
    agent = Agent('test', tools=[Tool(ctx_tool, takes_ctx=True, max_retries=3)], deps_type=int, retries=7)
    result = agent.run_sync('foobar', deps=5)
    assert result.data == snapshot('{"ctx_tool":5}')
    assert agent._function_tools['ctx_tool'].takes_ctx is True
    assert agent._function_tools['ctx_tool'].max_retries == 3

    agent_infer = Agent('test', tools=[ctx_tool], deps_type=int)
    result = agent_infer.run_sync('foobar', deps=6)
    assert result.data == snapshot('{"ctx_tool":6}')
    assert agent_infer._function_tools['ctx_tool'].takes_ctx is True


    def test_repeat_tool():
    with pytest.raises(UserError, match="Tool name conflicts with existing tool: 'ctx_tool'"):
    Agent('test', tools=[Tool(ctx_tool), ctx_tool], deps_type=int)


    def test_tool_return_conflict():
    # this is okay
    Agent('test', tools=[ctx_tool], deps_type=int)
    # this is also okay
    Agent('test', tools=[ctx_tool], deps_type=int, result_type=int)
    # this raises an error
    with pytest.raises(UserError, match="Tool name conflicts with result schema name: 'ctx_tool'"):
    Agent('test', tools=[ctx_tool], deps_type=int, result_type=int, result_tool_name='ctx_tool')


    def test_init_ctx_tool_invalid():
    def plain_tool(x: int) -> int: # pragma: no cover
    return x + 1

    m = r'First parameter of tools that take context must be annotated with RunContext\[\.\.\.\]'
    with pytest.raises(UserError, match=m):
    Tool(plain_tool, takes_ctx=True)


    def test_init_plain_tool_invalid():
    with pytest.raises(UserError, match='RunContext annotations can only be used with tools that take context'):
    Tool(ctx_tool, takes_ctx=False)


    def test_return_pydantic_model(set_event_loop: None):
    agent = Agent('test')

    @agent.tool_plain
    def return_pydantic_model(x: int) -> Foo:
    return Foo(x=x, y='a')

    result = agent.run_sync('')
    assert result.data == snapshot('{"return_pydantic_model":{"x":0,"y":"a"}}')


    def test_return_bytes(set_event_loop: None):
    agent = Agent('test')

    @agent.tool_plain
    def return_pydantic_model() -> bytes:
    return '🐈 Hello'.encode()

    result = agent.run_sync('')
    assert result.data == snapshot('{"return_pydantic_model":"🐈 Hello"}')


    def test_return_bytes_invalid(set_event_loop: None):
    agent = Agent('test')

    @agent.tool_plain
    def return_pydantic_model() -> bytes:
    return b'\00 \x81'

    with pytest.raises(PydanticSerializationError, match='invalid utf-8 sequence of 1 bytes from index 2'):
    agent.run_sync('')


    def test_return_unknown(set_event_loop: None):
    agent = Agent('test')

    class Foobar:
    pass

    @agent.tool_plain
    def return_pydantic_model() -> Foobar:
    return Foobar()

    with pytest.raises(PydanticSerializationError, match='Unable to serialize unknown type:'):
    agent.run_sync('')


    def test_dynamic_cls_tool(set_event_loop: None):
    @dataclass
    class MyTool(Tool[int]):
    spam: int

    def __init__(self, spam: int = 0, **kwargs: Any):
    self.spam = spam
    kwargs.update(function=self.tool_function, takes_ctx=False)
    super().__init__(**kwargs)

    def tool_function(self, x: int, y: str) -> str:
    return f'{self.spam} {x} {y}'

    async def prepare_tool_def(self, ctx: RunContext[int]) -> Union[ToolDefinition, None]:
    if ctx.deps != 42:
    return await super().prepare_tool_def(ctx)

    agent = Agent('test', tools=[MyTool(spam=777)], deps_type=int)
    r = agent.run_sync('', deps=1)
    assert r.data == snapshot('{"tool_function":"777 0 a"}')

    r = agent.run_sync('', deps=42)
    assert r.data == snapshot('success (no tool calls)')


    def test_dynamic_plain_tool_decorator(set_event_loop: None):
    agent = Agent('test', deps_type=int)

    async def prepare_tool_def(ctx: RunContext[int], tool_def: ToolDefinition) -> Union[ToolDefinition, None]:
    if ctx.deps != 42:
    return tool_def

    @agent.tool_plain(prepare=prepare_tool_def)
    def foobar(x: int, y: str) -> str:
    return f'{x} {y}'

    r = agent.run_sync('', deps=1)
    assert r.data == snapshot('{"foobar":"0 a"}')

    r = agent.run_sync('', deps=42)
    assert r.data == snapshot('success (no tool calls)')


    def test_dynamic_tool_decorator(set_event_loop: None):
    agent = Agent('test', deps_type=int)

    async def prepare_tool_def(ctx: RunContext[int], tool_def: ToolDefinition) -> Union[ToolDefinition, None]:
    if ctx.deps != 42:
    return tool_def

    @agent.tool(prepare=prepare_tool_def)
    def foobar(ctx: RunContext[int], x: int, y: str) -> str:
    return f'{ctx.deps} {x} {y}'

    r = agent.run_sync('', deps=1)
    assert r.data == snapshot('{"foobar":"1 0 a"}')

    r = agent.run_sync('', deps=42)
    assert r.data == snapshot('success (no tool calls)')


    def test_future_run_context(set_event_loop: None, create_module: Callable[[str], Any]):
    mod = create_module("""
    from __future__ import annotations
    from pydantic_ai import Agent, RunContext
    def ctx_tool(ctx: RunContext[int], x: int) -> int:
    return x + ctx.deps
    agent = Agent('test', tools=[ctx_tool], deps_type=int)
    """)
    result = mod.agent.run_sync('foobar', deps=5)
    assert result.data == snapshot('{"ctx_tool":5}')