Created
          October 23, 2024 01:00 
        
      - 
      
- 
        Save colonelpanic8/34c2c8fdba0de185003b4637e9e1cadc to your computer and use it in GitHub Desktop. 
Revisions
- 
        colonelpanic8 created this gist Oct 23, 2024 .There are no files selected for viewingThis file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,434 @@ from typing import TYPE_CHECKING, Optional import sqlalchemy as sa from numpy import pi from sqlalchemy import type_coerce from sqlalchemy.dialects.postgresql import ARRAY, ENUM from sqlalchemy.ext.declarative import declared_attr from sqlalchemy.orm import Mapped, deferred, mapped_column, relationship from sqlalchemy.types import String, TypeDecorator from railbird import config from railbird.datatypes import gql from . import query_builder_types as qbt from .base import Base, _qb if TYPE_CHECKING: from .shot import ShotModel # The default angle for determining if a shot is Left, Right, or Straight. DEFAULT_DIRECTION_ANGLE_THRESHOLD = 10 DEFAULT_DRAW_ANGLE = 100 DEFAULT_FOLLOW_ANGLE = 70 SpinTypeEnum = ENUM( gql.SpinTypeEnum, name="spin_type_enum", schema="railbird", ) PocketEnum = ENUM( gql.PocketEnum, name="pocket_enum", schema="railbird", ) WallTypeEnum = ENUM( gql.WallTypeEnum, name="wall_enum", schema="railbird", ) ShotDirectionEnum = ENUM( gql.ShotDirectionEnum, name="direction_enum", schema="railbird", ) class EnumType(TypeDecorator): impl = String cache_ok = True def __init__(self, enum_class): super().__init__() self.enum_class = enum_class def make_value(self, name): return type_coerce(name, self.enum_class) __call__ = make_value def process_result_value(self, value, dialect): if value is not None: return self.enum_class(value) return None DecoratedSpinType = EnumType(SpinTypeEnum) class CueObjectFeatures(Base): """Features that are defined when the cue ball collides with an object ball.""" __tablename__ = "cue_object_features" shot_id: Mapped[int] = mapped_column( sa.BIGINT, sa.ForeignKey("shot.id", onupdate="CASCADE", ondelete="CASCADE"), primary_key=True, nullable=False, ) cue_object_distance: Mapped[float] = mapped_column( sa.DECIMAL(precision=7, scale=3), index=True, info=_qb(), ) cue_object_angle: Mapped[float] = mapped_column( sa.DECIMAL(precision=6, scale=3), index=True, info=_qb(), ) cue_angle_after_object: Mapped[float] = mapped_column( sa.DECIMAL(precision=6, scale=3), nullable=True, index=True, info=_qb(), ) spin_type: Mapped[gql.SpinTypeEnum] = mapped_column( SpinTypeEnum, nullable=True, index=True, info=_qb( { "others": [ { "name": "spin_type_counts", "selectable_constructor": ( qbt.QueryBuilderEnumCountsSelectable.constructor_for_enum( gql.SpinTypeEnum ) ), } ] } ), ) cue_speed_after_object: Mapped[float] = mapped_column( sa.DECIMAL(precision=7, scale=3), nullable=True, index=True, info=_qb(), ) cue_ball_speed: Mapped[float] = mapped_column( sa.DECIMAL(precision=7, scale=3), index=True, info=_qb(), ) shot_direction: Mapped[gql.ShotDirectionEnum] = mapped_column( ShotDirectionEnum, index=True, info=_qb( { "others": [ { "name": "shot_direction_counts", "selectable_constructor": ( qbt.QueryBuilderEnumCountsSelectable.constructor_for_enum( gql.ShotDirectionEnum ) ), } ] } ), ) shot: Mapped["ShotModel"] = relationship( "ShotModel", back_populates="cue_object_features" ) DecoratedSpinType = EnumType(SpinTypeEnum) @classmethod def spin_type_by( cls, follow_angle_threshold=DEFAULT_FOLLOW_ANGLE, draw_angle_threshold=DEFAULT_DRAW_ANGLE, ): return sa.case( ( cls.cue_angle_after_object >= draw_angle_threshold, DecoratedSpinType("DRAW"), ), ( cls.cue_angle_after_object <= follow_angle_threshold, DecoratedSpinType("FOLLOW"), ), else_=DecoratedSpinType("CENTER"), ) @classmethod def is_straight_by( cls, angle_threshold=DEFAULT_DIRECTION_ANGLE_THRESHOLD ) -> sa.ColumnElement[bool]: return cls.cue_object_angle <= angle_threshold @classmethod def is_left_by( cls, angle_threshold=DEFAULT_DIRECTION_ANGLE_THRESHOLD ) -> sa.ColumnElement[bool]: return sa.and_( cls.shot_direction == gql.ShotDirectionEnum.LEFT, cls.cue_object_angle > angle_threshold, ) @classmethod def is_right_by( cls, angle_threshold=DEFAULT_DIRECTION_ANGLE_THRESHOLD ) -> sa.ColumnElement[bool]: return sa.and_( cls.shot_direction == gql.ShotDirectionEnum.RIGHT, cls.cue_object_angle > angle_threshold, ) @classmethod def __declare_last__(cls): cls.is_straight = deferred(cls.is_straight_by(), info=_qb()) cls.is_left = deferred(cls.is_left_by(), info=_qb()) cls.is_right = deferred(cls.is_right_by(), info=_qb()) # XXX: Commenting this out to get spin type to resolve the column. # Someone can revert this at some point but I couldnt figure it out using qb. # cls.spin_type = deferred(cls.spin_type_by()) class PocketingIntentionFeatures(Base): __tablename__ = "pocketing_intention_features" shot_id: Mapped[int] = mapped_column( sa.BIGINT, sa.ForeignKey("shot.id", onupdate="CASCADE", ondelete="CASCADE"), primary_key=True, nullable=False, ) target_pocket_angle: Mapped[Optional[float]] = mapped_column( sa.FLOAT, nullable=True, index=True, info=_qb() ) target_pocket_angle_direction: Mapped[Optional[gql.ShotDirectionEnum]] = ( mapped_column(ShotDirectionEnum, nullable=True, index=True, info=_qb()) ) backcut: Mapped[Optional[bool]] = mapped_column( sa.BOOLEAN, nullable=True, index=True, info=_qb() ) target_pocket_distance: Mapped[float] = mapped_column( sa.DECIMAL(precision=7, scale=3), index=True, info=_qb(), ) make: Mapped[bool] = mapped_column( sa.BOOLEAN, info={ "query_builder": { "others": [ { "name": "make_percentage", "selectable_constructor": qbt.QueryBuilderBoolProportionSelectable, } ] } }, index=True, nullable=True, ) intended_pocket_type: Mapped[gql.PocketEnum] = mapped_column( PocketEnum, index=True, info=_qb(), ) difficulty: Mapped[float] = mapped_column( sa.FLOAT, info={ "query_builder": { "others": [ { "name": "average_difficulty", "selectable_constructor": qbt.QueryBuilderAverageSelectable, } ] } }, index=True, nullable=True, ) difficulty_git_commit: Mapped[str] = mapped_column( sa.VARCHAR(length=150), nullable=True, default=config.git_commit_hash, ) shot: Mapped["ShotModel"] = relationship( "ShotModel", back_populates="pocketing_intention_features" ) class ErrorFeatures(Base): __tablename__ = "error_features" errors: Mapped[str] warnings: Mapped[str] shot_id: Mapped[int] = mapped_column( sa.BIGINT, sa.ForeignKey("shot.id", onupdate="CASCADE", ondelete="CASCADE"), primary_key=True, nullable=False, ) shot: Mapped["ShotModel"] = relationship( "ShotModel", back_populates="error_features" ) class BankFeatures(Base): __tablename__ = "bank_features" shot_id: Mapped[int] = mapped_column( sa.BIGINT, sa.ForeignKey("shot.id", onupdate="CASCADE", ondelete="CASCADE"), primary_key=True, nullable=False, ) walls_hit = mapped_column( ARRAY(WallTypeEnum), index=True, info=_qb( { "name": "bank_walls_hit", "filter_constructor": qbt.QueryBuilderRangeFilter, } ), ) bank_angle: Mapped[float] = mapped_column( sa.DECIMAL(precision=7, scale=3), index=True, info=_qb() ) distance: Mapped[float] = mapped_column( sa.DECIMAL(precision=7, scale=3), index=True, info=_qb(dict(name="bank_distance")), ) shot: Mapped["ShotModel"] = relationship( "ShotModel", back_populates="bank_features" ) class KickFeatures(Base): __tablename__ = "kick_features" shot_id: Mapped[int] = mapped_column( sa.BIGINT, sa.ForeignKey("shot.id", onupdate="CASCADE", ondelete="CASCADE"), primary_key=True, nullable=False, ) walls_hit = mapped_column(ARRAY(WallTypeEnum), index=True) angle: Mapped[float] = mapped_column( sa.DECIMAL(precision=7, scale=3), index=True, info=_qb(dict(name="kick_angle")), ) distance: Mapped[float] = mapped_column( sa.DECIMAL(precision=7, scale=3), index=True, info=_qb(dict(name="kick_distance")), ) shot: Mapped["ShotModel"] = relationship( "ShotModel", back_populates="kick_features" ) DEFAULT_OVER_UNDER_CUT_THRESHOLD = 5 class MissFeatures(Base): __tablename__ = "miss_features" shot_id: Mapped[int] = mapped_column( sa.BIGINT, sa.ForeignKey("shot.id", onupdate="CASCADE", ondelete="CASCADE"), primary_key=True, nullable=False, ) miss_angle: Mapped[float] = mapped_column( sa.DECIMAL(precision=7, scale=3), index=True, info=_qb(), ) @declared_attr def miss_angle_in_degrees(cls) -> Mapped[float]: return deferred(cls.miss_angle * (180.0 / pi), info=_qb()) # type: ignore shot: Mapped["ShotModel"] = relationship( "ShotModel", back_populates="miss_features" ) @classmethod def is_undercut_by( cls, angle_threshold=DEFAULT_OVER_UNDER_CUT_THRESHOLD ) -> sa.ColumnElement[bool]: return cls.miss_angle_in_degrees <= (-1 * angle_threshold) @classmethod def is_overcut_by( cls, angle_threshold=DEFAULT_OVER_UNDER_CUT_THRESHOLD ) -> sa.ColumnElement[bool]: return cls.miss_angle_in_degrees >= angle_threshold @classmethod def is_miss_in_direction_by( cls, angle_threshold=DEFAULT_OVER_UNDER_CUT_THRESHOLD, direction=gql.ShotDirectionEnum.LEFT, ): return ( sa.select( sa.or_( sa.and_( CueObjectFeatures.shot_direction == direction, cls.miss_angle_in_degrees >= angle_threshold, ), sa.and_( CueObjectFeatures.shot_direction != direction, cls.miss_angle_in_degrees <= -angle_threshold, ), ) ) .where(CueObjectFeatures.shot_id == cls.shot_id) .scalar_subquery() ) @classmethod def __declare_last__(cls): cls.is_overcut = deferred(cls.is_overcut_by(), info=_qb()) cls.is_undercut = deferred(cls.is_undercut_by(), info=_qb()) cls.is_left_miss = deferred( cls.is_miss_in_direction_by( direction=gql.ShotDirectionEnum.LEFT, ), info=_qb(), ) cls.is_right_miss = deferred( cls.is_miss_in_direction_by( direction=gql.ShotDirectionEnum.RIGHT, ), info=_qb(), )