from typing import TYPE_CHECKING, Optional import sqlalchemy as sa from numpy import pi from sqlalchemy import type_coerce from sqlalchemy.dialects.postgresql import ARRAY, ENUM from sqlalchemy.ext.declarative import declared_attr from sqlalchemy.orm import Mapped, deferred, mapped_column, relationship from sqlalchemy.types import String, TypeDecorator from railbird import config from railbird.datatypes import gql from . import query_builder_types as qbt from .base import Base, _qb if TYPE_CHECKING: from .shot import ShotModel # The default angle for determining if a shot is Left, Right, or Straight. DEFAULT_DIRECTION_ANGLE_THRESHOLD = 10 DEFAULT_DRAW_ANGLE = 100 DEFAULT_FOLLOW_ANGLE = 70 SpinTypeEnum = ENUM( gql.SpinTypeEnum, name="spin_type_enum", schema="railbird", ) PocketEnum = ENUM( gql.PocketEnum, name="pocket_enum", schema="railbird", ) WallTypeEnum = ENUM( gql.WallTypeEnum, name="wall_enum", schema="railbird", ) ShotDirectionEnum = ENUM( gql.ShotDirectionEnum, name="direction_enum", schema="railbird", ) class EnumType(TypeDecorator): impl = String cache_ok = True def __init__(self, enum_class): super().__init__() self.enum_class = enum_class def make_value(self, name): return type_coerce(name, self.enum_class) __call__ = make_value def process_result_value(self, value, dialect): if value is not None: return self.enum_class(value) return None DecoratedSpinType = EnumType(SpinTypeEnum) class CueObjectFeatures(Base): """Features that are defined when the cue ball collides with an object ball.""" __tablename__ = "cue_object_features" shot_id: Mapped[int] = mapped_column( sa.BIGINT, sa.ForeignKey("shot.id", onupdate="CASCADE", ondelete="CASCADE"), primary_key=True, nullable=False, ) cue_object_distance: Mapped[float] = mapped_column( sa.DECIMAL(precision=7, scale=3), index=True, info=_qb(), ) cue_object_angle: Mapped[float] = mapped_column( sa.DECIMAL(precision=6, scale=3), index=True, info=_qb(), ) cue_angle_after_object: Mapped[float] = mapped_column( sa.DECIMAL(precision=6, scale=3), nullable=True, index=True, info=_qb(), ) spin_type: Mapped[gql.SpinTypeEnum] = mapped_column( SpinTypeEnum, nullable=True, index=True, info=_qb( { "others": [ { "name": "spin_type_counts", "selectable_constructor": ( qbt.QueryBuilderEnumCountsSelectable.constructor_for_enum( gql.SpinTypeEnum ) ), } ] } ), ) cue_speed_after_object: Mapped[float] = mapped_column( sa.DECIMAL(precision=7, scale=3), nullable=True, index=True, info=_qb(), ) cue_ball_speed: Mapped[float] = mapped_column( sa.DECIMAL(precision=7, scale=3), index=True, info=_qb(), ) shot_direction: Mapped[gql.ShotDirectionEnum] = mapped_column( ShotDirectionEnum, index=True, info=_qb( { "others": [ { "name": "shot_direction_counts", "selectable_constructor": ( qbt.QueryBuilderEnumCountsSelectable.constructor_for_enum( gql.ShotDirectionEnum ) ), } ] } ), ) shot: Mapped["ShotModel"] = relationship( "ShotModel", back_populates="cue_object_features" ) DecoratedSpinType = EnumType(SpinTypeEnum) @classmethod def spin_type_by( cls, follow_angle_threshold=DEFAULT_FOLLOW_ANGLE, draw_angle_threshold=DEFAULT_DRAW_ANGLE, ): return sa.case( ( cls.cue_angle_after_object >= draw_angle_threshold, DecoratedSpinType("DRAW"), ), ( cls.cue_angle_after_object <= follow_angle_threshold, DecoratedSpinType("FOLLOW"), ), else_=DecoratedSpinType("CENTER"), ) @classmethod def is_straight_by( cls, angle_threshold=DEFAULT_DIRECTION_ANGLE_THRESHOLD ) -> sa.ColumnElement[bool]: return cls.cue_object_angle <= angle_threshold @classmethod def is_left_by( cls, angle_threshold=DEFAULT_DIRECTION_ANGLE_THRESHOLD ) -> sa.ColumnElement[bool]: return sa.and_( cls.shot_direction == gql.ShotDirectionEnum.LEFT, cls.cue_object_angle > angle_threshold, ) @classmethod def is_right_by( cls, angle_threshold=DEFAULT_DIRECTION_ANGLE_THRESHOLD ) -> sa.ColumnElement[bool]: return sa.and_( cls.shot_direction == gql.ShotDirectionEnum.RIGHT, cls.cue_object_angle > angle_threshold, ) @classmethod def __declare_last__(cls): cls.is_straight = deferred(cls.is_straight_by(), info=_qb()) cls.is_left = deferred(cls.is_left_by(), info=_qb()) cls.is_right = deferred(cls.is_right_by(), info=_qb()) # XXX: Commenting this out to get spin type to resolve the column. # Someone can revert this at some point but I couldnt figure it out using qb. # cls.spin_type = deferred(cls.spin_type_by()) class PocketingIntentionFeatures(Base): __tablename__ = "pocketing_intention_features" shot_id: Mapped[int] = mapped_column( sa.BIGINT, sa.ForeignKey("shot.id", onupdate="CASCADE", ondelete="CASCADE"), primary_key=True, nullable=False, ) target_pocket_angle: Mapped[Optional[float]] = mapped_column( sa.FLOAT, nullable=True, index=True, info=_qb() ) target_pocket_angle_direction: Mapped[Optional[gql.ShotDirectionEnum]] = ( mapped_column(ShotDirectionEnum, nullable=True, index=True, info=_qb()) ) backcut: Mapped[Optional[bool]] = mapped_column( sa.BOOLEAN, nullable=True, index=True, info=_qb() ) target_pocket_distance: Mapped[float] = mapped_column( sa.DECIMAL(precision=7, scale=3), index=True, info=_qb(), ) make: Mapped[bool] = mapped_column( sa.BOOLEAN, info={ "query_builder": { "others": [ { "name": "make_percentage", "selectable_constructor": qbt.QueryBuilderBoolProportionSelectable, } ] } }, index=True, nullable=True, ) intended_pocket_type: Mapped[gql.PocketEnum] = mapped_column( PocketEnum, index=True, info=_qb(), ) difficulty: Mapped[float] = mapped_column( sa.FLOAT, info={ "query_builder": { "others": [ { "name": "average_difficulty", "selectable_constructor": qbt.QueryBuilderAverageSelectable, } ] } }, index=True, nullable=True, ) difficulty_git_commit: Mapped[str] = mapped_column( sa.VARCHAR(length=150), nullable=True, default=config.git_commit_hash, ) shot: Mapped["ShotModel"] = relationship( "ShotModel", back_populates="pocketing_intention_features" ) class ErrorFeatures(Base): __tablename__ = "error_features" errors: Mapped[str] warnings: Mapped[str] shot_id: Mapped[int] = mapped_column( sa.BIGINT, sa.ForeignKey("shot.id", onupdate="CASCADE", ondelete="CASCADE"), primary_key=True, nullable=False, ) shot: Mapped["ShotModel"] = relationship( "ShotModel", back_populates="error_features" ) class BankFeatures(Base): __tablename__ = "bank_features" shot_id: Mapped[int] = mapped_column( sa.BIGINT, sa.ForeignKey("shot.id", onupdate="CASCADE", ondelete="CASCADE"), primary_key=True, nullable=False, ) walls_hit = mapped_column( ARRAY(WallTypeEnum), index=True, info=_qb( { "name": "bank_walls_hit", "filter_constructor": qbt.QueryBuilderRangeFilter, } ), ) bank_angle: Mapped[float] = mapped_column( sa.DECIMAL(precision=7, scale=3), index=True, info=_qb() ) distance: Mapped[float] = mapped_column( sa.DECIMAL(precision=7, scale=3), index=True, info=_qb(dict(name="bank_distance")), ) shot: Mapped["ShotModel"] = relationship( "ShotModel", back_populates="bank_features" ) class KickFeatures(Base): __tablename__ = "kick_features" shot_id: Mapped[int] = mapped_column( sa.BIGINT, sa.ForeignKey("shot.id", onupdate="CASCADE", ondelete="CASCADE"), primary_key=True, nullable=False, ) walls_hit = mapped_column(ARRAY(WallTypeEnum), index=True) angle: Mapped[float] = mapped_column( sa.DECIMAL(precision=7, scale=3), index=True, info=_qb(dict(name="kick_angle")), ) distance: Mapped[float] = mapped_column( sa.DECIMAL(precision=7, scale=3), index=True, info=_qb(dict(name="kick_distance")), ) shot: Mapped["ShotModel"] = relationship( "ShotModel", back_populates="kick_features" ) DEFAULT_OVER_UNDER_CUT_THRESHOLD = 5 class MissFeatures(Base): __tablename__ = "miss_features" shot_id: Mapped[int] = mapped_column( sa.BIGINT, sa.ForeignKey("shot.id", onupdate="CASCADE", ondelete="CASCADE"), primary_key=True, nullable=False, ) miss_angle: Mapped[float] = mapped_column( sa.DECIMAL(precision=7, scale=3), index=True, info=_qb(), ) @declared_attr def miss_angle_in_degrees(cls) -> Mapped[float]: return deferred(cls.miss_angle * (180.0 / pi), info=_qb()) # type: ignore shot: Mapped["ShotModel"] = relationship( "ShotModel", back_populates="miss_features" ) @classmethod def is_undercut_by( cls, angle_threshold=DEFAULT_OVER_UNDER_CUT_THRESHOLD ) -> sa.ColumnElement[bool]: return cls.miss_angle_in_degrees <= (-1 * angle_threshold) @classmethod def is_overcut_by( cls, angle_threshold=DEFAULT_OVER_UNDER_CUT_THRESHOLD ) -> sa.ColumnElement[bool]: return cls.miss_angle_in_degrees >= angle_threshold @classmethod def is_miss_in_direction_by( cls, angle_threshold=DEFAULT_OVER_UNDER_CUT_THRESHOLD, direction=gql.ShotDirectionEnum.LEFT, ): return ( sa.select( sa.or_( sa.and_( CueObjectFeatures.shot_direction == direction, cls.miss_angle_in_degrees >= angle_threshold, ), sa.and_( CueObjectFeatures.shot_direction != direction, cls.miss_angle_in_degrees <= -angle_threshold, ), ) ) .where(CueObjectFeatures.shot_id == cls.shot_id) .scalar_subquery() ) @classmethod def __declare_last__(cls): cls.is_overcut = deferred(cls.is_overcut_by(), info=_qb()) cls.is_undercut = deferred(cls.is_undercut_by(), info=_qb()) cls.is_left_miss = deferred( cls.is_miss_in_direction_by( direction=gql.ShotDirectionEnum.LEFT, ), info=_qb(), ) cls.is_right_miss = deferred( cls.is_miss_in_direction_by( direction=gql.ShotDirectionEnum.RIGHT, ), info=_qb(), )