Skip to content

Instantly share code, notes, and snippets.

@jim-my
Created March 23, 2025 01:59
Show Gist options
  • Save jim-my/12845dc8b71efe70df306efd798e3254 to your computer and use it in GitHub Desktop.
Save jim-my/12845dc8b71efe70df306efd798e3254 to your computer and use it in GitHub Desktop.

Revisions

  1. jim-my created this gist Mar 23, 2025.
    149 changes: 149 additions & 0 deletions format-sql-sqlglot.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,149 @@
    #!/usr/bin/env python
    # -*- coding: utf-8 -*-
    """Format SQL query using sqlglot library."""

    import argparse
    import sys
    import typing as t

    import sqlglot
    import sqlglot.expressions as exp
    import sqlparse


    class MyFormatter(sqlglot.Generator):
    """Custom formatter for SQL query."""

    def __init__(self, _original_object):
    self._original_object = _original_object

    def __getattr__(self, name):
    """Delegate attribute access to the original object."""
    return getattr(self._original_object, name)

    def maybe_comment(
    self,
    sql: str,
    expression: t.Optional[exp.Expression] = None,
    comments: t.Optional[t.List[str]] = None,
    separated: bool = False,
    ) -> str:
    comments = (
    ((expression and expression.comments) if comments is None else comments) # type: ignore
    if self.comments
    else None
    )

    if not comments or isinstance(expression, self.EXCLUDE_COMMENTS):
    return sql

    if len(comments) > 1: # multiple-line comment
    comments_sql = (
    "/*\n *"
    + "\n *".join(
    f"{self.pad_comment(comment)}" for comment in comments if comment
    )
    + "\n */"
    )
    else: # single-line comment
    comments_sql = "\n".join(
    f"-- {self.pad_comment(comment)}" for comment in comments if comment
    )

    if not comments_sql:
    return sql

    comments_sql = self._replace_line_breaks(comments_sql)

    if separated or isinstance(expression, self.WITH_SEPARATED_COMMENTS):
    return (
    f"{self.sep()}{comments_sql}{sql}"
    if not sql or sql[0].isspace()
    else f"{comments_sql}{self.sep()}{sql}"
    )

    return f"{sql} {comments_sql}"


    def format_sql(sql, dialect="redshift"):
    """
    Format SQL query using sqlglot/sqlparse library.
    # REF: https://github.com/tobymao/sqlglot/blob/32a86d38b7935bb04644ef2ebc07589d5e040e34/sqlglot/generator.py#L40
    Args(for ast.sql() below):
    pretty: Whether or not to format the produced SQL string.
    Default: False.
    identify: Determines when an identifier should be quoted. Possible values are:
    False (default): Never quote, except in cases where it's mandatory by the dialect.
    True or 'always': Always quote.
    'safe': Only quote identifiers that are case insensitive.
    normalize: Whether or not to normalize identifiers to lowercase.
    Default: False.
    pad: Determines the pad size in a formatted string.
    Default: 2.
    indent: Determines the indentation size in a formatted string.
    Default: 2.
    normalize_functions: Whether or not to normalize all function names. Possible values are:
    "upper" or True (default): Convert names to uppercase.
    "lower": Convert names to lowercase.
    False: Disables function name normalization.
    unsupported_level: Determines the generator's behavior when it encounters unsupported expressions.
    Default ErrorLevel.WARN.
    max_unsupported: Maximum number of unsupported messages to include in a raised UnsupportedError.
    This is only relevant if unsupported_level is ErrorLevel.RAISE.
    Default: 3
    leading_comma: Determines whether or not the comma is leading or trailing in select expressions.
    This is only relevant when generating in pretty mode.
    Default: False
    max_text_width: The max number of characters in a segment before creating new lines in pretty mode.
    The default is on the smaller end because the length only represents a segment and not the true
    line length.
    Default: 80
    comments: Whether or not to preserve comments in the output SQL code.
    Default: True
    """
    ast: sqlglot.Expression = sqlglot.parse_one(sql, dialect=dialect)
    generator_opts = {
    "identify": False,
    "pretty": True,
    "indent": 4,
    "pad": 4, # E.g. padding before c1 and c2 for `echo "select c1, c2" | sqlglot-format.py`
    "leading_comma": True,
    # max_text_width=80,
    "comments": True,
    "normalize_functions": "upper",
    }

    dia: sqlglot.Dialect = sqlglot.Dialect.get_or_raise(dialect)
    gt = dia.generator(**generator_opts)
    # return ast.sql(dialect=dialect, **generator_opts)
    return MyFormatter(gt).generate(ast)


    def read_sql_from_source(file):
    """Read SQL query from file or stdin."""
    return file.read()


    def main():
    """Main function to parse arguments and format SQL query."""
    parser = argparse.ArgumentParser()
    parser.add_argument(
    "file", nargs="?", type=argparse.FileType("r"), default=sys.stdin
    )
    # parser.add_argument("--dialect", default="redshift")
    parser.add_argument("--dialect", default="databricks")
    args = parser.parse_args()

    if args.file is sys.stdin and sys.stdin.isatty():
    parser.print_usage()
    sys.exit(1)

    sql = read_sql_from_source(args.file)
    sqls = sqlparse.split(sql)
    for each_sql in sqls:
    print(format_sql(each_sql, dialect=args.dialect) + ";\n")


    if __name__ == "__main__":
    main()