Skip to content

Instantly share code, notes, and snippets.

@komamitsu
Last active May 2, 2022 06:51
Show Gist options
  • Select an option

  • Save komamitsu/dd8ef34859d3b15198c6851b4adb6946 to your computer and use it in GitHub Desktop.

Select an option

Save komamitsu/dd8ef34859d3b15198c6851b4adb6946 to your computer and use it in GitHub Desktop.

Revisions

  1. komamitsu revised this gist May 2, 2022. 1 changed file with 34 additions and 51 deletions.
    85 changes: 34 additions & 51 deletions RewriteSampleWithCalcite.kotlin
    Original file line number Diff line number Diff line change
    @@ -1,40 +1,6 @@
    import io.trino.jdbc.TrinoDriver
    import org.apache.calcite.avatica.util.Casing
    import org.apache.calcite.avatica.util.Quoting
    import org.apache.calcite.avatica.util.TimeUnit
    import org.apache.calcite.sql.SqlBasicCall
    import org.apache.calcite.sql.SqlCall
    import org.apache.calcite.sql.SqlDataTypeSpec
    import org.apache.calcite.sql.SqlIdentifier
    import org.apache.calcite.sql.SqlIntervalQualifier
    import org.apache.calcite.sql.SqlKind
    import org.apache.calcite.sql.SqlNode
    import org.apache.calcite.sql.SqlNodeList
    import org.apache.calcite.sql.SqlNumericLiteral
    import org.apache.calcite.sql.SqlUserDefinedTypeNameSpec
    import org.apache.calcite.sql.`fun`.SqlCastFunction
    import org.apache.calcite.sql.`fun`.SqlStdOperatorTable
    import org.apache.calcite.sql.parser.SqlParser
    import org.apache.calcite.sql.parser.SqlParserUtil
    import org.apache.calcite.sql.util.SqlShuttle
    import org.apache.calcite.sql.validate.SqlValidator
    import org.apache.calcite.tools.Frameworks
    import java.sql.Connection
    import java.sql.DriverManager
    import java.sql.SQLException


    fun calcite() {
    val schema = Frameworks.createRootSchema(true)
    val schemaName = "sample_datasets"
    val customSchema = JdbcSchema.create(schema, schemaName, ImmutableMap.of<String, Any>(
    "jdbcDriver", "io.trino.jdbc.TrinoDriver",
    "jdbcUrl", DB_URL,
    "jdbcUser", USER,
    "jdbcSchema", schemaName
    ))
    schema.add(schemaName, customSchema)

    val config = Frameworks.newConfigBuilder()
    .defaultSchema(schema)
    .sqlValidatorConfig(SqlValidator.Config.DEFAULT)
    @@ -76,8 +42,24 @@ ORDER BY [TempTableQuerySchema].[YsnsAls_0002] ASC
    """.trimIndent()

    val node = planner.parse(sql)
    val validated = planner.validate(node)
    println(validated.toSqlString { c ->
    val typeFactory = JavaTypeFactoryImpl(config.typeSystem)
    class SqlRewriter : SqlValidatorImpl(
    config.operatorTable,
    CalciteCatalogReader(
    CalciteSchema.from(schema),
    CalciteSchema.from(schema).path(null),
    typeFactory,
    null
    ),
    typeFactory,
    config.sqlValidatorConfig
    ) {
    fun rewrite(node: SqlNode): SqlNode {
    return super.performUnconditionalRewrites(node, false)
    }
    }
    val rewritten = SqlRewriter().rewrite(node.accept(MySqlVisitor())!!)
    println(rewritten.toSqlString { c ->
    c.withDialect(PrestoSqlDialect.DEFAULT)
    .withAlwaysUseParentheses(false)
    .withSubQueryStyle(SqlWriter.SubQueryStyle.HYDE)
    @@ -114,14 +96,15 @@ class MySqlVisitor : SqlShuttle() {
    )
    }
    SqlKind.OTHER_FUNCTION -> {

    if (call.operator.name.uppercase() == "DATEADD") {
    val intervalType = when (call.operand<SqlIdentifier>(0).simple.lowercase()) {
    "yy", "yyyy" -> SqlIntervalQualifier(TimeUnit.YEAR, TimeUnit.YEAR, call.parserPosition)
    "mm", "m" -> SqlIntervalQualifier(TimeUnit.MONTH, TimeUnit.MONTH, call.parserPosition)
    "dd", "d" -> SqlIntervalQualifier(TimeUnit.DAY, TimeUnit.DAY, call.parserPosition)
    "hh" -> SqlIntervalQualifier(TimeUnit.HOUR, TimeUnit.HOUR, call.parserPosition)
    "mi", "n" -> SqlIntervalQualifier(TimeUnit.MINUTE, TimeUnit.MINUTE, call.parserPosition)
    "ss", "s" -> SqlIntervalQualifier(TimeUnit.SECOND, TimeUnit.SECOND, call.parserPosition)
    val timeunit = when (call.operand<SqlIdentifier>(0).simple.lowercase()) {
    "yy", "yyyy" -> "year"
    "mm", "m" -> "month"
    "dd", "d" -> "day"
    "hh" -> "hour"
    "mi", "n" -> "minute"
    "ss", "s" -> "second"
    else -> throw IllegalArgumentException("Unexpected identifier: ${call.operand<SqlCall>(0)}")
    }
    val origDiff = call.operand<SqlNode>(1)
    @@ -139,14 +122,14 @@ class MySqlVisitor : SqlShuttle() {
    else -> visit(origTarget as SqlBasicCall)
    }
    return SqlBasicCall(
    SqlStdOperatorTable.DATETIME_PLUS,
    SqlUnresolvedFunction(
    SqlIdentifier("DATE_ADD", call.parserPosition),
    null, null, null, null, SqlFunctionCategory.TIMEDATE
    ),
    SqlNodeList.of(
    target,
    SqlBasicCall(
    SqlStdOperatorTable.INTERVAL,
    SqlNodeList.of(diff, intervalType),
    call.parserPosition
    )
    SqlLiteral.createCharString(timeunit, call.parserPosition),
    diff,
    target
    ),
    call.parserPosition
    )
    @@ -155,4 +138,4 @@ class MySqlVisitor : SqlShuttle() {
    }
    return super.visit(call)
    }
    }
    }
  2. komamitsu revised this gist Apr 29, 2022. 1 changed file with 19 additions and 27 deletions.
    46 changes: 19 additions & 27 deletions RewriteSampleWithCalcite.kotlin
    Original file line number Diff line number Diff line change
    @@ -25,7 +25,18 @@ import java.sql.SQLException


    fun calcite() {
    val schema = Frameworks.createRootSchema(true)
    val schemaName = "sample_datasets"
    val customSchema = JdbcSchema.create(schema, schemaName, ImmutableMap.of<String, Any>(
    "jdbcDriver", "io.trino.jdbc.TrinoDriver",
    "jdbcUrl", DB_URL,
    "jdbcUser", USER,
    "jdbcSchema", schemaName
    ))
    schema.add(schemaName, customSchema)

    val config = Frameworks.newConfigBuilder()
    .defaultSchema(schema)
    .sqlValidatorConfig(SqlValidator.Config.DEFAULT)
    .parserConfig(
    SqlParser.config()
    @@ -34,21 +45,9 @@ fun calcite() {
    .withQuotedCasing(Casing.UNCHANGED)
    .withUnquotedCasing(Casing.UNCHANGED)
    )
    .defaultSchema(Frameworks.createRootSchema(true))
    .build()
    val planner = Frameworks.getPlanner(config)
    /*
    val sql = """
    SELECT
    CAST(
    DateAdd(yy,YEAR(tq_gnOjk2mbD.t) - 1904, DateAdd(mm,1 - 1, DateAdd(dd, 1 - 1, '1904-01-01'))) AS DateTime
    ) AS YsnsAls_0002
    FROM
    [tq_gnOjk2mbD_cTable] [tq_gnOjk2mb]
    """.trimIndent()
    */

    /*
    val sql = """
    WITH tq_gnOjk2mbD_cTable AS (
    SELECT *, from_unixtime([time]) as t
    @@ -74,24 +73,17 @@ SELECT [YsnsAls_0002] FROM (
    ) AS [TempTableQuerySchema]
    WHERE (rn > 0 AND rn <= 50001 )
    ORDER BY [TempTableQuerySchema].[YsnsAls_0002] ASC
    """.trimIndent()
    */

    val sql = """
    WITH cte AS (
    select 42 as i
    )
    SELECT i FROM (
    SELECT [i] from cte
    ) AS [alias]
    ORDER BY [alias].[i]
    """.trimIndent()

    val node = planner.parse(sql)
    println(node)
    // FileWriter("/tmp/presto.sql").use {
    // it.write(node.accept(MySqlVisitor())!!.toSqlString(PostgresqlSqlDialect.DEFAULT).toString())
    // }
    val validated = planner.validate(node)
    println(validated.toSqlString { c ->
    c.withDialect(PrestoSqlDialect.DEFAULT)
    .withAlwaysUseParentheses(false)
    .withSubQueryStyle(SqlWriter.SubQueryStyle.HYDE)
    .withClauseStartsLine(false)
    .withClauseEndsLine(false)
    })
    }

    class MySqlVisitor : SqlShuttle() {
  3. komamitsu revised this gist Apr 27, 2022. 1 changed file with 147 additions and 21 deletions.
    168 changes: 147 additions & 21 deletions RewriteSampleWithCalcite.kotlin
    Original file line number Diff line number Diff line change
    @@ -1,39 +1,165 @@
    import io.trino.jdbc.TrinoDriver
    import org.apache.calcite.avatica.util.Casing
    import org.apache.calcite.avatica.util.Quoting
    import org.apache.calcite.avatica.util.TimeUnit
    import org.apache.calcite.sql.SqlBasicCall
    import org.apache.calcite.sql.SqlCall
    import org.apache.calcite.sql.SqlDataTypeSpec
    import org.apache.calcite.sql.SqlIdentifier
    import org.apache.calcite.sql.SqlIntervalQualifier
    import org.apache.calcite.sql.SqlKind
    import org.apache.calcite.sql.SqlNode
    import org.apache.calcite.sql.SqlNodeList
    import org.apache.calcite.sql.SqlNumericLiteral
    import org.apache.calcite.sql.SqlUserDefinedTypeNameSpec
    import org.apache.calcite.sql.`fun`.SqlCastFunction
    import org.apache.calcite.sql.`fun`.SqlStdOperatorTable
    import org.apache.calcite.sql.parser.SqlParser
    import org.apache.calcite.sql.parser.SqlParserUtil
    import org.apache.calcite.sql.util.SqlShuttle
    import org.apache.calcite.sql.validate.SqlValidator
    import org.apache.calcite.tools.Frameworks
    import java.sql.Connection
    import java.sql.DriverManager
    import java.sql.SQLException


    fun calcite() {
    val config = Frameworks.newConfigBuilder()
    .sqlValidatorConfig(SqlValidator.Config.DEFAULT)
    .parserConfig(SqlParser.config())
    .parserConfig(
    SqlParser.config()
    .withQuoting(Quoting.BRACKET)
    .withCaseSensitive(true)
    .withQuotedCasing(Casing.UNCHANGED)
    .withUnquotedCasing(Casing.UNCHANGED)
    )
    .defaultSchema(Frameworks.createRootSchema(true))
    .build()
    val planner = Frameworks.getPlanner(config)
    /*
    val sql = """
    SELECT
    "CAST"(
    DateAdd(yy,YEAR(tq_gnOjk2mbD.t) - 1904, DateAdd(mm,1 - 1, DateAdd(dd, 1 - 1, '1904-01-01')))
    )
    AS YsnsAls_0002
    CAST(
    DateAdd(yy,YEAR(tq_gnOjk2mbD.t) - 1904, DateAdd(mm,1 - 1, DateAdd(dd, 1 - 1, '1904-01-01'))) AS DateTime
    ) AS YsnsAls_0002
    FROM
    tq_gnOjk2mbD_cTable tq_gnOjk2mbD
    [tq_gnOjk2mbD_cTable] [tq_gnOjk2mb]
    """.trimIndent()
    */

    /*
    val sql = """
    WITH tq_gnOjk2mbD_cTable AS (
    SELECT *, from_unixtime([time]) as t
    FROM [sample_datasets].[www_access]
    )
    SELECT [YsnsAls_0002] FROM (
    SELECT
    ROW_NUMBER() OVER (ORDER BY (SELECT [TempTableQuerySchema].[YsnsAls_0002]) ASC) as rn,
    (CAST(DateAdd(yy,YEAR([TempTableQuerySchema].[YsnsAls_0002]) - 1904, DateAdd(mm,1 - 1, DateAdd(dd, 1 - 1, '1904-01-01'))) AS DateTime)) AS [YsnsAls_0002]
    FROM (
    SELECT ([TempTableQuerySchema].[YsnsAls_0002]) AS [YsnsAls_0002]
    FROM (
    SELECT ([TempTableQuerySchema].[YsnsAls_0002]) AS [YsnsAls_0002]
    FROM (
    SELECT ([TempTableQuerySchema].[YsnsAls_0002]) AS [YsnsAls_0002] FROM (
    SELECT (CAST(DateAdd(yy,YEAR([tq_gnOjk2mbD].[t]) - 1904, DateAdd(mm,1 - 1, DateAdd(dd, 1 - 1, '1904-01-01'))) AS DateTime)) AS [YsnsAls_0002]
    FROM [tq_gnOjk2mbD_cTable] [tq_gnOjk2mbD]
    ) AS [TempTableQuerySchema]
    ) AS [TempTableQuerySchema]
    GROUP BY [TempTableQuerySchema].[YsnsAls_0002]
    ) AS [TempTableQuerySchema]
    ) AS [TempTableQuerySchema]
    ) AS [TempTableQuerySchema]
    WHERE (rn > 0 AND rn <= 50001 )
    ORDER BY [TempTableQuerySchema].[YsnsAls_0002] ASC
    """.trimIndent()
    */

    val sql = """
    WITH cte AS (
    select 42 as i
    )
    SELECT i FROM (
    SELECT [i] from cte
    ) AS [alias]
    ORDER BY [alias].[i]
    """.trimIndent()

    val node = planner.parse(sql)
    println(node.accept(MySqlVisitor()))
    println(node)
    // FileWriter("/tmp/presto.sql").use {
    // it.write(node.accept(MySqlVisitor())!!.toSqlString(PostgresqlSqlDialect.DEFAULT).toString())
    // }
    }

    class MySqlVisitor : SqlShuttle() {
    override fun visit(call: SqlCall?): SqlNode? {
    val sqlCall = call!!
    if (sqlCall.kind == SqlKind.OTHER_FUNCTION
    && sqlCall.operator.name.uppercase() == "DATEADD") {
    val op = sqlCall.operator
    val operands: List<SqlNode> = visit(SqlNodeList.of(SqlParserPos.ZERO, sqlCall.operandList)) as SqlNodeList
    return SqlBasicCall(
    SqlFunction(
    "MY_DATA_ADD",
    op.kind,
    op.returnTypeInference,
    op.operandTypeInference,
    op.operandTypeChecker,
    SqlFunctionCategory.TIMEDATE
    ), operands, call.parserPosition)
    call!!
    /*
    println("<<<<<<<${call.kind}>>>>>>> ${call.operator} : ${call.operandList}")
    if (call.kind == SqlKind.CAST) {
    call.operandList.withIndex().forEach {
    println(">>>>>>>>>>>>>>>>>>>>>> ${it.index}: ${it.value.kind} : ${it.value}")
    }
    }
    */

    when (call.kind) {
    SqlKind.CAST -> {
    val origDstType = call.operand<SqlDataTypeSpec>(1)
    val dstType = if (origDstType.typeNameSpec.typeName.simple.uppercase() == "DATETIME") {
    SqlUserDefinedTypeNameSpec("TIMESTAMP", call.parserPosition)
    }
    else {
    origDstType.typeNameSpec
    }
    return SqlBasicCall(
    SqlCastFunction(),
    visit(SqlNodeList.of(call.operand(0), SqlDataTypeSpec(dstType, call.parserPosition))) as SqlNodeList,
    call.parserPosition
    )
    }
    SqlKind.OTHER_FUNCTION -> {
    if (call.operator.name.uppercase() == "DATEADD") {
    val intervalType = when (call.operand<SqlIdentifier>(0).simple.lowercase()) {
    "yy", "yyyy" -> SqlIntervalQualifier(TimeUnit.YEAR, TimeUnit.YEAR, call.parserPosition)
    "mm", "m" -> SqlIntervalQualifier(TimeUnit.MONTH, TimeUnit.MONTH, call.parserPosition)
    "dd", "d" -> SqlIntervalQualifier(TimeUnit.DAY, TimeUnit.DAY, call.parserPosition)
    "hh" -> SqlIntervalQualifier(TimeUnit.HOUR, TimeUnit.HOUR, call.parserPosition)
    "mi", "n" -> SqlIntervalQualifier(TimeUnit.MINUTE, TimeUnit.MINUTE, call.parserPosition)
    "ss", "s" -> SqlIntervalQualifier(TimeUnit.SECOND, TimeUnit.SECOND, call.parserPosition)
    else -> throw IllegalArgumentException("Unexpected identifier: ${call.operand<SqlCall>(0)}")
    }
    val origDiff = call.operand<SqlNode>(1)
    val diff =
    when (origDiff.kind) {
    SqlKind.LITERAL -> visit(origDiff as SqlNumericLiteral)
    else -> visit(origDiff as SqlBasicCall)
    }
    val origTarget = call.operand<SqlNode>(2)
    val target =
    when (origTarget.kind) {
    SqlKind.LITERAL -> visit(
    SqlParserUtil.parseTimestampLiteral(origTarget.toString(), call.parserPosition)
    )
    else -> visit(origTarget as SqlBasicCall)
    }
    return SqlBasicCall(
    SqlStdOperatorTable.DATETIME_PLUS,
    SqlNodeList.of(
    target,
    SqlBasicCall(
    SqlStdOperatorTable.INTERVAL,
    SqlNodeList.of(diff, intervalType),
    call.parserPosition
    )
    ),
    call.parserPosition
    )
    }
    }
    }
    return super.visit(call)
    }
  4. komamitsu created this gist Apr 27, 2022.
    40 changes: 40 additions & 0 deletions RewriteSampleWithCalcite.kotlin
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,40 @@
    fun calcite() {
    val config = Frameworks.newConfigBuilder()
    .sqlValidatorConfig(SqlValidator.Config.DEFAULT)
    .parserConfig(SqlParser.config())
    .defaultSchema(Frameworks.createRootSchema(true))
    .build()
    val planner = Frameworks.getPlanner(config)
    val sql = """
    SELECT
    "CAST"(
    DateAdd(yy,YEAR(tq_gnOjk2mbD.t) - 1904, DateAdd(mm,1 - 1, DateAdd(dd, 1 - 1, '1904-01-01')))
    )
    AS YsnsAls_0002
    FROM
    tq_gnOjk2mbD_cTable tq_gnOjk2mbD
    """.trimIndent()
    val node = planner.parse(sql)
    println(node.accept(MySqlVisitor()))
    }

    class MySqlVisitor : SqlShuttle() {
    override fun visit(call: SqlCall?): SqlNode? {
    val sqlCall = call!!
    if (sqlCall.kind == SqlKind.OTHER_FUNCTION
    && sqlCall.operator.name.uppercase() == "DATEADD") {
    val op = sqlCall.operator
    val operands: List<SqlNode> = visit(SqlNodeList.of(SqlParserPos.ZERO, sqlCall.operandList)) as SqlNodeList
    return SqlBasicCall(
    SqlFunction(
    "MY_DATA_ADD",
    op.kind,
    op.returnTypeInference,
    op.operandTypeInference,
    op.operandTypeChecker,
    SqlFunctionCategory.TIMEDATE
    ), operands, call.parserPosition)
    }
    return super.visit(call)
    }
    }