Skip to content

Instantly share code, notes, and snippets.

@Tantas
Last active February 18, 2024 18:35
Show Gist options
  • Select an option

  • Save Tantas/b0e9f8a62807b3f8d5fcddda722c1492 to your computer and use it in GitHub Desktop.

Select an option

Save Tantas/b0e9f8a62807b3f8d5fcddda722c1492 to your computer and use it in GitHub Desktop.

Revisions

  1. Tantas revised this gist Feb 18, 2024. No changes.
  2. Tantas created this gist Feb 18, 2024.
    68 changes: 68 additions & 0 deletions sqlalchemy_pydantic_json.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,68 @@
    from pydantic import BaseModel, TypeAdapter
    from sqlalchemy import JSON, TypeDecorator
    from sqlalchemy.orm import DeclarativeBase
    from sqlalchemy.sql import elements


    class PydanticJson(TypeDecorator):
    impl = JSON
    cache_ok = True

    def __init__(self, model: type[BaseModel]):
    super().__init__(none_as_null=True)
    self.model = model

    def _make_bind_processor(self, string_process, json_serializer):
    if string_process:
    def process(value):
    if value is self.NULL:
    value = None
    elif isinstance(value, elements.Null) or (
    value is None and self.none_as_null
    ):
    return None
    serialized = json_serializer(value)
    return string_process(serialized)
    else:
    def process(value):
    if value is self.NULL:
    value = None
    elif isinstance(value, elements.Null) or (
    value is None and self.none_as_null
    ):
    return None
    return json_serializer(value)
    return process

    def bind_processor(self, dialect):
    string_process = self._str_impl.bind_processor(dialect)
    json_serializer = TypeAdapter(self.model).dump_json
    return self._make_bind_processor(string_process, json_serializer)

    def result_processor(self, dialect, coltype):
    string_process = self._str_impl.result_processor(dialect, coltype)
    json_deserializer = TypeAdapter(self.model).validate_json
    def process(value):
    if value is None:
    return None
    if string_process:
    value = string_process(value)
    return json_deserializer(value)
    return process


    """
    # Usage example.
    from pydantic import BaseModel
    from sqlalchemy.orm import DeclarativeBase
    class Base(DeclarativeBase):
    pass
    class PydanticType(BaseModel):
    field: str
    class Entity(Base):
    id: Mapped[int] = mapped_column(primary_key=True)
    data: Mapped[PydanticType] = mapped_column(PydanticJson(PydanticType))
    """