Skip to content

Instantly share code, notes, and snippets.

@anthonny
Forked from longcao/SparkCopyPostgres.scala
Created August 8, 2017 07:03
Show Gist options
  • Save anthonny/817517f37f4977e8daae279d38e83bb0 to your computer and use it in GitHub Desktop.
Save anthonny/817517f37f4977e8daae279d38e83bb0 to your computer and use it in GitHub Desktop.

Revisions

  1. @longcao longcao revised this gist Jan 4, 2017. 1 changed file with 6 additions and 4 deletions.
    10 changes: 6 additions & 4 deletions SparkCopyPostgres.scala
    Original file line number Diff line number Diff line change
    @@ -26,10 +26,12 @@ def rowsToInputStream(rows: Iterator[Row], delimiter: String): InputStream = {
    (row.mkString(delimiter) + "\n").getBytes
    }.flatten

    override def read(): Int = if (bytes.hasNext) {
    bytes.next & 0xff // bitwise AND - make the signed byte an unsigned int from 0-255
    } else {
    -1
    new InputStream {
    override def read(): Int = if (bytes.hasNext) {
    bytes.next & 0xff // bitwise AND - make the signed byte an unsigned int from 0-255
    } else {
    -1
    }
    }
    }

  2. @longcao longcao revised this gist Jul 11, 2016. 1 changed file with 4 additions and 2 deletions.
    6 changes: 4 additions & 2 deletions SparkCopyPostgres.scala
    Original file line number Diff line number Diff line change
    @@ -26,8 +26,10 @@ def rowsToInputStream(rows: Iterator[Row], delimiter: String): InputStream = {
    (row.mkString(delimiter) + "\n").getBytes
    }.flatten

    new InputStream {
    override def read(): Int = if (bytes.hasNext) bytes.next.toInt else -1
    override def read(): Int = if (bytes.hasNext) {
    bytes.next & 0xff // bitwise AND - make the signed byte an unsigned int from 0-255
    } else {
    -1
    }
    }

  3. @longcao longcao revised this gist Jul 11, 2016. 1 changed file with 1 addition and 0 deletions.
    1 change: 1 addition & 0 deletions SparkCopyPostgres.scala
    Original file line number Diff line number Diff line change
    @@ -31,6 +31,7 @@ def rowsToInputStream(rows: Iterator[Row], delimiter: String): InputStream = {
    }
    }

    // Beware: this will open a db connection for every partition of your DataFrame.
    frame.foreachPartition { rows =>
    val conn = cf()
    val cm = new CopyManager(conn.asInstanceOf[BaseConnection])
  4. @longcao longcao revised this gist Jul 11, 2016. No changes.
  5. @longcao longcao created this gist Jul 11, 2016.
    43 changes: 43 additions & 0 deletions SparkCopyPostgres.scala
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,43 @@
    import java.io.InputStream

    import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils
    import org.apache.spark.sql.{ DataFrame, Row }

    import org.postgresql.copy.CopyManager
    import org.postgresql.core.BaseConnection

    val jdbcUrl = s"jdbc:postgresql://..." // db credentials elided
    val connectionProperties = {
    val props = new java.util.Properties()

    props.setProperty("driver", "org.postgresql.Driver")

    props
    }

    // Spark reads the "driver" property to allow users to override the default driver selected, otherwise
    // it picks the Redshift driver, which doesn't support JDBC CopyManager.
    // https://github.com/apache/spark/blob/v1.6.1/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala#L44-51
    val cf: () => Connection = JdbcUtils.createConnectionFactory(jdbcUrl, connectionProperties)

    // Convert every partition (an `Iterator[Row]`) to bytes (InputStream)
    def rowsToInputStream(rows: Iterator[Row], delimiter: String): InputStream = {
    val bytes: Iterator[Byte] = rows.map { row =>
    (row.mkString(delimiter) + "\n").getBytes
    }.flatten

    new InputStream {
    override def read(): Int = if (bytes.hasNext) bytes.next.toInt else -1
    }
    }

    frame.foreachPartition { rows =>
    val conn = cf()
    val cm = new CopyManager(conn.asInstanceOf[BaseConnection])

    cm.copyIn(
    """COPY my_schema._mytable FROM STDIN WITH (NULL 'null', FORMAT CSV, DELIMITER E'\t')""", // adjust COPY settings as you desire, options from https://www.postgresql.org/docs/9.5/static/sql-copy.html
    rowsToInputStream(rows, "\t"))

    conn.close()
    }