import logging from typing import Any, Dict, Generic, List from flask_appbuilder import Model from marshmallow import Schema, fields from superset import security_manager from superset.connectors.sqla.models import SqlaTable from superset.datasets.dao import DatasetDAO from superset.datasets.commands.exceptions import ( DatasetForbiddenError, DatasetNotFoundError, DatasetRefreshFailedError, ) from superset.exceptions import SupersetSecurityException logger = logging.getLogger(__name__) class CommandSchema(Schema): number = fields.Integer() class NewCommand(Generic[Model]): schema = CommandSchema() def __init__(self, deserialized: Dict[str, Any]): self.models = self.load_models(deserialized) @classmethod def from_serialized(self, serialized: Dict[str, Any]) -> 'NewCommand': """ Instantiate the command from a serialized dictionary. """ deserialized = self.schema.load(serialized) return NewCommand.from_deserialized(deserialized) @classmethod def from_deserialized(self, deserialized: Dict[str, Any]) -> 'NewCommand': """ Instantiate the command from a deserialized dictionary. This is useful when the command is called from an API endpoint, since the API can deserialize the request body into a dictionary and call the Marshmallow schema validation. """ return NewCommand(deserialized) def load_models(self, deserialized: Dict[str, Any]) -> List[Model]: """ Load models from the deserialized dictionary. Even the command works on a single model it should return a list with a single element. """ raise NotImplementedError("Subclasses must implement load_models") def run(self) -> Any: raise NotImplementedError("Subclasses must implement run") class RefreshDatasetSchema(Schema): pk = fields.Integer(required=True) class RefreshDatasetCommand(NewCommand[SqlaTable]): schema = RefreshDatasetSchema() def load_models(self, deserialized: Dict[str, Any]) -> List[SqlaTable]: model = DatasetDAO.find_by_id(deserialized['pk']) if not model: raise DatasetNotFoundError() # check ownership try: security_manager.raise_for_ownership(model) except SupersetSecurityException as ex: raise DatasetForbiddenError() from ex return [model] def run(self) -> SqlaTable: model = self.models[0] try: model.fetch_metadata() return model except Exception as ex: logger.exception( "An error occurred while fetching dataset metadata" ) raise DatasetRefreshFailedError() from ex