Skip to content

Instantly share code, notes, and snippets.

@colonelpanic8
Created October 23, 2024 01:00
Show Gist options
  • Save colonelpanic8/34c2c8fdba0de185003b4637e9e1cadc to your computer and use it in GitHub Desktop.
Save colonelpanic8/34c2c8fdba0de185003b4637e9e1cadc to your computer and use it in GitHub Desktop.

Revisions

  1. colonelpanic8 created this gist Oct 23, 2024.
    434 changes: 434 additions & 0 deletions new_features.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,434 @@
    from typing import TYPE_CHECKING, Optional

    import sqlalchemy as sa
    from numpy import pi
    from sqlalchemy import type_coerce
    from sqlalchemy.dialects.postgresql import ARRAY, ENUM
    from sqlalchemy.ext.declarative import declared_attr
    from sqlalchemy.orm import Mapped, deferred, mapped_column, relationship
    from sqlalchemy.types import String, TypeDecorator

    from railbird import config
    from railbird.datatypes import gql

    from . import query_builder_types as qbt
    from .base import Base, _qb

    if TYPE_CHECKING:
    from .shot import ShotModel

    # The default angle for determining if a shot is Left, Right, or Straight.
    DEFAULT_DIRECTION_ANGLE_THRESHOLD = 10
    DEFAULT_DRAW_ANGLE = 100
    DEFAULT_FOLLOW_ANGLE = 70

    SpinTypeEnum = ENUM(
    gql.SpinTypeEnum,
    name="spin_type_enum",
    schema="railbird",
    )

    PocketEnum = ENUM(
    gql.PocketEnum,
    name="pocket_enum",
    schema="railbird",
    )

    WallTypeEnum = ENUM(
    gql.WallTypeEnum,
    name="wall_enum",
    schema="railbird",
    )

    ShotDirectionEnum = ENUM(
    gql.ShotDirectionEnum,
    name="direction_enum",
    schema="railbird",
    )


    class EnumType(TypeDecorator):
    impl = String
    cache_ok = True

    def __init__(self, enum_class):
    super().__init__()
    self.enum_class = enum_class

    def make_value(self, name):
    return type_coerce(name, self.enum_class)

    __call__ = make_value

    def process_result_value(self, value, dialect):
    if value is not None:
    return self.enum_class(value)
    return None


    DecoratedSpinType = EnumType(SpinTypeEnum)


    class CueObjectFeatures(Base):
    """Features that are defined when the cue ball collides with an object ball."""

    __tablename__ = "cue_object_features"

    shot_id: Mapped[int] = mapped_column(
    sa.BIGINT,
    sa.ForeignKey("shot.id", onupdate="CASCADE", ondelete="CASCADE"),
    primary_key=True,
    nullable=False,
    )
    cue_object_distance: Mapped[float] = mapped_column(
    sa.DECIMAL(precision=7, scale=3),
    index=True,
    info=_qb(),
    )
    cue_object_angle: Mapped[float] = mapped_column(
    sa.DECIMAL(precision=6, scale=3),
    index=True,
    info=_qb(),
    )
    cue_angle_after_object: Mapped[float] = mapped_column(
    sa.DECIMAL(precision=6, scale=3),
    nullable=True,
    index=True,
    info=_qb(),
    )
    spin_type: Mapped[gql.SpinTypeEnum] = mapped_column(
    SpinTypeEnum,
    nullable=True,
    index=True,
    info=_qb(
    {
    "others": [
    {
    "name": "spin_type_counts",
    "selectable_constructor": (
    qbt.QueryBuilderEnumCountsSelectable.constructor_for_enum(
    gql.SpinTypeEnum
    )
    ),
    }
    ]
    }
    ),
    )
    cue_speed_after_object: Mapped[float] = mapped_column(
    sa.DECIMAL(precision=7, scale=3),
    nullable=True,
    index=True,
    info=_qb(),
    )
    cue_ball_speed: Mapped[float] = mapped_column(
    sa.DECIMAL(precision=7, scale=3),
    index=True,
    info=_qb(),
    )
    shot_direction: Mapped[gql.ShotDirectionEnum] = mapped_column(
    ShotDirectionEnum,
    index=True,
    info=_qb(
    {
    "others": [
    {
    "name": "shot_direction_counts",
    "selectable_constructor": (
    qbt.QueryBuilderEnumCountsSelectable.constructor_for_enum(
    gql.ShotDirectionEnum
    )
    ),
    }
    ]
    }
    ),
    )
    shot: Mapped["ShotModel"] = relationship(
    "ShotModel", back_populates="cue_object_features"
    )

    DecoratedSpinType = EnumType(SpinTypeEnum)

    @classmethod
    def spin_type_by(
    cls,
    follow_angle_threshold=DEFAULT_FOLLOW_ANGLE,
    draw_angle_threshold=DEFAULT_DRAW_ANGLE,
    ):
    return sa.case(
    (
    cls.cue_angle_after_object >= draw_angle_threshold,
    DecoratedSpinType("DRAW"),
    ),
    (
    cls.cue_angle_after_object <= follow_angle_threshold,
    DecoratedSpinType("FOLLOW"),
    ),
    else_=DecoratedSpinType("CENTER"),
    )

    @classmethod
    def is_straight_by(
    cls, angle_threshold=DEFAULT_DIRECTION_ANGLE_THRESHOLD
    ) -> sa.ColumnElement[bool]:
    return cls.cue_object_angle <= angle_threshold

    @classmethod
    def is_left_by(
    cls, angle_threshold=DEFAULT_DIRECTION_ANGLE_THRESHOLD
    ) -> sa.ColumnElement[bool]:
    return sa.and_(
    cls.shot_direction == gql.ShotDirectionEnum.LEFT,
    cls.cue_object_angle > angle_threshold,
    )

    @classmethod
    def is_right_by(
    cls, angle_threshold=DEFAULT_DIRECTION_ANGLE_THRESHOLD
    ) -> sa.ColumnElement[bool]:
    return sa.and_(
    cls.shot_direction == gql.ShotDirectionEnum.RIGHT,
    cls.cue_object_angle > angle_threshold,
    )

    @classmethod
    def __declare_last__(cls):
    cls.is_straight = deferred(cls.is_straight_by(), info=_qb())
    cls.is_left = deferred(cls.is_left_by(), info=_qb())
    cls.is_right = deferred(cls.is_right_by(), info=_qb())
    # XXX: Commenting this out to get spin type to resolve the column.
    # Someone can revert this at some point but I couldnt figure it out using qb.
    # cls.spin_type = deferred(cls.spin_type_by())


    class PocketingIntentionFeatures(Base):
    __tablename__ = "pocketing_intention_features"

    shot_id: Mapped[int] = mapped_column(
    sa.BIGINT,
    sa.ForeignKey("shot.id", onupdate="CASCADE", ondelete="CASCADE"),
    primary_key=True,
    nullable=False,
    )

    target_pocket_angle: Mapped[Optional[float]] = mapped_column(
    sa.FLOAT, nullable=True, index=True, info=_qb()
    )

    target_pocket_angle_direction: Mapped[Optional[gql.ShotDirectionEnum]] = (
    mapped_column(ShotDirectionEnum, nullable=True, index=True, info=_qb())
    )

    backcut: Mapped[Optional[bool]] = mapped_column(
    sa.BOOLEAN, nullable=True, index=True, info=_qb()
    )

    target_pocket_distance: Mapped[float] = mapped_column(
    sa.DECIMAL(precision=7, scale=3),
    index=True,
    info=_qb(),
    )

    make: Mapped[bool] = mapped_column(
    sa.BOOLEAN,
    info={
    "query_builder": {
    "others": [
    {
    "name": "make_percentage",
    "selectable_constructor": qbt.QueryBuilderBoolProportionSelectable,
    }
    ]
    }
    },
    index=True,
    nullable=True,
    )
    intended_pocket_type: Mapped[gql.PocketEnum] = mapped_column(
    PocketEnum,
    index=True,
    info=_qb(),
    )

    difficulty: Mapped[float] = mapped_column(
    sa.FLOAT,
    info={
    "query_builder": {
    "others": [
    {
    "name": "average_difficulty",
    "selectable_constructor": qbt.QueryBuilderAverageSelectable,
    }
    ]
    }
    },
    index=True,
    nullable=True,
    )

    difficulty_git_commit: Mapped[str] = mapped_column(
    sa.VARCHAR(length=150),
    nullable=True,
    default=config.git_commit_hash,
    )

    shot: Mapped["ShotModel"] = relationship(
    "ShotModel", back_populates="pocketing_intention_features"
    )


    class ErrorFeatures(Base):
    __tablename__ = "error_features"

    errors: Mapped[str]
    warnings: Mapped[str]

    shot_id: Mapped[int] = mapped_column(
    sa.BIGINT,
    sa.ForeignKey("shot.id", onupdate="CASCADE", ondelete="CASCADE"),
    primary_key=True,
    nullable=False,
    )

    shot: Mapped["ShotModel"] = relationship(
    "ShotModel", back_populates="error_features"
    )


    class BankFeatures(Base):
    __tablename__ = "bank_features"

    shot_id: Mapped[int] = mapped_column(
    sa.BIGINT,
    sa.ForeignKey("shot.id", onupdate="CASCADE", ondelete="CASCADE"),
    primary_key=True,
    nullable=False,
    )

    walls_hit = mapped_column(
    ARRAY(WallTypeEnum),
    index=True,
    info=_qb(
    {
    "name": "bank_walls_hit",
    "filter_constructor": qbt.QueryBuilderRangeFilter,
    }
    ),
    )

    bank_angle: Mapped[float] = mapped_column(
    sa.DECIMAL(precision=7, scale=3), index=True, info=_qb()
    )
    distance: Mapped[float] = mapped_column(
    sa.DECIMAL(precision=7, scale=3),
    index=True,
    info=_qb(dict(name="bank_distance")),
    )

    shot: Mapped["ShotModel"] = relationship(
    "ShotModel", back_populates="bank_features"
    )


    class KickFeatures(Base):
    __tablename__ = "kick_features"

    shot_id: Mapped[int] = mapped_column(
    sa.BIGINT,
    sa.ForeignKey("shot.id", onupdate="CASCADE", ondelete="CASCADE"),
    primary_key=True,
    nullable=False,
    )
    walls_hit = mapped_column(ARRAY(WallTypeEnum), index=True)
    angle: Mapped[float] = mapped_column(
    sa.DECIMAL(precision=7, scale=3),
    index=True,
    info=_qb(dict(name="kick_angle")),
    )
    distance: Mapped[float] = mapped_column(
    sa.DECIMAL(precision=7, scale=3),
    index=True,
    info=_qb(dict(name="kick_distance")),
    )
    shot: Mapped["ShotModel"] = relationship(
    "ShotModel", back_populates="kick_features"
    )


    DEFAULT_OVER_UNDER_CUT_THRESHOLD = 5


    class MissFeatures(Base):
    __tablename__ = "miss_features"
    shot_id: Mapped[int] = mapped_column(
    sa.BIGINT,
    sa.ForeignKey("shot.id", onupdate="CASCADE", ondelete="CASCADE"),
    primary_key=True,
    nullable=False,
    )
    miss_angle: Mapped[float] = mapped_column(
    sa.DECIMAL(precision=7, scale=3),
    index=True,
    info=_qb(),
    )

    @declared_attr
    def miss_angle_in_degrees(cls) -> Mapped[float]:
    return deferred(cls.miss_angle * (180.0 / pi), info=_qb()) # type: ignore

    shot: Mapped["ShotModel"] = relationship(
    "ShotModel", back_populates="miss_features"
    )

    @classmethod
    def is_undercut_by(
    cls, angle_threshold=DEFAULT_OVER_UNDER_CUT_THRESHOLD
    ) -> sa.ColumnElement[bool]:
    return cls.miss_angle_in_degrees <= (-1 * angle_threshold)

    @classmethod
    def is_overcut_by(
    cls, angle_threshold=DEFAULT_OVER_UNDER_CUT_THRESHOLD
    ) -> sa.ColumnElement[bool]:
    return cls.miss_angle_in_degrees >= angle_threshold

    @classmethod
    def is_miss_in_direction_by(
    cls,
    angle_threshold=DEFAULT_OVER_UNDER_CUT_THRESHOLD,
    direction=gql.ShotDirectionEnum.LEFT,
    ):
    return (
    sa.select(
    sa.or_(
    sa.and_(
    CueObjectFeatures.shot_direction == direction,
    cls.miss_angle_in_degrees >= angle_threshold,
    ),
    sa.and_(
    CueObjectFeatures.shot_direction != direction,
    cls.miss_angle_in_degrees <= -angle_threshold,
    ),
    )
    )
    .where(CueObjectFeatures.shot_id == cls.shot_id)
    .scalar_subquery()
    )

    @classmethod
    def __declare_last__(cls):
    cls.is_overcut = deferred(cls.is_overcut_by(), info=_qb())
    cls.is_undercut = deferred(cls.is_undercut_by(), info=_qb())
    cls.is_left_miss = deferred(
    cls.is_miss_in_direction_by(
    direction=gql.ShotDirectionEnum.LEFT,
    ),
    info=_qb(),
    )
    cls.is_right_miss = deferred(
    cls.is_miss_in_direction_by(
    direction=gql.ShotDirectionEnum.RIGHT,
    ),
    info=_qb(),
    )