Created
March 23, 2025 01:59
-
-
Save jim-my/12845dc8b71efe70df306efd798e3254 to your computer and use it in GitHub Desktop.
Revisions
-
jim-my created this gist
Mar 23, 2025 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,149 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- """Format SQL query using sqlglot library.""" import argparse import sys import typing as t import sqlglot import sqlglot.expressions as exp import sqlparse class MyFormatter(sqlglot.Generator): """Custom formatter for SQL query.""" def __init__(self, _original_object): self._original_object = _original_object def __getattr__(self, name): """Delegate attribute access to the original object.""" return getattr(self._original_object, name) def maybe_comment( self, sql: str, expression: t.Optional[exp.Expression] = None, comments: t.Optional[t.List[str]] = None, separated: bool = False, ) -> str: comments = ( ((expression and expression.comments) if comments is None else comments) # type: ignore if self.comments else None ) if not comments or isinstance(expression, self.EXCLUDE_COMMENTS): return sql if len(comments) > 1: # multiple-line comment comments_sql = ( "/*\n *" + "\n *".join( f"{self.pad_comment(comment)}" for comment in comments if comment ) + "\n */" ) else: # single-line comment comments_sql = "\n".join( f"-- {self.pad_comment(comment)}" for comment in comments if comment ) if not comments_sql: return sql comments_sql = self._replace_line_breaks(comments_sql) if separated or isinstance(expression, self.WITH_SEPARATED_COMMENTS): return ( f"{self.sep()}{comments_sql}{sql}" if not sql or sql[0].isspace() else f"{comments_sql}{self.sep()}{sql}" ) return f"{sql} {comments_sql}" def format_sql(sql, dialect="redshift"): """ Format SQL query using sqlglot/sqlparse library. # REF: https://github.com/tobymao/sqlglot/blob/32a86d38b7935bb04644ef2ebc07589d5e040e34/sqlglot/generator.py#L40 Args(for ast.sql() below): pretty: Whether or not to format the produced SQL string. Default: False. identify: Determines when an identifier should be quoted. Possible values are: False (default): Never quote, except in cases where it's mandatory by the dialect. True or 'always': Always quote. 'safe': Only quote identifiers that are case insensitive. normalize: Whether or not to normalize identifiers to lowercase. Default: False. pad: Determines the pad size in a formatted string. Default: 2. indent: Determines the indentation size in a formatted string. Default: 2. normalize_functions: Whether or not to normalize all function names. Possible values are: "upper" or True (default): Convert names to uppercase. "lower": Convert names to lowercase. False: Disables function name normalization. unsupported_level: Determines the generator's behavior when it encounters unsupported expressions. Default ErrorLevel.WARN. max_unsupported: Maximum number of unsupported messages to include in a raised UnsupportedError. This is only relevant if unsupported_level is ErrorLevel.RAISE. Default: 3 leading_comma: Determines whether or not the comma is leading or trailing in select expressions. This is only relevant when generating in pretty mode. Default: False max_text_width: The max number of characters in a segment before creating new lines in pretty mode. The default is on the smaller end because the length only represents a segment and not the true line length. Default: 80 comments: Whether or not to preserve comments in the output SQL code. Default: True """ ast: sqlglot.Expression = sqlglot.parse_one(sql, dialect=dialect) generator_opts = { "identify": False, "pretty": True, "indent": 4, "pad": 4, # E.g. padding before c1 and c2 for `echo "select c1, c2" | sqlglot-format.py` "leading_comma": True, # max_text_width=80, "comments": True, "normalize_functions": "upper", } dia: sqlglot.Dialect = sqlglot.Dialect.get_or_raise(dialect) gt = dia.generator(**generator_opts) # return ast.sql(dialect=dialect, **generator_opts) return MyFormatter(gt).generate(ast) def read_sql_from_source(file): """Read SQL query from file or stdin.""" return file.read() def main(): """Main function to parse arguments and format SQL query.""" parser = argparse.ArgumentParser() parser.add_argument( "file", nargs="?", type=argparse.FileType("r"), default=sys.stdin ) # parser.add_argument("--dialect", default="redshift") parser.add_argument("--dialect", default="databricks") args = parser.parse_args() if args.file is sys.stdin and sys.stdin.isatty(): parser.print_usage() sys.exit(1) sql = read_sql_from_source(args.file) sqls = sqlparse.split(sql) for each_sql in sqls: print(format_sql(each_sql, dialect=args.dialect) + ";\n") if __name__ == "__main__": main()