Skip to content

Instantly share code, notes, and snippets.

@jim-my
Created March 23, 2025 01:59
Show Gist options
  • Save jim-my/12845dc8b71efe70df306efd798e3254 to your computer and use it in GitHub Desktop.
Save jim-my/12845dc8b71efe70df306efd798e3254 to your computer and use it in GitHub Desktop.
Format SQL query using sqlglot library(sqlfluff might be a better option sometimes?)
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""Format SQL query using sqlglot library."""
import argparse
import sys
import typing as t
import sqlglot
import sqlglot.expressions as exp
import sqlparse
class MyFormatter(sqlglot.Generator):
"""Custom formatter for SQL query."""
def __init__(self, _original_object):
self._original_object = _original_object
def __getattr__(self, name):
"""Delegate attribute access to the original object."""
return getattr(self._original_object, name)
def maybe_comment(
self,
sql: str,
expression: t.Optional[exp.Expression] = None,
comments: t.Optional[t.List[str]] = None,
separated: bool = False,
) -> str:
comments = (
((expression and expression.comments) if comments is None else comments) # type: ignore
if self.comments
else None
)
if not comments or isinstance(expression, self.EXCLUDE_COMMENTS):
return sql
if len(comments) > 1: # multiple-line comment
comments_sql = (
"/*\n *"
+ "\n *".join(
f"{self.pad_comment(comment)}" for comment in comments if comment
)
+ "\n */"
)
else: # single-line comment
comments_sql = "\n".join(
f"-- {self.pad_comment(comment)}" for comment in comments if comment
)
if not comments_sql:
return sql
comments_sql = self._replace_line_breaks(comments_sql)
if separated or isinstance(expression, self.WITH_SEPARATED_COMMENTS):
return (
f"{self.sep()}{comments_sql}{sql}"
if not sql or sql[0].isspace()
else f"{comments_sql}{self.sep()}{sql}"
)
return f"{sql} {comments_sql}"
def format_sql(sql, dialect="redshift"):
"""
Format SQL query using sqlglot/sqlparse library.
# REF: https://github.com/tobymao/sqlglot/blob/32a86d38b7935bb04644ef2ebc07589d5e040e34/sqlglot/generator.py#L40
Args(for ast.sql() below):
pretty: Whether or not to format the produced SQL string.
Default: False.
identify: Determines when an identifier should be quoted. Possible values are:
False (default): Never quote, except in cases where it's mandatory by the dialect.
True or 'always': Always quote.
'safe': Only quote identifiers that are case insensitive.
normalize: Whether or not to normalize identifiers to lowercase.
Default: False.
pad: Determines the pad size in a formatted string.
Default: 2.
indent: Determines the indentation size in a formatted string.
Default: 2.
normalize_functions: Whether or not to normalize all function names. Possible values are:
"upper" or True (default): Convert names to uppercase.
"lower": Convert names to lowercase.
False: Disables function name normalization.
unsupported_level: Determines the generator's behavior when it encounters unsupported expressions.
Default ErrorLevel.WARN.
max_unsupported: Maximum number of unsupported messages to include in a raised UnsupportedError.
This is only relevant if unsupported_level is ErrorLevel.RAISE.
Default: 3
leading_comma: Determines whether or not the comma is leading or trailing in select expressions.
This is only relevant when generating in pretty mode.
Default: False
max_text_width: The max number of characters in a segment before creating new lines in pretty mode.
The default is on the smaller end because the length only represents a segment and not the true
line length.
Default: 80
comments: Whether or not to preserve comments in the output SQL code.
Default: True
"""
ast: sqlglot.Expression = sqlglot.parse_one(sql, dialect=dialect)
generator_opts = {
"identify": False,
"pretty": True,
"indent": 4,
"pad": 4, # E.g. padding before c1 and c2 for `echo "select c1, c2" | sqlglot-format.py`
"leading_comma": True,
# max_text_width=80,
"comments": True,
"normalize_functions": "upper",
}
dia: sqlglot.Dialect = sqlglot.Dialect.get_or_raise(dialect)
gt = dia.generator(**generator_opts)
# return ast.sql(dialect=dialect, **generator_opts)
return MyFormatter(gt).generate(ast)
def read_sql_from_source(file):
"""Read SQL query from file or stdin."""
return file.read()
def main():
"""Main function to parse arguments and format SQL query."""
parser = argparse.ArgumentParser()
parser.add_argument(
"file", nargs="?", type=argparse.FileType("r"), default=sys.stdin
)
# parser.add_argument("--dialect", default="redshift")
parser.add_argument("--dialect", default="databricks")
args = parser.parse_args()
if args.file is sys.stdin and sys.stdin.isatty():
parser.print_usage()
sys.exit(1)
sql = read_sql_from_source(args.file)
sqls = sqlparse.split(sql)
for each_sql in sqls:
print(format_sql(each_sql, dialect=args.dialect) + ";\n")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment