Skip to content

Instantly share code, notes, and snippets.

@yahorbarkouski
Last active July 10, 2024 01:26
Show Gist options
  • Select an option

  • Save yahorbarkouski/bcfbf2bf1b10ff7757f2c629eab33a46 to your computer and use it in GitHub Desktop.

Select an option

Save yahorbarkouski/bcfbf2bf1b10ff7757f2c629eab33a46 to your computer and use it in GitHub Desktop.

Revisions

  1. yahorbarkouski revised this gist Nov 30, 2023. 1 changed file with 22 additions and 17 deletions.
    39 changes: 22 additions & 17 deletions PGVectorBinding.kt
    Original file line number Diff line number Diff line change
    @@ -1,13 +1,6 @@
    import org.jooq.Binding
    import org.jooq.BindingGetResultSetContext
    import org.jooq.BindingGetSQLInputContext
    import org.jooq.BindingGetStatementContext
    import org.jooq.BindingRegisterContext
    import org.jooq.BindingSQLContext
    import org.jooq.BindingSetSQLOutputContext
    import org.jooq.BindingSetStatementContext
    import org.jooq.Converter
    import org.jooq.*
    import org.jooq.impl.DSL
    import java.sql.SQLFeatureNotSupportedException
    import java.sql.Types

    @Suppress("UNCHECKED_CAST")
    @@ -31,6 +24,14 @@ class PGVectorBinding : Binding<Any, List<Double>> {
    }
    }

    override fun sql(ctx: BindingSQLContext<List<Double>>) {
    ctx.render().visit(DSL.`val`(ctx.convert(converter()).value())).sql("::vector")
    }

    override fun register(ctx: BindingRegisterContext<List<Double>>) {
    ctx.statement().registerOutParameter(ctx.index(), Types.ARRAY)
    }

    override fun get(ctx: BindingGetResultSetContext<List<Double>>) {
    val resultSet = ctx.resultSet()
    val vectorAsString = resultSet.getString(ctx.index())
    @@ -42,15 +43,19 @@ class PGVectorBinding : Binding<Any, List<Double>> {
    ctx.statement().setString(ctx.index(), value?.let { converter().to(it) as String } ?: "[]")
    }

    override fun register(ctx: BindingRegisterContext<List<Double>>) {
    ctx.statement().registerOutParameter(ctx.index(), Types.ARRAY)
    override fun get(ctx: BindingGetStatementContext<List<Double>>) {
    val statement = ctx.statement()
    val vectorAsString = statement.getString(ctx.index())
    ctx.value(converter().from(vectorAsString))
    }

    override fun sql(ctx: BindingSQLContext<List<Double>>) {
    ctx.render().visit(DSL.`val`(ctx.convert(converter()).value())).sql("::vector")
    // the below methods aren't needed in Postgres:

    override fun get(ctx: BindingGetSQLInputContext<List<Double>>?) {
    throw SQLFeatureNotSupportedException()
    }

    override fun get(ctx: BindingGetSQLInputContext<List<Double>>?) {}
    override fun get(ctx: BindingGetStatementContext<List<Double>>?) {}
    override fun set(ctx: BindingSetSQLOutputContext<List<Double>>?) {}
    }
    override fun set(ctx: BindingSetSQLOutputContext<List<Double>>?) {
    throw SQLFeatureNotSupportedException()
    }
    }
  2. yahorbarkouski created this gist Nov 30, 2023.
    56 changes: 56 additions & 0 deletions PGVectorBinding.kt
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,56 @@
    import org.jooq.Binding
    import org.jooq.BindingGetResultSetContext
    import org.jooq.BindingGetSQLInputContext
    import org.jooq.BindingGetStatementContext
    import org.jooq.BindingRegisterContext
    import org.jooq.BindingSQLContext
    import org.jooq.BindingSetSQLOutputContext
    import org.jooq.BindingSetStatementContext
    import org.jooq.Converter
    import org.jooq.impl.DSL
    import java.sql.Types

    @Suppress("UNCHECKED_CAST")
    class PGVectorBinding : Binding<Any, List<Double>> {

    override fun converter(): Converter<Any, List<Double>> {
    return object : Converter<Any, List<Double>> {
    override fun from(databaseObject: Any?): List<Double> {
    return databaseObject?.let { v ->
    v.toString().removeSurrounding("[", "]").split(",").map { it.toDouble() }
    } ?: emptyList()
    }

    override fun to(userObject: List<Double>): Any {
    return userObject.toString()
    }

    override fun fromType(): Class<Any> = Any::class.java

    override fun toType(): Class<List<Double>> = List::class.java as Class<List<Double>>
    }
    }

    override fun get(ctx: BindingGetResultSetContext<List<Double>>) {
    val resultSet = ctx.resultSet()
    val vectorAsString = resultSet.getString(ctx.index())
    ctx.value(converter().from(vectorAsString))
    }

    override fun set(ctx: BindingSetStatementContext<List<Double>>) {
    val value = ctx.value()
    ctx.statement().setString(ctx.index(), value?.let { converter().to(it) as String } ?: "[]")
    }

    override fun register(ctx: BindingRegisterContext<List<Double>>) {
    ctx.statement().registerOutParameter(ctx.index(), Types.ARRAY)
    }

    override fun sql(ctx: BindingSQLContext<List<Double>>) {
    ctx.render().visit(DSL.`val`(ctx.convert(converter()).value())).sql("::vector")
    }

    override fun get(ctx: BindingGetSQLInputContext<List<Double>>?) {}
    override fun get(ctx: BindingGetStatementContext<List<Double>>?) {}
    override fun set(ctx: BindingSetSQLOutputContext<List<Double>>?) {}
    }