Created
March 23, 2025 01:59
-
-
Save jim-my/12845dc8b71efe70df306efd798e3254 to your computer and use it in GitHub Desktop.
Format SQL query using sqlglot library(sqlfluff might be a better option sometimes?)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python | |
| # -*- coding: utf-8 -*- | |
| """Format SQL query using sqlglot library.""" | |
| import argparse | |
| import sys | |
| import typing as t | |
| import sqlglot | |
| import sqlglot.expressions as exp | |
| import sqlparse | |
| class MyFormatter(sqlglot.Generator): | |
| """Custom formatter for SQL query.""" | |
| def __init__(self, _original_object): | |
| self._original_object = _original_object | |
| def __getattr__(self, name): | |
| """Delegate attribute access to the original object.""" | |
| return getattr(self._original_object, name) | |
| def maybe_comment( | |
| self, | |
| sql: str, | |
| expression: t.Optional[exp.Expression] = None, | |
| comments: t.Optional[t.List[str]] = None, | |
| separated: bool = False, | |
| ) -> str: | |
| comments = ( | |
| ((expression and expression.comments) if comments is None else comments) # type: ignore | |
| if self.comments | |
| else None | |
| ) | |
| if not comments or isinstance(expression, self.EXCLUDE_COMMENTS): | |
| return sql | |
| if len(comments) > 1: # multiple-line comment | |
| comments_sql = ( | |
| "/*\n *" | |
| + "\n *".join( | |
| f"{self.pad_comment(comment)}" for comment in comments if comment | |
| ) | |
| + "\n */" | |
| ) | |
| else: # single-line comment | |
| comments_sql = "\n".join( | |
| f"-- {self.pad_comment(comment)}" for comment in comments if comment | |
| ) | |
| if not comments_sql: | |
| return sql | |
| comments_sql = self._replace_line_breaks(comments_sql) | |
| if separated or isinstance(expression, self.WITH_SEPARATED_COMMENTS): | |
| return ( | |
| f"{self.sep()}{comments_sql}{sql}" | |
| if not sql or sql[0].isspace() | |
| else f"{comments_sql}{self.sep()}{sql}" | |
| ) | |
| return f"{sql} {comments_sql}" | |
| def format_sql(sql, dialect="redshift"): | |
| """ | |
| Format SQL query using sqlglot/sqlparse library. | |
| # REF: https://github.com/tobymao/sqlglot/blob/32a86d38b7935bb04644ef2ebc07589d5e040e34/sqlglot/generator.py#L40 | |
| Args(for ast.sql() below): | |
| pretty: Whether or not to format the produced SQL string. | |
| Default: False. | |
| identify: Determines when an identifier should be quoted. Possible values are: | |
| False (default): Never quote, except in cases where it's mandatory by the dialect. | |
| True or 'always': Always quote. | |
| 'safe': Only quote identifiers that are case insensitive. | |
| normalize: Whether or not to normalize identifiers to lowercase. | |
| Default: False. | |
| pad: Determines the pad size in a formatted string. | |
| Default: 2. | |
| indent: Determines the indentation size in a formatted string. | |
| Default: 2. | |
| normalize_functions: Whether or not to normalize all function names. Possible values are: | |
| "upper" or True (default): Convert names to uppercase. | |
| "lower": Convert names to lowercase. | |
| False: Disables function name normalization. | |
| unsupported_level: Determines the generator's behavior when it encounters unsupported expressions. | |
| Default ErrorLevel.WARN. | |
| max_unsupported: Maximum number of unsupported messages to include in a raised UnsupportedError. | |
| This is only relevant if unsupported_level is ErrorLevel.RAISE. | |
| Default: 3 | |
| leading_comma: Determines whether or not the comma is leading or trailing in select expressions. | |
| This is only relevant when generating in pretty mode. | |
| Default: False | |
| max_text_width: The max number of characters in a segment before creating new lines in pretty mode. | |
| The default is on the smaller end because the length only represents a segment and not the true | |
| line length. | |
| Default: 80 | |
| comments: Whether or not to preserve comments in the output SQL code. | |
| Default: True | |
| """ | |
| ast: sqlglot.Expression = sqlglot.parse_one(sql, dialect=dialect) | |
| generator_opts = { | |
| "identify": False, | |
| "pretty": True, | |
| "indent": 4, | |
| "pad": 4, # E.g. padding before c1 and c2 for `echo "select c1, c2" | sqlglot-format.py` | |
| "leading_comma": True, | |
| # max_text_width=80, | |
| "comments": True, | |
| "normalize_functions": "upper", | |
| } | |
| dia: sqlglot.Dialect = sqlglot.Dialect.get_or_raise(dialect) | |
| gt = dia.generator(**generator_opts) | |
| # return ast.sql(dialect=dialect, **generator_opts) | |
| return MyFormatter(gt).generate(ast) | |
| def read_sql_from_source(file): | |
| """Read SQL query from file or stdin.""" | |
| return file.read() | |
| def main(): | |
| """Main function to parse arguments and format SQL query.""" | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "file", nargs="?", type=argparse.FileType("r"), default=sys.stdin | |
| ) | |
| # parser.add_argument("--dialect", default="redshift") | |
| parser.add_argument("--dialect", default="databricks") | |
| args = parser.parse_args() | |
| if args.file is sys.stdin and sys.stdin.isatty(): | |
| parser.print_usage() | |
| sys.exit(1) | |
| sql = read_sql_from_source(args.file) | |
| sqls = sqlparse.split(sql) | |
| for each_sql in sqls: | |
| print(format_sql(each_sql, dialect=args.dialect) + ";\n") | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment