Skip to content

Instantly share code, notes, and snippets.

@juliolinarez
Forked from hadees/arel_helpers.rb
Created December 3, 2021 13:09
Show Gist options
  • Select an option

  • Save juliolinarez/6bca9abe5b3f239f653bcd39ea22f944 to your computer and use it in GitHub Desktop.

Select an option

Save juliolinarez/6bca9abe5b3f239f653bcd39ea22f944 to your computer and use it in GitHub Desktop.

Revisions

  1. @hadees hadees created this gist Mar 5, 2016.
    224 changes: 224 additions & 0 deletions arel_helpers.rb
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,224 @@
    module ArelHelpers
    extend self

    def self.included(base)
    base.extend self
    end

    def asterisk(arel_table_or_model)
    arel_table, columns = case arel_table_or_model
    when Arel::Table
    [arel_table_or_model, arel_table_or_model.engine.columns]
    when ->(possible_model) { UtilitiesHelper.is_model?(possible_model) }
    [arel_table_or_model.arel_table, arel_table_or_model.columns]
    else
    raise ArgumentError, "Must pass in an arel table or model"
    end
    columns.map { |c| arel_table[c.name] }
    end

    def greatest(*args)
    Arel::Nodes::NamedFunction.new "greatest", args
    end

    def least(*args)
    Arel::Nodes::NamedFunction.new "least", args
    end

    def cast(pred, type)
    Arel::Nodes::NamedFunction.new "cast", [pred.as(type)]
    end

    def null_if(column, value)
    Arel::Nodes::NamedFunction.new "NULLIF", [column, value]
    end

    def predicate(pred, true_value, false_value)
    Arel::Nodes::SqlLiteral.new("CASE WHEN #{sqlv(pred)} THEN #{sqlv(true_value)} ELSE #{sqlv(false_value)} END")
    end

    def tsrange(lower_or_range, upper = nil)
    Arel::Nodes::NamedFunction.new "tsrange", range_params(lower_or_range, upper)
    end

    def tstzrange(lower_or_range, upper = nil)
    Arel::Nodes::NamedFunction.new "tstzrange", range_params(lower_or_range, upper)
    end

    def overlap(a, b)
    Arel::Nodes::InfixOperation.new "&&", a, b
    end

    def coalesce(*args)
    Arel::Nodes::NamedFunction.new "coalesce", args
    end

    def hstore_key(hstore, key)
    Arel::Nodes::InfixOperation.new "->", hstore, cloneable(key)
    end

    def concat(*args)
    Arel::Nodes::NamedFunction.new "concat", args
    end

    def mod(a, b)
    Arel::Nodes::InfixOperation.new "%", a, b
    end

    def to_char(input, format)
    Arel::Nodes::NamedFunction.new "to_char", [input, format]
    end

    def string_agg(input, delimiter)
    Arel::Nodes::NamedFunction.new "string_agg", [input, delimiter]
    end

    def between(pred, lower_or_range, upper = nil)
    Arel::Nodes::Between.new(pred, Arel::Nodes::And.new(range_params(lower_or_range, upper)))
    end

    def unnest(array)
    Arel::Nodes::NamedFunction.new "unnest", [array]
    end

    def array_agg(expression)
    Arel::Nodes::NamedFunction.new "array_agg", [expression]
    end

    def lower(expression)
    Arel::Nodes::NamedFunction.new "lower", [expression]
    end

    def accumulative_or(array)
    array.inject do |expressions, expression|
    if expressions === expression
    expression
    else
    expressions.or(expression)
    end
    end
    end

    def array_intersect(a1, a2, opts = {})
    select1 = unnest(sqlv(a1))
    select2 = unnest(sqlv(a2))

    if !opts[:case_sensitive]
    select1 = lower cast(select1, "text")
    select2 = lower cast(select2, "text")
    end

    Arel::Nodes::SqlLiteral.new <<-SQL
    ARRAY(
    SELECT #{sqlv(select1)} INTERSECT
    SELECT #{sqlv(select2)}
    )
    SQL
    end

    def descendants_search(table, id, max_depth: 999)
    tree_sql = Arel::Nodes::SqlLiteral.new <<-SQL
    WITH RECURSIVE descendants_search(id, path) AS (
    SELECT id, ARRAY[id]
    FROM #{table.name}
    WHERE id = #{id}
    UNION ALL
    SELECT #{table.name}.id, (path || #{table.name}.id)
    FROM descendants_search
    JOIN #{table.name}
    ON descendants_search.id = #{table.name}.reports_to_id
    WHERE NOT #{table.name}.id = ANY(path)
    AND NOT array_length(path,1) > #{max_depth}
    )
    SELECT id
    FROM descendants_search
    WHERE id != #{id}
    ORDER BY array_length(path, 1), path
    SQL
    table[:id].in(tree_sql)
    end

    def ancestor_search(table, id)
    tree_sql = Arel::Nodes::SqlLiteral.new <<-SQL
    WITH RECURSIVE ancestor_search(id, reports_to_id, path) AS (
    SELECT id, reports_to_id, ARRAY[id]
    FROM #{table.name}
    WHERE id = #{id}
    UNION ALL
    SELECT #{table.name}.id, #{table.name}.reports_to_id, (path || #{table.name}.id)
    FROM ancestor_search
    JOIN #{table.name}
    ON ancestor_search.reports_to_id = #{table.name}.id
    WHERE NOT #{table.name}.id = ANY(path)
    )
    SELECT id
    FROM ancestor_search
    WHERE id != #{id}
    ORDER BY array_length(path, 1), path
    SQL
    table[:id].in(tree_sql)
    end

    def sqlv(node)
    case node
    when ->(n) { n.respond_to?(:to_sql) }
    node.to_sql
    when Arel::Attributes::Attribute
    Arel::Nodes::SqlLiteral.new "\"#{node.relation.name}\".\"#{node.name}\""
    when Array, Range
    value = node.map { |x| x.is_a?(String) ? "'#{x}'" : x }.join(",")
    Arel::Nodes::SqlLiteral.new "ARRAY[#{value}]"
    when Time, DateTime, Date
    Arel::Nodes.build_quoted node
    when String
    Arel::Nodes.build_quoted node
    else
    Arel::Nodes::SqlLiteral.new node.to_s
    end
    end

    def array_agg(expression)
    Arel::Nodes::NamedFunction.new "array_agg", [expression]
    end

    def between(pred, lower_or_range, upper = nil)
    Arel::Nodes::Between.new(pred, Arel::Nodes::And.new(range_params(lower_or_range, upper)))
    end

    # This is a special ordering SQL used inside methods like array_agg
    # http://www.postgresql.org/docs/current/static/sql-expressions.html#SYNTAX-AGGREGATES
    def order_by(a, b)
    Arel::Nodes::SqlLiteral.new "#{sqlv(a)} ORDER BY #{sqlv(b)}"
    end

    def range_params(lower_or_range, upper = nil)
    case lower_or_range
    when Range
    lower = lower_or_range.min
    upper = lower_or_range.max
    else
    lower = lower_or_range
    end
    [sqlv(lower), sqlv(upper)]
    end

    def cloneable(obj)
    case obj
    when Symbol
    Arel::Nodes.build_quoted obj.to_s
    else
    obj
    end
    end

    def self.sort(node, order)
    case order.try(:to_sym)
    when :asc
    Arel::Nodes::Ascending.new node
    when :desc
    Arel::Nodes::Descending.new node
    else
    raise ArgumentError, "Must pass in either :asc or :desc"
    end
    end
    end