module ArelHelpers extend self def self.included(base) base.extend self end def asterisk(arel_table_or_model) arel_table, columns = case arel_table_or_model when Arel::Table [arel_table_or_model, arel_table_or_model.engine.columns] when ->(possible_model) { UtilitiesHelper.is_model?(possible_model) } [arel_table_or_model.arel_table, arel_table_or_model.columns] else raise ArgumentError, "Must pass in an arel table or model" end columns.map { |c| arel_table[c.name] } end def greatest(*args) Arel::Nodes::NamedFunction.new "greatest", args end def least(*args) Arel::Nodes::NamedFunction.new "least", args end def cast(pred, type) Arel::Nodes::NamedFunction.new "cast", [pred.as(type)] end def null_if(column, value) Arel::Nodes::NamedFunction.new "NULLIF", [column, value] end def predicate(pred, true_value, false_value) Arel::Nodes::SqlLiteral.new("CASE WHEN #{sqlv(pred)} THEN #{sqlv(true_value)} ELSE #{sqlv(false_value)} END") end def tsrange(lower_or_range, upper = nil) Arel::Nodes::NamedFunction.new "tsrange", range_params(lower_or_range, upper) end def tstzrange(lower_or_range, upper = nil) Arel::Nodes::NamedFunction.new "tstzrange", range_params(lower_or_range, upper) end def overlap(a, b) Arel::Nodes::InfixOperation.new "&&", a, b end def coalesce(*args) Arel::Nodes::NamedFunction.new "coalesce", args end def hstore_key(hstore, key) Arel::Nodes::InfixOperation.new "->", hstore, cloneable(key) end def concat(*args) Arel::Nodes::NamedFunction.new "concat", args end def mod(a, b) Arel::Nodes::InfixOperation.new "%", a, b end def to_char(input, format) Arel::Nodes::NamedFunction.new "to_char", [input, format] end def string_agg(input, delimiter) Arel::Nodes::NamedFunction.new "string_agg", [input, delimiter] end def between(pred, lower_or_range, upper = nil) Arel::Nodes::Between.new(pred, Arel::Nodes::And.new(range_params(lower_or_range, upper))) end def unnest(array) Arel::Nodes::NamedFunction.new "unnest", [array] end def array_agg(expression) Arel::Nodes::NamedFunction.new "array_agg", [expression] end def lower(expression) Arel::Nodes::NamedFunction.new "lower", [expression] end def accumulative_or(array) array.inject do |expressions, expression| if expressions === expression expression else expressions.or(expression) end end end def array_intersect(a1, a2, opts = {}) select1 = unnest(sqlv(a1)) select2 = unnest(sqlv(a2)) if !opts[:case_sensitive] select1 = lower cast(select1, "text") select2 = lower cast(select2, "text") end Arel::Nodes::SqlLiteral.new <<-SQL ARRAY( SELECT #{sqlv(select1)} INTERSECT SELECT #{sqlv(select2)} ) SQL end def descendants_search(table, id, max_depth: 999) tree_sql = Arel::Nodes::SqlLiteral.new <<-SQL WITH RECURSIVE descendants_search(id, path) AS ( SELECT id, ARRAY[id] FROM #{table.name} WHERE id = #{id} UNION ALL SELECT #{table.name}.id, (path || #{table.name}.id) FROM descendants_search JOIN #{table.name} ON descendants_search.id = #{table.name}.reports_to_id WHERE NOT #{table.name}.id = ANY(path) AND NOT array_length(path,1) > #{max_depth} ) SELECT id FROM descendants_search WHERE id != #{id} ORDER BY array_length(path, 1), path SQL table[:id].in(tree_sql) end def ancestor_search(table, id) tree_sql = Arel::Nodes::SqlLiteral.new <<-SQL WITH RECURSIVE ancestor_search(id, reports_to_id, path) AS ( SELECT id, reports_to_id, ARRAY[id] FROM #{table.name} WHERE id = #{id} UNION ALL SELECT #{table.name}.id, #{table.name}.reports_to_id, (path || #{table.name}.id) FROM ancestor_search JOIN #{table.name} ON ancestor_search.reports_to_id = #{table.name}.id WHERE NOT #{table.name}.id = ANY(path) ) SELECT id FROM ancestor_search WHERE id != #{id} ORDER BY array_length(path, 1), path SQL table[:id].in(tree_sql) end def sqlv(node) case node when ->(n) { n.respond_to?(:to_sql) } node.to_sql when Arel::Attributes::Attribute Arel::Nodes::SqlLiteral.new "\"#{node.relation.name}\".\"#{node.name}\"" when Array, Range value = node.map { |x| x.is_a?(String) ? "'#{x}'" : x }.join(",") Arel::Nodes::SqlLiteral.new "ARRAY[#{value}]" when Time, DateTime, Date Arel::Nodes.build_quoted node when String Arel::Nodes.build_quoted node else Arel::Nodes::SqlLiteral.new node.to_s end end def array_agg(expression) Arel::Nodes::NamedFunction.new "array_agg", [expression] end def between(pred, lower_or_range, upper = nil) Arel::Nodes::Between.new(pred, Arel::Nodes::And.new(range_params(lower_or_range, upper))) end # This is a special ordering SQL used inside methods like array_agg # http://www.postgresql.org/docs/current/static/sql-expressions.html#SYNTAX-AGGREGATES def order_by(a, b) Arel::Nodes::SqlLiteral.new "#{sqlv(a)} ORDER BY #{sqlv(b)}" end def range_params(lower_or_range, upper = nil) case lower_or_range when Range lower = lower_or_range.min upper = lower_or_range.max else lower = lower_or_range end [sqlv(lower), sqlv(upper)] end def cloneable(obj) case obj when Symbol Arel::Nodes.build_quoted obj.to_s else obj end end def self.sort(node, order) case order.try(:to_sym) when :asc Arel::Nodes::Ascending.new node when :desc Arel::Nodes::Descending.new node else raise ArgumentError, "Must pass in either :asc or :desc" end end end