Last active
November 10, 2025 19:37
-
-
Save haileyok/e48dae74e4d880a7a42ab7563e06d71c to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import base64 | |
| import logging | |
| from datetime import datetime, timedelta, timezone | |
| from enum import Enum | |
| from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence | |
| from google.cloud import bigquery | |
| from osprey.engine.query_language import parse_query_to_validated_ast | |
| from osprey.engine.query_language.ast_bigquery_translator import BQTranslator | |
| from osprey.worker.lib.singletons import CONFIG, ENGINE | |
| from osprey.worker.ui_api.osprey.lib.marshal import JsonBodyMarshaller | |
| from osprey.worker.ui_api.osprey.singletons import BIGQUERY | |
| from pydantic import BaseModel | |
| if TYPE_CHECKING: | |
| from .abilities import QueryFilterAbility | |
| logger = logging.getLogger('bigquery') | |
| # query timeout in seconds. should be low, since bigquery really shouldn't even take this long to return | |
| # a result, and if it does you probably did something wrong... | |
| DEFAULT_BIGQUERY_TIMEOUT = 30 | |
| class Ordering(str, Enum): | |
| ASCENDING = 'ASCENDING' | |
| DESCENDING = 'DESCENDING' | |
| class PaginatedScanResult(BaseModel): | |
| action_ids: List[int] | |
| next_page: Optional[str] | |
| class EntityFilter(BaseModel): | |
| id: str | |
| type: str | |
| feature_filters: Optional[List[str]] | |
| def to_sql_filter(self) -> str: | |
| """Convert entity filter to a SQL WHERE clause""" | |
| feature_to_entity_mapping = ENGINE.instance().get_feature_name_to_entity_type_mapping() | |
| filters = [ | |
| feature_name | |
| for feature_name, entity_type in feature_to_entity_mapping.items() | |
| if entity_type == self.type and (not self.feature_filters or feature_name in self.feature_filters) | |
| ] | |
| if not filters: | |
| return '1=1' | |
| # cast all feature columns to STRING to match @entity_id parameter type | |
| conds = [f'CAST({feature_name} AS STRING) = @entity_id' for feature_name in filters] | |
| return f'({" OR ".join(conds)})' | |
| class BaseBQQuery(BaseModel, JsonBodyMarshaller): | |
| start: datetime | |
| end: datetime | |
| query_filter: str | |
| entity_filter: Optional[EntityFilter] | |
| class Config: | |
| arbitrary_types_allowed = True | |
| underscore_attrs_are_private = True | |
| @property | |
| def _client(self) -> bigquery.Client: | |
| return BIGQUERY.instance().client | |
| @property | |
| def _dataset(self) -> str: | |
| return BIGQUERY.instance().dataset | |
| @property | |
| def _table(self) -> str: | |
| return BIGQUERY.instance().table | |
| def _build_base_query( | |
| self, | |
| select_clause: str, | |
| where_conds: List[str], | |
| group_by: Optional[str] = None, | |
| order_by: Optional[str] = None, | |
| limit: Optional[int] = None, | |
| ): | |
| """Creates us a SQL query from supplied parts""" | |
| parts = [f'SELECT {select_clause}', f'FROM `{self._dataset}.{self._table}`'] | |
| if where_conds: | |
| parts.append(f'WHERE {" AND ".join(where_conds)}') | |
| if group_by: | |
| parts.append(f'GROUP BY {group_by}') | |
| if order_by: | |
| parts.append(f'ORDER BY {order_by}') | |
| if limit: | |
| parts.append(f'LIMIT {limit}') | |
| return '\n'.join(parts) | |
| def _execute_query( | |
| self, | |
| query: str, | |
| query_params: Optional[List[bigquery.ScalarQueryParameter]] = None, | |
| # if i understand, this is what tells us if the perms are available? yolo for now, but need to figure this out | |
| query_filter_abilities: Sequence[Optional['QueryFilterAbility[Any, Any]']] = (), | |
| ) -> List[Dict[str, Any]]: | |
| """Execute the supplied query in BigQuery""" | |
| job_config = bigquery.QueryJobConfig( | |
| query_parameters=query_params or [], | |
| use_query_cache=True, | |
| ) | |
| try: | |
| query_job = self._client.query(query, job_config=job_config) | |
| results = query_job.result(timeout=DEFAULT_BIGQUERY_TIMEOUT) | |
| return [dict(row) for row in results] | |
| except Exception as e: | |
| logger.error(f'BigQuery query failed: {e}') | |
| raise | |
| def _get_where_conds(self) -> tuple[List[str], List[bigquery.ScalarQueryParameter]]: | |
| """Put together the WHERE conds for a query""" | |
| # create conds with the initial timestamp filter | |
| conds: List[str] = [ | |
| 'timestamp >= @start_time AND timestamp < @end_time', | |
| ] | |
| # add the start and end times to the parameters | |
| params: List[bigquery.ScalarQueryParameter] = [ | |
| bigquery.ScalarQueryParameter('start_time', 'TIMESTAMP', self.start), | |
| bigquery.ScalarQueryParameter('end_time', 'TIMESTAMP', self.end), | |
| ] | |
| # add an entity filter if we have one | |
| if self.entity_filter: | |
| conds.append(self.entity_filter.to_sql_filter()) | |
| params.append(bigquery.ScalarQueryParameter('entity_id', 'STRING', self.entity_filter.id)) | |
| if self.query_filter: | |
| translated_filter = self._parse_query_filter(self.query_filter) | |
| if translated_filter: | |
| conds.append(translated_filter['sql']) | |
| if 'params' in translated_filter: | |
| for param in translated_filter['params']: | |
| params.append(param) | |
| return conds, params | |
| def _parse_query_filter(self, query_filter: str) -> Optional[Dict[str, Any]]: | |
| """Parse the query filter into SQL""" | |
| if not query_filter: | |
| return None | |
| validated_sources = parse_query_to_validated_ast( | |
| query_filter, rules_sources=ENGINE.instance().execution_graph.validated_sources | |
| ) | |
| return BQTranslator(validated_sources=validated_sources).transform() | |
| class TimeseriesResultRow(BaseModel): | |
| """Format for timeseries result rows, to match what Druid produces and the UI expects""" | |
| timestamp: Any | |
| result: Dict[str, Any] | |
| class TimeseriesBQQuery(BaseBQQuery): | |
| granularity: str | |
| agg_dims: Optional[List[str]] = None | |
| def _get_time_bucket_sql(self) -> str: | |
| "Return the SQL granularity for the query's granularity" | |
| grans = { | |
| 'minute': 'TIMESTAMP_TRUNC(timestamp, MINUTE)', | |
| 'hour': 'TIMESTAMP_TRUNC(timestamp, HOUR)', | |
| 'day': 'TIMESTAMP_TRUNC(timestamp, DAY)', | |
| 'week': 'TIMESTAMP_TRUNC(timestamp, WEEK)', | |
| # TODO: month maybe, idk | |
| } | |
| return grans.get(self.granularity, 'TIMESTAMP_TRUNC(timestamp, HOUR)') | |
| def execute(self) -> List[TimeseriesResultRow]: | |
| bucket = self._get_time_bucket_sql() | |
| if self.agg_dims and self.entity_filter: | |
| agg_selects = [f'COUNTIF(CAST({dim} AS STRING) = @entity_id) AS {dim}' for dim in self.agg_dims] | |
| select_clause = f'{bucket} AS timestamp, ' + ', '.join(agg_selects) | |
| else: | |
| select_clause = f'{bucket} AS timestamp, COUNT(*) AS count' | |
| conds, params = self._get_where_conds() | |
| query = self._build_base_query( | |
| select_clause=select_clause, | |
| where_conds=conds, | |
| group_by='timestamp', | |
| order_by='timestamp', | |
| ) | |
| results = self._execute_query(query, params) | |
| # ui transforms | |
| transformed_results: List[TimeseriesResultRow] = [] | |
| for row in results: | |
| timestamp = row.pop('timestamp') | |
| transformed_results.append( | |
| TimeseriesResultRow( | |
| timestamp=timestamp, | |
| result=row, | |
| ) | |
| ) | |
| return transformed_results | |
| class PaginatedScanBigQueryQuery(BaseBQQuery): | |
| limit: int = 100 | |
| next_page: Optional[str] = None | |
| order: Ordering = Ordering.DESCENDING | |
| def execute( | |
| self, query_filter_abilities: Sequence[Optional['QueryFilterAbility[Any, Any]']] = () | |
| ) -> PaginatedScanResult: | |
| paginated_limit = self.limit + 1 | |
| conds, params = self._get_where_conds() | |
| if self.next_page: | |
| date_in_milliseconds = int(base64.b64decode(self.next_page.encode('utf-8'))) | |
| pagination_datetime = datetime.fromtimestamp(date_in_milliseconds // 1000, tz=timezone.utc) | |
| if self.order == Ordering.ASCENDING: | |
| conds.append('timestamp >= @page_cursor') | |
| else: | |
| conds.append('timestamp < @page_cursor') | |
| params.append(bigquery.ScalarQueryParameter('page_cursor', 'TIMESTAMP', pagination_datetime)) | |
| select_clause = 'action_id, timestamp' | |
| order_dir = 'ASC' if self.order == Ordering.ASCENDING else 'DESC' | |
| query = self._build_base_query( | |
| select_clause=select_clause, | |
| where_conds=conds, | |
| order_by=f'timestamp {order_dir}', | |
| limit=paginated_limit, | |
| ) | |
| results = self._execute_query(query, params, query_filter_abilities) | |
| if not results: | |
| return PaginatedScanResult(action_ids=[], next_page=None) | |
| next_page = None | |
| if len(results) == paginated_limit: | |
| last_row = results.pop() | |
| timestamp_ms = int(last_row['timestamp'].timestamp() * 1000) | |
| timestamp_string = str(timestamp_ms).encode('utf-8') | |
| next_page = base64.b64encode(timestamp_string).decode('utf-8') | |
| action_ids = [int(row['action_id']) for row in results] | |
| return PaginatedScanResult(action_ids=action_ids, next_page=next_page) | |
| class GroupByApproximateCountBigQueryQuery(BaseBQQuery): | |
| dim: str | |
| def execute(self) -> int: | |
| select_clause = f'APPROX_COUNT_DISTINCT({self.dim}) AS cardinality' | |
| conds, params = self._get_where_conds() | |
| query = self._build_base_query(select_clause=select_clause, where_conds=conds) | |
| results = self._execute_query(query, params) | |
| if results and 'cardinality' in results[0]: | |
| return int(results[0]['cardinality']) | |
| return -1 | |
| class DimensionData(BaseModel): | |
| count: int | |
| class Config: | |
| extra = 'allow' | |
| class PeriodData(BaseModel): | |
| timestamp: datetime | |
| result: List[DimensionData] | |
| class DimensionDifference(BaseModel): | |
| dimension_key: str | None | |
| current_count: int | |
| previous_count: int | |
| difference: int | |
| percentage_change: float | None | |
| class ComparisonData(BaseModel): | |
| differences: List[DimensionDifference] | |
| class TopNPoPResponse(BaseModel): | |
| current_period: List[PeriodData] | |
| previous_period: List[PeriodData] | None | |
| comparison: List[ComparisonData] | None | |
| class TopNBigQueryQuery(BaseBQQuery): | |
| dimension: str | |
| limit: int = 100 | |
| def execute( | |
| self, | |
| query_filter_abilities: Sequence[Optional['QueryFilterAbility[Any, Any]']] = (), | |
| calculate_previous_period: bool = True, | |
| ) -> TopNPoPResponse: | |
| current_results = self._execute_single_period( | |
| start=self.start, end=self.end, query_filter_abilities=query_filter_abilities | |
| ) | |
| sanitized_current_results = self._sanitize_results(current_results) | |
| if not calculate_previous_period: | |
| return TopNPoPResponse(current_period=sanitized_current_results, previous_period=None, comparison=None) | |
| period_duration = self.end - self.start | |
| previous_start = self.start - period_duration | |
| previous_end = self.start | |
| config = CONFIG.instance() | |
| # i wonder if this default is high? seems okay based on the size of the data we are querying... | |
| max_historical_query_window_days = config.get_int('MAX_HISTORICAL_QUERY_WINDOW_DAYS', 90) | |
| if previous_start.replace(tzinfo=timezone.utc) < ( | |
| datetime.now(timezone.utc) - timedelta(days=max_historical_query_window_days) | |
| ): | |
| return TopNPoPResponse(current_period=sanitized_current_results, previous_period=None, comparison=None) | |
| previous_results = self._execute_single_period( | |
| start=previous_start, end=previous_end, query_filter_abilities=query_filter_abilities | |
| ) | |
| sanitized_previous_results = self._sanitize_results(previous_results) | |
| pop_results = self._analyze_pop_results(sanitized_current_results, sanitized_previous_results) | |
| return pop_results | |
| def _execute_single_period( | |
| self, | |
| start: datetime, | |
| end: datetime, | |
| query_filter_abilities: Sequence[Optional['QueryFilterAbility[Any, Any]']] = (), | |
| ) -> List[Dict[str, Any]]: | |
| select_clause = f'{self.dimension}, COUNT(*) AS count' | |
| original_start, original_end = self.start, self.end | |
| self.start, self.end = start, end | |
| try: | |
| conds, params = self._get_where_conds() | |
| query = self._build_base_query( | |
| select_clause=select_clause, | |
| where_conds=conds, | |
| group_by=self.dimension, | |
| order_by='count DESC', | |
| limit=self.limit, | |
| ) | |
| results = self._execute_query(query, params, query_filter_abilities) | |
| return results | |
| finally: | |
| self.start, self.end = original_start, original_end | |
| def _sanitize_results(self, results: List[Dict[str, Any]]) -> List[PeriodData]: | |
| if not results: | |
| return [] | |
| dimension_data = [] | |
| for result in results: | |
| try: | |
| dimension_value = result.get(self.dimension) | |
| count = result.get('count', 0) | |
| data_dict = {'count': count, self.dimension: dimension_value} | |
| dimension_data.append(DimensionData(**data_dict)) | |
| except Exception as e: | |
| logger.error(f'Failed to parse result: {result}, error: {e}') | |
| continue | |
| if dimension_data: | |
| return [PeriodData(timestamp=self.end, result=dimension_data)] | |
| return [] | |
| # slop code that seems to work if i use the pop feature | |
| def _analyze_pop_results( | |
| self, current_results: List[PeriodData], previous_results: List[PeriodData] | |
| ) -> TopNPoPResponse: | |
| if not previous_results: | |
| return TopNPoPResponse(current_period=current_results, previous_period=None, comparison=None) | |
| dimension_key = self.dimension | |
| comparison = [] | |
| for current_result, previous_result in zip(current_results, previous_results): | |
| current_map = {getattr(item, dimension_key): item.count for item in current_result.result} | |
| previous_map = {getattr(item, dimension_key): item.count for item in previous_result.result} | |
| dimension_differences = [] | |
| all_keys = set(current_map.keys()) | set(previous_map.keys()) | |
| for item in all_keys: | |
| current_count = current_map.get(item, 0) | |
| previous_count = previous_map.get(item, 0) | |
| if current_count == 0: | |
| continue | |
| difference = current_count - previous_count | |
| pct_change = (difference / previous_count * 100) if previous_count else None | |
| dimension_differences.append( | |
| DimensionDifference( | |
| dimension_key=item, | |
| current_count=current_count, | |
| previous_count=previous_count, | |
| difference=difference, | |
| percentage_change=pct_change, | |
| ) | |
| ) | |
| comparison.append(ComparisonData(differences=dimension_differences)) | |
| return TopNPoPResponse( | |
| current_period=current_results, | |
| previous_period=previous_results, | |
| comparison=comparison, | |
| ) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment