Skip to content

Instantly share code, notes, and snippets.

@geeksambhu
Last active June 26, 2025 20:00
Show Gist options
  • Save geeksambhu/c22ffd60a9528379b825dd5016f615fe to your computer and use it in GitHub Desktop.
Save geeksambhu/c22ffd60a9528379b825dd5016f615fe to your computer and use it in GitHub Desktop.
springbatch.py
from __future__ import annotations
from typing import Any, Dict, Iterable, Type, TypeVar, Protocol
from boto3.dynamodb.conditions import Attr, Key
from pydantic import BaseModel, ValidationError
T = TypeVar("T", bound=BaseModel)
class TableLike(Protocol):
def put_item(self, **kwargs): ...
def get_item(self, **kwargs): ...
def delete_item(self, **kwargs): ...
def update_item(self, **kwargs): ...
def scan(self, **kwargs): ...
def query(self, **kwargs): ...
class DynamoRepo:
"""
A clean, injected DynamoDB repository with Pydantic validation.
- insert/upsert use put_item (entire object)
- update uses SET expressions (partial)
"""
def __init__(
self,
table: TableLike,
model_cls: Type[T],
*,
pk_attr: str,
sk_attr: str | None = None,
):
self.table = table
self.model_cls = model_cls
self.pk_attr = pk_attr
self.sk_attr = sk_attr
def _key(self, d: Dict[str, Any]) -> Dict[str, Any]:
key = {self.pk_attr: d[self.pk_attr]}
if self.sk_attr:
key[self.sk_attr] = d[self.sk_attr]
return key
def _validate(self, payload: dict | BaseModel) -> dict:
if isinstance(payload, BaseModel):
return payload.dict()
try:
return self.model_cls(**payload).dict()
except ValidationError as e:
raise ValueError(f"Validation failed: {e}") from None
def insert(self, data: dict | T) -> None:
item = self._validate(data)
self.table.put_item(Item=item)
upsert = insert
def delete(self, pk: Any, sk: Any | None = None) -> None:
key = {self.pk_attr: pk}
if self.sk_attr:
key[self.sk_attr] = sk
self.table.delete_item(Key=key)
def update(self, pk: Any, sk: Any | None = None, **changes: Any) -> None:
if not changes:
return
# Validate merged model to avoid invalid writes
existing = self.select(pk, sk)
if existing is None:
raise KeyError("Item not found")
merged = existing.dict()
merged.update(changes)
self._validate(merged) # Ensure valid update
# SET expression-style update
expr, names, values = [], {}, {}
for i, (k, v) in enumerate(changes.items()):
names[f"#k{i}"] = k
values[f":v{i}"] = v
expr.append(f"#k{i} = :v{i}")
update_expression = "SET " + ", ".join(expr)
key = {self.pk_attr: pk}
if self.sk_attr:
key[self.sk_attr] = sk
self.table.update_item(
Key=key,
UpdateExpression=update_expression,
ExpressionAttributeNames=names,
ExpressionAttributeValues=values,
)
def select(self, pk: Any, sk: Any | None = None) -> T | None:
key = {self.pk_attr: pk}
if self.sk_attr:
key[self.sk_attr] = sk
resp = self.table.get_item(Key=key)
return None if "Item" not in resp else self.model_cls(**resp["Item"])
def count(self, attr: str, value: Any) -> int:
resp = self.table.scan(
FilterExpression=Attr(attr).eq(value),
Select="COUNT",
)
return resp["Count"]
def select_partition(self, pk: Any) -> Iterable[T]:
resp = self.table.query(
KeyConditionExpression=Key(self.pk_attr).eq(pk)
)
return (self.model_cls(**i) for i in resp.get("Items", []))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment